// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE % 32 == 0
$assert BATCH_TILE >= 32
#include <assert.h>

#include "xnnpack/simd/f32-hvx.h"
#include "xnnpack/vcvt.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
void xnn_f32_${DATATYPE.lower()}_vcvt_ukernel__hvx_u${BATCH_TILE}(
    size_t batch,
    const float* input,
    ${XINT8_T}* output,
    const struct xnn_f32_${DATATYPE.lower()}_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input != NULL);
  assert(output != NULL);

  const HVX_Vector vscale = xnn_set1_f32(params->scalar.scale);
  const HVX_Vector vmagic_bias = xnn_set1_f32(12582912.0f);
  const HVX_Vector vmagic_bias_less_zero_point = Q6_V_vsplat_R(INT32_C(0x4B400000) - (int32_t) params->scalar.output_zero_point);
  XNN_FORCE_REALIZATION(vmagic_bias);
  $if BATCH_TILE > 32:
    for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
      HVX_Vector vx0 = xnn_loadu_f32(input);
      $for N in range(32, BATCH_TILE, 32):
        HVX_Vector vx${int(N/32)} = xnn_loadu_f32(input + ${N});
      input += ${BATCH_TILE};

      $for N in range(0, BATCH_TILE, 32):
        vx${int(N/32)} = xnn_fmadd_f32(vx${int(N/32)}, vscale, vmagic_bias);

      $for N in range(0, BATCH_TILE, 32):
        const HVX_Vector vacc${int(N/32)} = Q6_Vw_vsub_VwVw_sat(vx${int(N/32)}, vmagic_bias_less_zero_point);

      // narrowing 32-bit to 16-bit
      $for N in range(0, BATCH_TILE, 64):
        $if N + 32 < BATCH_TILE:
          const HVX_Vector vacc_h${int(N/64)} = Q6_Vh_vpack_VwVw_sat(vacc${int((N+32)/32)}, vacc${int(N/32)});
        $else:
          const HVX_Vector vacc_h${int(N/64)} = Q6_Vh_vpack_VwVw_sat(vacc${int(N/32)}, vacc${int(N/32)});

      // narrowing 16-bit to 8-bit
      $for N in range(0, BATCH_TILE, 128):
        $if N + 64 < BATCH_TILE:
          HVX_Vector vy${int(N/128)} = Q6_Vb_vpack_VhVh_sat(vacc_h${int((N+64)/64)}, vacc_h${int(N/64)});
        $else:
          HVX_Vector vy${int(N/128)} = Q6_Vb_vpack_VhVh_sat(vacc_h${int(N/64)}, vacc_h${int(N/64)});

      $for N in range(0, BATCH_TILE, 128):
        $if N + 128 <= BATCH_TILE:
          *((HVX_UVector *) output) = vy${int(N/128)};
          output += 128;
        $else:
          Q6_V_vstu_variable(output, ${BATCH_TILE - N}, vy${int(N/128)});
          output += ${BATCH_TILE - N};
    }
  for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) {
    HVX_Vector vx = xnn_loadu_f32(input);
    input += 32;

    vx = xnn_fmadd_f32(vx, vscale, vmagic_bias);

    const HVX_Vector vacc = Q6_Vw_vsub_VwVw_sat(vx, vmagic_bias_less_zero_point);

    const HVX_Vector vacc_h = Q6_Vh_vpack_VwVw_sat(vacc, vacc);

    HVX_Vector vy = Q6_Vb_vpack_VhVh_sat(vacc_h, vacc_h);

    Q6_V_vstu_variable(output, 32, vy);
    output += 32;
  }
  if XNN_UNLIKELY(batch != 0) {
    assert(batch >= 1 * sizeof(float));
    assert(batch < 32 * sizeof(float));
    HVX_Vector vx = xnn_loadu_f32(input);

    vx = xnn_fmadd_f32(vx, vscale, vmagic_bias);

    const HVX_Vector vacc = Q6_Vw_vsub_VwVw_sat(vx, vmagic_bias_less_zero_point);

    const HVX_Vector vacc_h = Q6_Vh_vpack_VwVw_sat(vacc, vacc);

    HVX_Vector vy = Q6_Vb_vpack_VhVh_sat(vacc_h, vacc_h);

    // Since the output data type is int8_t,
    // we simply determine the number of elements using batch >> 2
    // without multiplying by sizeof(int8_t).
    Q6_V_vstu_variable(output, batch >> 2, vy);
  }
}
