// Copyright 2023 SiFive, Inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL >= 1
$assert OP in ["MAX", "MIN", "MINMAX"]
#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/vunary.h"


$EMIT_MIN = "MIN" in OP
$EMIT_MAX = "MAX" in OP
$MAX_POS = 1 if OP == "MINMAX" else 0
$OP_0 = "max" if OP == "MAX" else "min"
void xnn_f32_r${OP.lower()}_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const float* input,
    float* output,
    const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input != NULL);
  assert(output != NULL);

  size_t N = batch >> 2;
  size_t avl;
  size_t vl = __riscv_vsetvl_e32m${LMUL}(N);

  vfloat32m${LMUL}_t t0 = __riscv_vle32_v_f32m${LMUL}(input, vl);
  input += vl;
  $if MAX_POS == 1:
    vfloat32m${LMUL}_t t1 = __riscv_vmv_v_v_f32m${LMUL}(t0, vl);

  for (avl = N - vl; avl; avl -= vl, input += vl) {
    vl = __riscv_vsetvl_e32m${LMUL}(avl);
    vfloat32m${LMUL}_t vec = __riscv_vle32_v_f32m${LMUL}(input, vl);
    t0 = __riscv_vf${OP_0}_vv_f32m${LMUL}_tu(t0, t0, vec, vl);
    $if MAX_POS == 1:
      t1 = __riscv_vfmax_vv_f32m${LMUL}_tu(t1, t1, vec, vl);
  }

  $if EMIT_MIN:
    vfloat32m1_t fmin = __riscv_vfmv_s_f_f32m1(INFINITY, 1);
  $if EMIT_MAX:
    vfloat32m1_t fmax = __riscv_vfmv_s_f_f32m1(-INFINITY, 1);
  $if EMIT_MIN:
    output[0] = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmin_vs_f32m${LMUL}_f32m1(t0, fmin, N));
  $if EMIT_MAX:
    output[${MAX_POS}] = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredmax_vs_f32m${LMUL}_f32m1(t${MAX_POS}, fmax, N));
}
