// Copyright 2024 SiFive, Inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR in ["m1", "m2", "m4", "m8"]
$LMUL = int(NR[1])
$assert KBLOCK in [1, 2, 4, 8]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <riscv_vector.h>

#include "xnnpack/packw.h"

void xnn_x32_packw_gemm_goi_ukernel_x${LMUL}v__rvv_u${KBLOCK}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == __riscv_vsetvlmax_e32m${LMUL}());
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  uint32_t* out = packed_weights;
  const uint32_t* b = bias;
  size_t kc_bstride = kc * 4;

  do {
    const uint32_t* w0 = weights;
    size_t n = nc;
    // NC main loop: process multiple of NR
    for (;n >= nr; n -= nr) {
      // Pack nr bias at begining of tile
      size_t vlmax = __riscv_vsetvlmax_e32m${LMUL}();
      vuint32m${LMUL}_t v_bias;
      if XNN_LIKELY(b != NULL) {
        v_bias = __riscv_vle32_v_u32m${LMUL}(b, vlmax); b += nr;
      } else {
        v_bias = __riscv_vmv_v_x_u32m${LMUL}(0, vlmax);
      }
      __riscv_vse32_v_u32m${LMUL}(out, v_bias, vlmax); out += nr;

      uint32_t* out0 = out;
      size_t k = kc;
      $if KBLOCK == 8:
        // vlsseg8, LMUL must <= 1
        vlmax = __riscv_vsetvlmax_e32m1();
        // Pack 8 x nr weights
        for (; k >= 8; k -= 8) {
          $for N in range(1, 8):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL == 1:
            // 1 load & store can pack 8 x nr weights
            vuint32m1x8_t v_w_m1x8 = __riscv_vlsseg8e32_v_u32m1x8(w0, kc_bstride, vlmax);
            $for KB in range(8):
              vuint32m1_t v_w${KB} = __riscv_vget_v_u32m1x8_u32m1(v_w_m1x8, ${KB});
              __riscv_vse32_v_u32m1(out${KB}, v_w${KB}, vlmax);
            out0 = out7 + vlmax;
          $else:
            // When vlsseg8, LMUL is contraint to 1. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            size_t remaining_n = nr;
            do {
              vuint32m1x8_t v_w_m1x8 = __riscv_vlsseg8e32_v_u32m1x8(w_ptr, kc_bstride, vlmax);
              w_ptr += kc * vlmax;
              $for KB in range(8):
                vuint32m1_t v_w${KB} = __riscv_vget_v_u32m1x8_u32m1(v_w_m1x8, ${KB});
                __riscv_vse32_v_u32m1(out${KB}, v_w${KB}, vlmax); out${KB} += vlmax;
              remaining_n -= vlmax;
            } while(remaining_n > 0);
            out0 = out7;
          w0 += 8;
        }
      $if KBLOCK >= 4:
        // vlsseg4, LMUL must <= 2
        vlmax = __riscv_vsetvlmax_e32m${min(LMUL, 2)}();
        // Pack 4 x nr weights
        for (; k >= 4; k -= 4) {
          $for N in range(1, 4):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL <= 2:
            // 1 load & store can pack 4 x nr weights
            vuint32m${LMUL}x4_t v_w_m${LMUL}x4 = __riscv_vlsseg4e32_v_u32m${LMUL}x4(w0, kc_bstride, vlmax);
            $for KB in range(4):
              vuint32m${LMUL}_t v_w${KB} = __riscv_vget_v_u32m${LMUL}x4_u32m${LMUL}(v_w_m${LMUL}x4, ${KB});
              __riscv_vse32_v_u32m${LMUL}(out${KB}, v_w${KB}, vlmax);
            out0 = out3 + vlmax;
          $else:
            // When vlsseg4, LMUL is contraint to 2. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            size_t remaining_n = nr;
            do {
              vuint32m2x4_t v_w_m2x4 = __riscv_vlsseg4e32_v_u32m2x4(w_ptr, kc_bstride, vlmax);
              w_ptr += kc * vlmax;
              $for KB in range(4):
                vuint32m2_t v_w${KB} = __riscv_vget_v_u32m2x4_u32m2(v_w_m2x4, ${KB});
                __riscv_vse32_v_u32m2(out${KB}, v_w${KB}, vlmax); out${KB} += vlmax;
              remaining_n -= vlmax;
            } while(remaining_n > 0);
            out0 = out3;
          w0 += 4;
        }
      $if KBLOCK == 2:
        // vlsseg2, LMUL must <= 4
        vlmax = __riscv_vsetvlmax_e32m${min(LMUL, 4)}();
        // Pack 2 x nr weights
        for (; k >= 2; k -= 2) {
          $for N in range(1, 2):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL <= 4:
            // 1 load & store can pack 2 x nr weights
            vuint32m${LMUL}x2_t v_w_m${LMUL}x2 = __riscv_vlsseg2e32_v_u32m${LMUL}x2(w0, kc_bstride, vlmax);
            $for KB in range(2):
              vuint32m${LMUL}_t v_w${KB} = __riscv_vget_v_u32m${LMUL}x2_u32m${LMUL}(v_w_m${LMUL}x2, ${KB});
              __riscv_vse32_v_u32m${LMUL}(out${KB}, v_w${KB}, vlmax);
            out0 = out1 + vlmax;
          $else:
            // When vlsseg2, LMUL is contraint to 4. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            size_t remaining_n = nr;
            do {
              vuint32m4x2_t v_w_m4x2 = __riscv_vlsseg2e32_v_u32m4x2(w_ptr, kc_bstride, vlmax);
              w_ptr += kc * vlmax;
              $for KB in range(2):
                vuint32m4_t v_w${KB} = __riscv_vget_v_u32m4x2_u32m4(v_w_m4x2, ${KB});
                __riscv_vse32_v_u32m4(out${KB}, v_w${KB}, vlmax); out${KB} += vlmax;
              remaining_n -= vlmax;
            } while(remaining_n > 0);
            out0 = out1;
          w0 += 2;
        }
      vlmax = __riscv_vsetvlmax_e32m${LMUL}();
      // Pack nr weights
      for (; k >= 1; k -= 1) {
        vuint32m${LMUL}_t v_w = __riscv_vlse32_v_u32m${LMUL}(w0, kc_bstride, vlmax);
        __riscv_vse32_v_u32m${LMUL}(out0, v_w, vlmax);
        out0 += vlmax;
        w0 += 1;
      }
      out = (uint32_t*) ((uintptr_t) out0 + extra_bytes);
      w0 += (nr - 1) * kc;
    }
    // NC remainder: process n < NR
    if (n > 0) {
      // Pack nr bias at begining of tile
      size_t vl = __riscv_vsetvl_e32m${LMUL}(n);
      vuint32m${LMUL}_t v_bias;
      if XNN_LIKELY(b != NULL) {
        v_bias = __riscv_vle32_v_u32m${LMUL}(b, vl); b += vl;
      } else {
        v_bias = __riscv_vmv_v_x_u32m${LMUL}(0, vl);
      }
      __riscv_vse32_v_u32m${LMUL}(out, v_bias, vl); out += nr;

      size_t vlmax;
      uint32_t* out0 = out;
      size_t k = kc;
      $if KBLOCK == 8:
        // vlsseg8, LMUL must <= 1
        vlmax = __riscv_vsetvlmax_e32m1();
        // Pack 8 x n weights
        for (; k >= 8; k -= 8) {
          $for N in range(1, 8):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL == 1:
            // 1 load & store can pack 8 x n value
            size_t vl = __riscv_vsetvl_e32m1(n);
            vuint32m1x8_t v_w_m1x8 = __riscv_vlsseg8e32_v_u32m1x8(w0, kc_bstride, vl);
            $for KB in range(8):
              vuint32m1_t v_w${KB} = __riscv_vget_v_u32m1x8_u32m1(v_w_m1x8, ${KB});
              __riscv_vse32_v_u32m1(out${KB}, v_w${KB}, vl);
            out0 = out7 + vlmax;
          $else:
            // When vlsseg8, LMUL is contraint to 1. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            unsigned char remaining_blocks = ${LMUL};
            size_t remaining_n = n;
            do {
              size_t vl;
              if XNN_LIKELY(remaining_n >= vlmax) {
                vl = vlmax;
              } else {
                vl = __riscv_vsetvl_e32m1(remaining_n);
              }
              vuint32m1x8_t v_w_m1x8 = __riscv_vlsseg8e32_v_u32m1x8(w_ptr, kc_bstride, vl);
              w_ptr += kc * vl;
              $for KB in range(8):
                vuint32m1_t v_w${KB} = __riscv_vget_v_u32m1x8_u32m1(v_w_m1x8, ${KB});
                __riscv_vse32_v_u32m1(out${KB}, v_w${KB}, vl); out${KB} += vlmax;
              remaining_n -= vl;
              remaining_blocks--;
            } while(remaining_n > 0);
            out0 = out7 + remaining_blocks * vlmax;
          w0 += 8;
        }
      $if KBLOCK >= 4:
        // vlsseg4, LMUL must <= 2
        vlmax = __riscv_vsetvlmax_e32m${min(LMUL, 2)}();
        // Pack 4 x n weights
        for (; k >= 4; k -= 4) {
          $for N in range(1, 4):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL <= 2: 
            // 1 load & store can pack 4 x nr weights
            size_t vl = __riscv_vsetvl_e32m${LMUL}(n);
            vuint32m${LMUL}x4_t v_w_m${LMUL}x4 = __riscv_vlsseg4e32_v_u32m${LMUL}x4(w0, kc_bstride, vl);
            $for KB in range(4):
              vuint32m${LMUL}_t v_w${KB} = __riscv_vget_v_u32m${LMUL}x4_u32m${LMUL}(v_w_m${LMUL}x4, ${KB});
              __riscv_vse32_v_u32m${LMUL}(out${KB}, v_w${KB}, vl);
            out0 = out3 + vlmax;
          $else:
            // When vlsseg4, LMUL is contraint to 2. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            unsigned char remaining_blocks = ${int(LMUL/2)};
            size_t remaining_n = n;
            do {
              size_t vl;
              if XNN_LIKELY(remaining_n >= vlmax) {
                vl = vlmax;
              } else {
                vl = __riscv_vsetvl_e32m2(remaining_n);
              }
              vuint32m2x4_t v_w_m2x4 = __riscv_vlsseg4e32_v_u32m2x4(w_ptr, kc_bstride, vl);
              w_ptr += kc * vl;
              $for KB in range(4):
                vuint32m2_t v_w${KB} = __riscv_vget_v_u32m2x4_u32m2(v_w_m2x4, ${KB});
                __riscv_vse32_v_u32m2(out${KB}, v_w${KB}, vl); out${KB} += vlmax;
              remaining_n -= vl;
              remaining_blocks--;
            } while(remaining_n > 0);
            out0 = out3 + remaining_blocks * vlmax;
          w0 += 4;
        }
      $if KBLOCK == 2:
        // vlsseg2, LMUL must <= 4
        vlmax = __riscv_vsetvlmax_e32m${min(LMUL, 4)}();
        // Pack 2 x n weights
         for (; k >= 2; k -= 2) {
          $for N in range(1, 2):
            uint32_t* out${N} = out${N-1} + nr;
          $if LMUL <= 4:
            // 1 load & store can pack 2 x nr weights
            size_t vl = __riscv_vsetvl_e32m${LMUL}(n);
            vuint32m${LMUL}x2_t v_w_m${LMUL}x2 = __riscv_vlsseg2e32_v_u32m${LMUL}x2(w0, kc_bstride, vl);
            $for KB in range(2):
              vuint32m${LMUL}_t v_w${KB} = __riscv_vget_v_u32m${LMUL}x2_u32m${LMUL}(v_w_m${LMUL}x2, ${KB});
              __riscv_vse32_v_u32m${LMUL}(out${KB}, v_w${KB}, vl);
            out0 = out1 + vlmax;
          $else:
            // When vlsseg2, LMUL is contraint to 4. We need to use multiple of load & store.
            const uint32_t* w_ptr = w0;
            unsigned char remaining_blocks = ${int(LMUL/4)};
            size_t remaining_n = n;
            do {
              if XNN_LIKELY(remaining_n >= vlmax) {
                vl = vlmax;
              } else {
                vl = __riscv_vsetvl_e32m4(remaining_n);
              }
              vuint32m4x2_t v_w_m4x2 = __riscv_vlsseg2e32_v_u32m4x2(w_ptr, kc_bstride, vlmax);
              w_ptr += kc * vl;
              $for KB in range(2):
                vuint32m4_t v_w${KB} = __riscv_vget_v_u32m4x2_u32m4(v_w_m4x2, ${KB});
                __riscv_vse32_v_u32m4(out${KB}, v_w${KB}, vlmax); out${KB} += vlmax;
              remaining_n -= vl;
              remaining_blocks--;
            } while(remaining_n > 0);
            out0 = out1 + remaining_blocks * vlmax;
          w0 += 2;
        }
      vlmax = __riscv_vsetvlmax_e32m${LMUL}();
      vl = __riscv_vsetvl_e32m${LMUL}(n);
      // Pack n weights
      for (; k >= 1; k -= 1) {
        vuint32m${LMUL}_t v_w = __riscv_vlse32_v_u32m${LMUL}(w0, kc_bstride, vl);
        __riscv_vse32_v_u32m${LMUL}(out0, v_w, vl);
        out0 += vlmax;
        w0 += 1;
      }
      out = (uint32_t*) ((uintptr_t) out0 + extra_bytes);
      w0 += (nr - 1) * kc; 
    }
    weights += nc * kc;
  } while (--g != 0);
}