// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
$LMUL_8 =  {1: "f4", 2: "f2", 4: "1", 8: "2"}[LMUL]
$LMUL_16 = {1: "f2", 2: "1",  4: "2", 8: "4"}[LMUL]
$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/math.h"
#include "xnnpack/vcvt.h"


$OUTPUT_MIN = {"QS8": -128, "QU8": 0}[DATATYPE]
$OUTPUT_MAX = {"QS8": 127, "QU8": 255}[DATATYPE]
void xnn_f32_${DATATYPE.lower()}_vcvt_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const float* input,
    ${XINT8_T}* output,
    const struct xnn_f32_${DATATYPE.lower()}_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input != NULL);
  assert(output != NULL);

  batch >>= XNN_LOG2_SIZEOF_FLOAT;

  const float scale = params->scalar.scale;
  // TODO: Clamp may not be necessary. RISCV spec doesn't say if vncvt saturates...
  const float output_min_less_zero_point = (float) ((int32_t) ${OUTPUT_MIN} - (int32_t) params->scalar.output_zero_point);
  const float output_max_less_zero_point = (float) ((int32_t) ${OUTPUT_MAX} - (int32_t) params->scalar.output_zero_point);
  const int32_t output_zero_point = params->scalar.output_zero_point;

  for (; batch > 0; ) {
    const int32_t n = __riscv_vsetvl_e32m${LMUL}(batch); batch -= n;

    vfloat32m${LMUL}_t x_f32v = __riscv_vle32_v_f32m${LMUL}(input, n); input += n;

    x_f32v = __riscv_vfmul_vf_f32m${LMUL}(x_f32v, scale, n);
    x_f32v = __riscv_vfmax_vf_f32m${LMUL}(x_f32v, output_min_less_zero_point, n);
    x_f32v = __riscv_vfmin_vf_f32m${LMUL}(x_f32v, output_max_less_zero_point, n);

    vint32m${LMUL}_t y_i32v = __riscv_vfcvt_x_f_v_i32m${LMUL}(x_f32v, n);
    y_i32v = __riscv_vadd_vx_i32m${LMUL}(y_i32v, output_zero_point, n);

    $if DATATYPE == "QS8":
      __riscv_vse8_v_i8m${LMUL_8}(output, __riscv_vncvt_x_x_w_i8m${LMUL_8}(__riscv_vncvt_x_x_w_i16m${LMUL_16}(y_i32v, n), n), n); output += n;
    $else:
      __riscv_vse8_v_u8m${LMUL_8}(output, __riscv_vncvt_x_x_w_u8m${LMUL_8}(__riscv_vncvt_x_x_w_u16m${LMUL_16}(__riscv_vreinterpret_v_i32m${LMUL}_u32m${LMUL}(y_i32v), n), n), n); output += n;
  }
}
