// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
$assert DATATYPE in ["S8", "U8"]

#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/vunary.h"

$XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE]

void xnn_${DATATYPE.lower()}_vclamp_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const ${XINT8_T}* input,
    ${XINT8_T}* output,
    const struct xnn_${DATATYPE.lower()}_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(${XINT8_T}) == 0);
  assert(input != NULL);
  assert(output != NULL);

  const ${XINT8_T} vmin = params->scalar.min;
  const ${XINT8_T} vmax = params->scalar.max;

  do {
    const size_t n = __riscv_vsetvl_e8m${LMUL}(batch);
    $if DATATYPE == "S8":
      vint8m${LMUL}_t vacc = __riscv_vle8_v_i8m${LMUL}(input, n);
      vacc = __riscv_vmax_vx_i8m${LMUL}(vacc, vmin, n);
      vacc = __riscv_vmin_vx_i8m${LMUL}(vacc, vmax, n);
      __riscv_vse8_v_i8m${LMUL}(output, vacc, n);
    $else:
      vuint8m${LMUL}_t vacc = __riscv_vle8_v_u8m${LMUL}(input, n);
      vacc = __riscv_vmaxu_vx_u8m${LMUL}(vacc, vmin, n);
      vacc = __riscv_vminu_vx_u8m${LMUL}(vacc, vmax, n);
      __riscv_vse8_v_u8m${LMUL}(output, vacc, n);
    input += n;
    output += n;
    batch -= n;
  } while (batch != 0);
}
