// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert BATCH_TILE % 32 == 0
$assert BATCH_TILE >= 32
#include <assert.h>

#include <hvx_hexagon_protos.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>

#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/math.h"
#include "xnnpack/vbinary.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
void xnn_${DATATYPE.lower()}_vadd_minmax_ukernel__hvx_u${BATCH_TILE}(
    size_t batch,
    const ${XINT8_T}* input_a,
    const ${XINT8_T}* input_b,
    ${XINT8_T}* output,
    const struct xnn_${DATATYPE.lower()}_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(batch % sizeof(${XINT8_T}) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  const HVX_Vector vbias = Q6_V_vsplat_R(*((int32_t *) &params->scalar.bias));
  const HVX_Vector va_multiplier = Q6_V_vsplat_R(*((int32_t *) &params->scalar.a_multiplier));
  const HVX_Vector vb_multiplier = Q6_V_vsplat_R(*((int32_t *) &params->scalar.b_multiplier));

  const uint32_t shift = params->scalar.shift;
  const int32_t first_shift = min(shift, 15);
  const int32_t rest_shift = shift - first_shift;
  assert(first_shift < 16);
  assert(rest_shift < 16);

  const HVX_Vector voutput_zero_point = Q6_Vh_vsplat_R(*((int16_t *) &params->scalar.output_zero_point));
  const HVX_Vector voutput_min = Q6_Vb_vsplat_R(*((int8_t *) &params->scalar.output_min));
  const HVX_Vector voutput_max = Q6_Vb_vsplat_R(*((int8_t *) &params->scalar.output_max));

  for (; batch >= ${BATCH_TILE} * sizeof(${XINT8_T}); batch -= ${BATCH_TILE} * sizeof(${XINT8_T})) {
    HVX_Vector va0 = *((HVX_UVector*)input_a);
    HVX_Vector vb0 = *((HVX_UVector*)input_b);
    $for N in range(128, BATCH_TILE, 128):
      HVX_Vector va${int(N/128)} = *((HVX_UVector*)(input_a + ${N}));
      HVX_Vector vb${int(N/128)} = *((HVX_UVector*)(input_b + ${N}));
    input_a += ${BATCH_TILE};
    input_b += ${BATCH_TILE};

    // widen 8-bit to 16-bit
    $for N in range(0, BATCH_TILE, 128):
      HVX_VectorPair va${int(N/128)}_i16 = Q6_Wh_vunpack_Vb(va${int(N/128)});
      HVX_VectorPair vb${int(N/128)}_i16 = Q6_Wh_vunpack_Vb(vb${int(N/128)});
      HVX_Vector va${int(N/128)}_lo = Q6_V_lo_W(va${int(N/128)}_i16);
      $if N + 64 < BATCH_TILE:
        HVX_Vector va${int(N/128)}_hi = Q6_V_hi_W(va${int(N/128)}_i16);
      HVX_Vector vb${int(N/128)}_lo = Q6_V_lo_W(vb${int(N/128)}_i16);
      $if N + 64 < BATCH_TILE:
        HVX_Vector vb${int(N/128)}_hi = Q6_V_hi_W(vb${int(N/128)}_i16);

    // vacc = vbias + va * va_multiplier + vb * vb_multiplier with widening 16-bit to 32-bit
    $for N in range(0, BATCH_TILE, 128):
      HVX_Vector vacc${int(N/128)}_lo_even = vbias;
      vacc${int(N/128)}_lo_even = Q6_Vw_vmpyieacc_VwVwVh(vacc${int(N/128)}_lo_even, va_multiplier, va${int(N/128)}_lo);
      HVX_Vector vacc${int(N/128)}_lo_odd = Q6_Vw_vadd_VwVw(vbias, Q6_Vw_vmpyio_VwVh(va_multiplier, va${int(N/128)}_lo));

      vacc${int(N/128)}_lo_even = Q6_Vw_vmpyieacc_VwVwVh(vacc${int(N/128)}_lo_even, vb_multiplier, vb${int(N/128)}_lo);
      vacc${int(N/128)}_lo_odd = Q6_Vw_vadd_VwVw(vacc${int(N/128)}_lo_odd, Q6_Vw_vmpyio_VwVh(vb_multiplier, vb${int(N/128)}_lo));

      $if N + 64 < BATCH_TILE:
        HVX_Vector vacc${int(N/128)}_hi_even = vbias;
        vacc${int(N/128)}_hi_even = Q6_Vw_vmpyieacc_VwVwVh(vacc${int(N/128)}_hi_even, va_multiplier, va${int(N/128)}_hi);
        HVX_Vector vacc${int(N/128)}_hi_odd = Q6_Vw_vadd_VwVw(vbias, Q6_Vw_vmpyio_VwVh(va_multiplier, va${int(N/128)}_hi));

        vacc${int(N/128)}_hi_even = Q6_Vw_vmpyieacc_VwVwVh(vacc${int(N/128)}_hi_even, vb_multiplier, vb${int(N/128)}_hi);
        vacc${int(N/128)}_hi_odd = Q6_Vw_vadd_VwVw(vacc${int(N/128)}_hi_odd, Q6_Vw_vmpyio_VwVh(vb_multiplier, vb${int(N/128)}_hi));

    // narrow shift to 16-bit
    // vacc = vacc + voutput_zero_point
    $for N in range(0, BATCH_TILE, 128):
      HVX_Vector vacc${int(N/128)}_lo = Q6_Vh_vasr_VwVwR_sat(vacc${int(N/128)}_lo_odd, vacc${int(N/128)}_lo_even, first_shift);
      vacc${int(N/128)}_lo = Q6_Vh_vadd_VhVh(voutput_zero_point, Q6_Vh_vasr_VhR(vacc${int(N/128)}_lo, rest_shift));
      $if N + 64 < BATCH_TILE:
        HVX_Vector vacc${int(N/128)}_hi = Q6_Vh_vasr_VwVwR_sat(vacc${int(N/128)}_hi_odd, vacc${int(N/128)}_hi_even, first_shift);
        vacc${int(N/128)}_hi = Q6_Vh_vadd_VhVh(voutput_zero_point, Q6_Vh_vasr_VhR(vacc${int(N/128)}_hi, rest_shift));

    // narrow 16-bit to 8-bit
    $for N in range(0, BATCH_TILE, 128):
      $if N + 64 < BATCH_TILE:
        HVX_Vector vout${int(N/128)} = Q6_Vb_vpack_VhVh_sat(vacc${int(N/128)}_hi, vacc${int(N/128)}_lo);
      $else:
        HVX_Vector vout${int(N/128)} = Q6_Vb_vpack_VhVh_sat(vacc${int(N/128)}_lo, vacc${int(N/128)}_lo);

    // minmax
    $for N in range(0, BATCH_TILE, 128):
      vout${int(N/128)} = Q6_Vb_vmax_VbVb(voutput_min, vout${int(N/128)});
      vout${int(N/128)} = Q6_Vb_vmin_VbVb(voutput_max, vout${int(N/128)});

    // store output
    $for N in range(0, BATCH_TILE, 128):
      $if N + 128 <= BATCH_TILE:
        *((HVX_UVector *) output) = vout${int(N/128)};
        output += 128;
      $else:
        Q6_V_vstu_variable(output, ${BATCH_TILE - N}, vout${int(N/128)});
        output += ${BATCH_TILE - N};
  }
  if XNN_UNLIKELY(batch != 0){
    do {
      HVX_Vector va = *((HVX_UVector*)input_a);
      HVX_Vector vb = *((HVX_UVector*)input_b);
      $if BATCH_TILE > 32:
        if XNN_LIKELY(batch > (32 * sizeof(int8_t))) {
          input_a += 32;
          input_b += 32;
        }

      HVX_VectorPair va_i16 = Q6_Wh_vunpack_Vb(va);
      HVX_Vector va_lo = Q6_V_lo_W(va_i16);

      HVX_VectorPair vb_i16 = Q6_Wh_vunpack_Vb(vb);
      HVX_Vector vb_lo = Q6_V_lo_W(vb_i16);

      HVX_Vector vacc_even = vbias;
      vacc_even = Q6_Vw_vmpyieacc_VwVwVh(vacc_even, va_multiplier, va_lo);
      HVX_Vector vacc_odd = Q6_Vw_vadd_VwVw(vbias, Q6_Vw_vmpyio_VwVh(va_multiplier, va_lo));

      vacc_even = Q6_Vw_vmpyieacc_VwVwVh(vacc_even, vb_multiplier, vb_lo);
      vacc_odd = Q6_Vw_vadd_VwVw(vacc_odd, Q6_Vw_vmpyio_VwVh(vb_multiplier, vb_lo));

      HVX_Vector vacc = Q6_Vh_vasr_VwVwR_sat(vacc_odd, vacc_even, first_shift);
      vacc = Q6_Vh_vadd_VhVh(voutput_zero_point, Q6_Vh_vasr_VhR(vacc, rest_shift));

      HVX_Vector vout = Q6_Vb_vpack_VhVh_sat(vacc, vacc);

      vout = Q6_Vb_vmax_VbVb(voutput_min, vout);
      vout = Q6_Vb_vmin_VbVb(voutput_max, vout);

      $if BATCH_TILE > 32:
        if XNN_LIKELY(batch > (32 * sizeof(int8_t))) {
          Q6_V_vstu_variable(output, 32, vout);
          output += 32;
          batch -=32;
        }
        else{
          Q6_V_vstu_variable(output, batch, vout);
          batch = 0;
        }
      $else:
        Q6_V_vstu_variable(output, batch, vout);
        batch = 0;
    } while (batch != 0);
  }
}
