// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
$assert DATATYPE in ["QS8", "QU8"]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/vunary.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]

void xnn_${DATATYPE.lower()}_vlrelu_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const ${XINT8_T}* input,
    ${XINT8_T}* output,
    const struct xnn_${DATATYPE.lower()}_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(${XINT8_T}) == 0);
  assert(input != NULL);
  assert(output != NULL);

  const int32_t input_zero_point = params->scalar.input_zero_point;
  const int32_t multiplier_diff = params->scalar.negative_multiplier ^ params->scalar.positive_multiplier;
  const int32_t multiplier_base = params->scalar.positive_multiplier;
  const int32_t bias = (int32_t) (((uint32_t) (int32_t) params->scalar.output_zero_point) << 8) + 128;
  int32_t n = __riscv_vsetvl_e8m${LMUL}(batch);
  vint32m${LMUL*4}_t bias_i32v = __riscv_vmv_v_x_i32m${LMUL*4}(bias, n);

  do {
    n = __riscv_vsetvl_e8m${LMUL}(batch); batch -= n;

    $if DATATYPE == "QS8":
      vint8m${LMUL}_t in_i8v = __riscv_vle8_v_i8m${LMUL}(input, n); input += n;
      vint16m${LMUL*2}_t acc_i16v = __riscv_vwsub_vx_i16m${LMUL*2}(in_i8v, input_zero_point, n);
    $else:
      vuint8m${LMUL}_t in_u8v = __riscv_vle8_v_u8m${LMUL}(input, n); input += n;
      vuint16m${LMUL*2}_t acc_u16v = __riscv_vwsubu_vx_u16m${LMUL*2}(in_u8v, input_zero_point, n);
      vint16m${LMUL*2}_t acc_i16v = __riscv_vreinterpret_v_u16m${LMUL*2}_i16m${LMUL*2}(acc_u16v);

    vint32m${LMUL*4}_t acc_i32v = __riscv_vwcvt_x_x_v_i32m${LMUL*4}(acc_i16v, n);
    vint32m${LMUL*4}_t sra_i32v = __riscv_vsra_vx_i32m${LMUL*4}(acc_i32v, 31, n);
    vint32m${LMUL*4}_t and_i32v = __riscv_vand_vx_i32m${LMUL*4}(sra_i32v, multiplier_diff, n);
    vint32m${LMUL*4}_t mult_i32v = __riscv_vxor_vx_i32m${LMUL*4}(and_i32v, multiplier_base, n);
    acc_i32v = __riscv_vmacc_vv_i32m${LMUL*4}(bias_i32v, acc_i32v, mult_i32v, n);

    $if DATATYPE == "QS8":
      vint16m${LMUL*2}_t out_i16v = __riscv_vnclip_wx_i16m${LMUL*2}(acc_i32v, 8, __RISCV_VXRM_RDN, n);
      vint8m${LMUL}_t out_i8v = __riscv_vnclip_wx_i8m${LMUL}(out_i16v, 0, __RISCV_VXRM_RNU, n);
      __riscv_vse8_v_i8m${LMUL}(output, out_i8v, n); output += n;
    $else:
      acc_i32v = __riscv_vmax_vx_i32m${LMUL*4}(acc_i32v, 0, n);
      vuint32m${LMUL*4}_t out_u32v = __riscv_vreinterpret_v_i32m${LMUL*4}_u32m${LMUL*4}(acc_i32v);
      vuint16m${LMUL*2}_t out_u16v =__riscv_vnclipu_wx_u16m${LMUL*2}(out_u32v, 8, __RISCV_VXRM_RDN, n);
      vuint8m${LMUL}_t out_u8v = __riscv_vnclipu_wx_u8m${LMUL}(out_u16v, 0, __RISCV_VXRM_RNU, n);
      __riscv_vse8_v_u8m${LMUL}(output, out_u8v, n); output += n;
  } while (batch != 0);
}
