// Copyright 2024 SiFive, Inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"]
#include <assert.h>

#include <riscv_vector.h>

#include "xnnpack/common.h"
#include "xnnpack/vbinary.h"


$OP_FUNC = {
$  "ADD": "__riscv_vfadd_vv_f32",
$  "DIV": "__riscv_vfdiv_vv_f32",
$  "MAX": "__riscv_vfmax_vv_f32",
$  "MIN": "__riscv_vfmin_vv_f32",
$  "MUL": "__riscv_vfmul_vv_f32",
$  "SUB": "__riscv_vfsub_vv_f32",
$  "SQRDIFF": "__riscv_vfsub_vv_f32",
$}[OP]
void xnn_f32_v${OP.lower()}_ukernel__rvv_u${LMUL}v(
    size_t batch,
    const float* input_a,
    const float* input_b,
    float* output,
    const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  size_t n = batch >> 2;

  do {
    size_t vl = __riscv_vsetvl_e32m${LMUL}(n);
    n -= vl;
    vfloat32m${LMUL}_t va = __riscv_vle32_v_f32m${LMUL}(input_a, vl);
    input_a += vl;
    vfloat32m${LMUL}_t vb = __riscv_vle32_v_f32m${LMUL}(input_b, vl);
    input_b += vl;
    vfloat32m${LMUL}_t vacc = ${OP_FUNC}m${LMUL}(va, vb, vl);
    $if OP == "SQRDIFF":
      vacc = __riscv_vfmul_vv_f32m${LMUL}(vacc, vacc, vl);
    __riscv_vse32_v_f32m${LMUL}(output, vacc, vl);
    output += vl;
  } while (n > 0);
}
