// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2]
$VXINT = {"QS8": "vint", "QU8": "vuint"}[DATATYPE]
$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
$XLOAD = {"QS8": "__riscv_vle8_v_i8",  "QU8": "__riscv_vle8_v_u8"}[DATATYPE]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/vcvt.h"


void xnn_${DATATYPE.lower()}_f32_vcvt_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const ${XINT8_T}* input,
    float* output,
    const struct xnn_${DATATYPE.lower()}_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(int8_t) == 0);
  assert(input != NULL);
  assert(output != NULL);

  batch >>= XNN_LOG2_SIZEOF_INT8_T;

  const float scale = params->scalar.scale;
  const int32_t minus_zero_point = -params->scalar.zero_point;

  for (; batch > 0; ) {
    const int32_t n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;

    $if DATATYPE == "QS8":
      vint8m${LMUL}_t x_i8v = __riscv_vle8_v_i8m${LMUL}(input, n); input += n;

      vint32m${LMUL*4}_t wx_i32v = __riscv_vsext_vf4_i32m${LMUL*4}(x_i8v, n);
    $else:
      vuint8m${LMUL}_t x_u8v = __riscv_vle8_v_u8m${LMUL}(input, n); input += n;

      vint32m${LMUL*4}_t wx_i32v = __riscv_vreinterpret_v_u32m${LMUL*4}_i32m${LMUL*4}(__riscv_vzext_vf4_u32m${LMUL*4}(x_u8v, n));
    wx_i32v = __riscv_vadd_vx_i32m${LMUL*4}(wx_i32v, minus_zero_point, n);
    vfloat32m${LMUL*4}_t y_f32v = __riscv_vfcvt_f_x_v_f32m${LMUL*4}(wx_i32v, n);
    y_f32v = __riscv_vfmul_vf_f32m${LMUL*4}(y_f32v, scale, n);

    __riscv_vse32_v_f32m${LMUL*4}(output, y_f32v, n); output += n;
  }
}
