// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <riscv_vector.h>
#include "xnnpack/vbinary.h"

$assert DATATYPE in ["QS8", "QU8"]
$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
$VSIZE1 = {1: "m1", 2: "m2"}[LMUL]
$VSIZE2 = {1: "m2", 2: "m4"}[LMUL]
$VSIZE4 = {1: "m4", 2: "m8"}[LMUL]

void xnn_${DATATYPE.lower()}_vmulc_minmax_fp32_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const ${XINT8_T}* input_a,
    const ${XINT8_T}* input_b,
    ${XINT8_T}* output,
    const union xnn_${DATATYPE.lower()}_mul_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(${XINT8_T}) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  const int32_t a_zero_point = params->scalar.a_zero_point;
  const float scale = params->scalar.scale;
  const float output_min_less_zero_point = (int32_t) params->scalar.output_min - (int32_t) params->scalar.output_zero_point;
  const float output_max_less_zero_point = (int32_t) params->scalar.output_max - (int32_t) params->scalar.output_zero_point;
  const float magic_bias = 12582912.0f;
  const int32_t magic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) params->scalar.output_zero_point;

  const int32_t vb = (int32_t) *input_b - params->scalar.b_zero_point;
  do {
    const size_t n = __riscv_vsetvl_e8${VSIZE1}(batch);

    $if DATATYPE == "QS8":
      vint8${VSIZE1}_t in_a_i8v = __riscv_vle8_v_i8${VSIZE1}(input_a, n); input_a += n;
      vint16${VSIZE2}_t a_i16v = __riscv_vwsub_vx_i16${VSIZE2}(in_a_i8v, a_zero_point, n);
    $else:
      vuint8${VSIZE1}_t in_a_u8v = __riscv_vle8_v_u8${VSIZE1}(input_a, n); input_a += n;
      vuint16${VSIZE2}_t a_u16v = __riscv_vwsubu_vx_u16${VSIZE2}(in_a_u8v, a_zero_point, n);
      vint16${VSIZE2}_t a_i16v = __riscv_vreinterpret_v_u16${VSIZE2}_i16${VSIZE2}(a_u16v);

    vint32${VSIZE4}_t acc_i32v = __riscv_vwmul_vx_i32${VSIZE4}(a_i16v, vb, n);
    vfloat32${VSIZE4}_t acc_f32v = __riscv_vfcvt_f_x_v_f32${VSIZE4}(acc_i32v, n);
    acc_f32v = __riscv_vfmul_vf_f32${VSIZE4}(acc_f32v, scale, n);
    acc_f32v = __riscv_vfmin_vf_f32${VSIZE4}(__riscv_vfmax_vf_f32${VSIZE4}(acc_f32v, output_min_less_zero_point, n), output_max_less_zero_point, n);
    acc_f32v = __riscv_vfadd_vf_f32${VSIZE4}(acc_f32v, magic_bias, n);

    $if DATATYPE == "QS8":
      vint32${VSIZE4}_t out_i32v = __riscv_vfcvt_x_f_v_i32${VSIZE4}(acc_f32v, n);
      out_i32v = __riscv_vsub_vx_i32${VSIZE4}(out_i32v, magic_bias_less_output_zero_point, n);
      vint16${VSIZE2}_t out_i16v = __riscv_vncvt_x_x_w_i16${VSIZE2}(out_i32v, n);
      vint8${VSIZE1}_t out_i8v = __riscv_vncvt_x_x_w_i8${VSIZE1}(out_i16v, n);
      __riscv_vse8_v_i8${VSIZE1}(output, out_i8v, n); output += n;
    $else:
      vuint32${VSIZE4}_t out_u32v = __riscv_vfcvt_xu_f_v_u32${VSIZE4}(acc_f32v, n);
      out_u32v = __riscv_vsub_vx_u32${VSIZE4}(out_u32v, magic_bias_less_output_zero_point, n);
      vuint16${VSIZE2}_t out_u16v = __riscv_vncvt_x_x_w_u16${VSIZE2}(out_u32v, n);
      vuint8${VSIZE1}_t out_u8v = __riscv_vncvt_x_x_w_u8${VSIZE1}(out_u16v, n);
      __riscv_vse8_v_u8${VSIZE1}(output, out_u8v, n); output += n;

    batch -= n;
  } while (batch != 0);
}
