// Copyright 2024 SiFive, Inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QD8", "QC4"]
$assert MR >= 1
$assert NR in ["m4", "m8"]
$OUT_LMUL = NR
$IN_LMUL = {"m4": "m1", "m8": "m2"}[OUT_LMUL]
$INTERMEDIATE_MLUL = {"m4": "m2", "m8": "m4"}[OUT_LMUL]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/gemm.h"
#include "xnnpack/math.h"

$DATATYPE_SPEC = {"QD8": "qd8_f32_qc8w", "QC4": "qd8_f32_qc4w"}[DATATYPE]
$PARAMS_TYPE = {"QD8": "union xnn_f32_minmax_params", "QC4": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${OUT_LMUL[1]}v__rvv(
    size_t mr,
    size_t nc,
    size_t kc,
    const int8_t* restrict a,
    size_t a_stride,
    const void* restrict w,
    float* restrict c,
    size_t cm_stride,
    size_t cn_stride,
    const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)],
    const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(mr != 0);
  assert(mr <= ${MR});
  assert(nc != 0);
  assert(kc != 0);

  const int8_t* a0 = a;
  float* c0 = c;
  $for M in range(1, MR):
    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
    float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
    $if M % 2 == 0:
      if XNN_UNPREDICTABLE(mr <= ${M}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }
    $elif M + 1 == MR:
      if XNN_UNPREDICTABLE(mr != ${M+1}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }
    $else:
      if XNN_UNPREDICTABLE(mr < ${M+1}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }

  const size_t nr = __riscv_vsetvlmax_e32${OUT_LMUL}();
  size_t vl = nr;
  $if DATATYPE == "QC4":
    kc = round_up_po2(kc, 2);
  do {
    if XNN_UNLIKELY(nc < nr) {
      vl = __riscv_vsetvl_e32${OUT_LMUL}(nc);
    }
    nc = nc - vl;

    vint32${OUT_LMUL}_t vksum = __riscv_vle32_v_i32${OUT_LMUL}((const int32_t*)w, vl);
    w = (const int32_t*) w + nr;
    $for M in range(MR):
      const int32_t vinput_zero_point${M} = quantization_params[${M}].zero_point;
    $for M in range(MR):
      vint32${OUT_LMUL}_t vacc${M} = __riscv_vmul_vx_i32${OUT_LMUL}(vksum, vinput_zero_point${M}, vl);

    size_t k = kc;
    $if DATATYPE == "QC4":
      for (; k >= 2 * sizeof(uint8_t); k -= 2 * sizeof(uint8_t)) {
        $for M in range(MR):
          const int8_t va${M}c0 = a${M}[0];
          const int8_t va${M}c1 = a${M}[1];
          a${M} += 2;
        const vint8${IN_LMUL}_t vbi = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl);
        w = (const uint8_t*) w + nr;
        const vint8${IN_LMUL}_t vbc0 = __riscv_vsll_vx_i8${IN_LMUL}(vbi, 4, vl);
        const vint8${IN_LMUL}_t vbc1 = __riscv_vand_vx_i8${IN_LMUL}(vbi, 0xF0, vl);

        $for M in range(MR):
          vint16${INTERMEDIATE_MLUL}_t va${M}bc0 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc0, va${M}c0, vl);
          vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc0, vl);
          vint16${INTERMEDIATE_MLUL}_t va${M}bc1 = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vbc1, va${M}c1, vl);
          vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}bc1, vl);
      }
    $else:
      do {
        $for M in range(MR):
          const int8_t va${M} = *a${M}++;
        const vint8${IN_LMUL}_t vb = __riscv_vle8_v_i8${IN_LMUL}((const int8_t*) w, vl);
        w = (const int8_t*) w + nr;

        $for M in range(MR):
          vint16${INTERMEDIATE_MLUL}_t va${M}b = __riscv_vwmul_vx_i16${INTERMEDIATE_MLUL}(vb, va${M}, vl);
          vacc${M} = __riscv_vwadd_wv_i32${OUT_LMUL}(vacc${M}, va${M}b, vl);

        k -= sizeof(int8_t);
      } while (k != 0);
    $if DATATYPE == "QC4":
      $for M in range(MR):
        vacc${M} = __riscv_vsra_vx_i32${OUT_LMUL}(vacc${M}, 4, vl);
    // i32 -> f32
    $for M in range(MR):
      vfloat32${OUT_LMUL}_t vout${M} = __riscv_vfcvt_f_x_v_f32${OUT_LMUL}(vacc${M}, vl);

    // vout * input_scale
    $for M in range(MR):
      const float vinput_scale${M} = quantization_params[${M}].inv_scale;
    $for M in range(MR):
      vout${M} = __riscv_vfmul_vf_f32${OUT_LMUL}(vout${M}, vinput_scale${M}, vl);

    const vfloat32${OUT_LMUL}_t vfilter_output_scale = __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
    w = (const float*) w + nr;
    $for M in range(MR):
      vout${M} = __riscv_vfmul_vv_f32${OUT_LMUL}(vout${M}, vfilter_output_scale, vl);

    const vfloat32${OUT_LMUL}_t vbias =  __riscv_vle32_v_f32${OUT_LMUL}((const float*) w, vl);
    w = (const float*) w + nr;
    $for M in range(MR):
      vout${M} = __riscv_vfadd_vv_f32${OUT_LMUL}(vout${M}, vbias, vl);

    const float vmin = params->scalar.min;
    $for M in range(MR):
      vout${M} = __riscv_vfmax_vf_f32${OUT_LMUL}(vout${M}, vmin, vl);
    const float vmax = params->scalar.max;
    $for M in range(MR):
      vout${M} = __riscv_vfmin_vf_f32${OUT_LMUL}(vout${M}, vmax, vl);

    // store ${MR} x vl results to c
    $for M in range(MR):
      __riscv_vse32_v_f32${OUT_LMUL}(c${M}, vout${M}, vl);
      c${M} = (float*) ((uintptr_t) c${M} + cn_stride);

    $for M in range(MR):
      a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
  } while (nc != 0);
}