// Copyright 2024 Imagination Technologies, inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert LMUL in [1, 2, 4, 8]
#include <assert.h>
#include "xnnpack/avgpool.h"
#include <riscv_vector.h>

void xnn_f32_avgpool_minmax_ukernel_9x__rvv_c${LMUL}v(
    size_t output_pixels,
    size_t kernel_elements,
    size_t channels,
    const float** input,
    size_t input_offset,
    const float* zero,
    float* output,
    size_t input_increment,
    size_t output_increment,
    const struct xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(output_pixels != 0);
  assert(kernel_elements != 0);
  assert(kernel_elements <= 9);
  assert(channels != 0);
  assert((input_offset & 3) == 0);

  input_offset >>= XNN_LOG2_SIZEOF_FLOAT;

  const float scale = params->scalar.scale;
  const float min = params->scalar.min;
  const float max = params->scalar.max;

  do {
    const float *i[9];
    for (size_t kk = 0; kk < kernel_elements; ++kk) {
      assert(input[kk] != NULL);
      i[kk] = (input[kk] != zero ? input[kk] + input_offset : zero) ;
    }
    for (size_t tail = kernel_elements; tail < 9; ++tail) {
      i[tail] = zero;
    }
    input = (const float**) ((uintptr_t) input + input_increment);

    for (size_t c = channels; c != 0; ) {
      int32_t n = __riscv_vsetvl_e32m${LMUL}(c);

      vfloat32m${LMUL}_t i0_f32v = __riscv_vle32_v_f32m${LMUL}(i[0], n); i[0] += n;
      vfloat32m${LMUL}_t i1_f32v = __riscv_vle32_v_f32m${LMUL}(i[1], n); i[1] += n;
      vfloat32m${LMUL}_t i2_f32v = __riscv_vle32_v_f32m${LMUL}(i[2], n); i[2] += n;
      vfloat32m${LMUL}_t i3_f32v = __riscv_vle32_v_f32m${LMUL}(i[3], n); i[3] += n;
      vfloat32m${LMUL}_t i4_f32v = __riscv_vle32_v_f32m${LMUL}(i[4], n); i[4] += n;
      vfloat32m${LMUL}_t i5_f32v = __riscv_vle32_v_f32m${LMUL}(i[5], n); i[5] += n;
      vfloat32m${LMUL}_t i6_f32v = __riscv_vle32_v_f32m${LMUL}(i[6], n); i[6] += n;
      vfloat32m${LMUL}_t i7_f32v = __riscv_vle32_v_f32m${LMUL}(i[7], n); i[7] += n;
      vfloat32m${LMUL}_t i8_f32v = __riscv_vle32_v_f32m${LMUL}(i[8], n); i[8] += n;

      vfloat32m${LMUL}_t sum01_f32v = __riscv_vfadd_vv_f32m${LMUL}(i0_f32v, i1_f32v, n);
      vfloat32m${LMUL}_t sum23_f32v = __riscv_vfadd_vv_f32m${LMUL}(i2_f32v, i3_f32v, n);
      vfloat32m${LMUL}_t sum45_f32v = __riscv_vfadd_vv_f32m${LMUL}(i4_f32v, i5_f32v, n);
      vfloat32m${LMUL}_t sum67_f32v = __riscv_vfadd_vv_f32m${LMUL}(i6_f32v, i7_f32v, n);
      vfloat32m${LMUL}_t sum018_f32v = __riscv_vfadd_vv_f32m${LMUL}(sum01_f32v, i8_f32v, n);
      vfloat32m${LMUL}_t sum2345_f32v = __riscv_vfadd_vv_f32m${LMUL}(sum23_f32v, sum45_f32v, n);
      vfloat32m${LMUL}_t sum01678_f32v = __riscv_vfadd_vv_f32m${LMUL}(sum018_f32v, sum67_f32v, n);
      vfloat32m${LMUL}_t sum_f32v = __riscv_vfadd_vv_f32m${LMUL}(sum2345_f32v, sum01678_f32v, n);
      vfloat32m${LMUL}_t out_f32v = __riscv_vfmul_vf_f32m${LMUL}(sum_f32v, scale, n);
      out_f32v = __riscv_vfmin_vf_f32m${LMUL}(__riscv_vfmax_vf_f32m${LMUL}(out_f32v, min, n), max, n);
      __riscv_vse32_v_f32m${LMUL}(output, out_f32v, n); output += n;

      c -= n;
    }
    output = (float*) ((uintptr_t) output + output_increment);
  } while (--output_pixels != 0);
}
