// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE % 32 == 0
$assert BATCH_TILE >= 32
$SIMD_TILE = BATCH_TILE // 32

#include <assert.h>

#include "xnnpack/simd/f32-hvx.h"
#include "xnnpack/raddstoreexpminusmax.h"

$PARAMS_STRUCT = "hvx_rr2_p5"
$VMULADD_F32 = "xnn_fmadd_f32"
void xnn_f32_raddstoreexpminusmax_ukernel__hvx_rr2_p5_u${BATCH_TILE}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
    size_t batch,
    const float* input,
    const float* max,
    float* output,
    float* sum,
    const void* params) XNN_OOB_READS
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input != NULL);
  assert(max != NULL);
  assert(output != NULL);
  assert(sum != NULL);

  const HVX_Vector vi_max = xnn_set1_f32(*max);
  const HVX_Vector vlog2e = xnn_set1_f32(0x1.715476p+0f);
  const HVX_Vector vmagic_bias = xnn_set1_f32(0x1.8000FEp23f);
  const HVX_Vector vminus_ln2_hi = xnn_set1_f32(-0x1.62E400p-1f);
  const HVX_Vector vminus_ln2_lo = xnn_set1_f32(-0x1.7F7D1Cp-20f);
  const HVX_Vector vc5 = xnn_set1_f32(0x1.0F9F9Cp-7f);
  const HVX_Vector vc4 = xnn_set1_f32(0x1.573A1Ap-5f);
  const HVX_Vector vc3 = xnn_set1_f32(0x1.555A80p-3f);
  const HVX_Vector vc2 = xnn_set1_f32(0x1.FFFDC6p-2f);
  const HVX_Vector vc1 = xnn_set1_f32(0x1.FFFFF6p-1f);
  const HVX_Vector vdenorm_cutoff = xnn_set1_f32(-0x1.5D589Ep6f);

  XNN_FORCE_REALIZATION(vlog2e);
  XNN_FORCE_REALIZATION(vmagic_bias);
  XNN_FORCE_REALIZATION(vminus_ln2_hi);
  XNN_FORCE_REALIZATION(vminus_ln2_lo);
  XNN_FORCE_REALIZATION(vc5);
  XNN_FORCE_REALIZATION(vc4);
  XNN_FORCE_REALIZATION(vc3);
  XNN_FORCE_REALIZATION(vc2);
  XNN_FORCE_REALIZATION(vc1);
  XNN_FORCE_REALIZATION(vdenorm_cutoff);

  $if BATCH_TILE > 32:
    HVX_Vector vacc0 = Q6_V_vzero();
    $for K in range(1, ACCUMULATORS):
      HVX_Vector vacc${K} = vacc0;
    for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
      const HVX_Vector vi0 =  xnn_loadu_f32(input);
      $for N in range(32, BATCH_TILE, 32):
        const HVX_Vector vi${int(N/32)} = xnn_loadu_f32(input + ${N});
      input += ${BATCH_TILE};

      // Subtract maximum input x := i - i_max
      $for N in range(0, BATCH_TILE, 32):
        const HVX_Vector vx${int(N/32)} = xnn_sub_f32(vi${int(N/32)}, vi_max);

      // n := round(x / log(2))
      $for N in range(0, BATCH_TILE, 32):
        HVX_Vector vn${int(N/32)} = ${VMULADD_F32}(vx${int(N/32)}, vlog2e, vmagic_bias);

      // Create a floating-point number s (scale) such that s == 2**n for inputs which don't cause underflow.
      $for N in range(0, BATCH_TILE, 32):
        const HVX_Vector vs${int(N/32)} = Q6_Vw_vasl_VwR(vn${int(N/32)}, 23);

      // Subtract the large number back to get final batch := round(x / log(2)).
      $for N in range(0, BATCH_TILE, 32):
        vn${int(N/32)} = xnn_sub_f32(vn${int(N/32)}, vmagic_bias);

      // Compute reduced argument t := x - batch * log(2).
      // Use Cody-Waite range reduction method (note two constants to represent log(2)) to improve accuracy.
      $for N in range(0, BATCH_TILE, 32):
          HVX_Vector vt${int(N/32)} = ${VMULADD_F32}(vn${int(N/32)}, vminus_ln2_hi, vx${int(N/32)});

      $for N in range(0, BATCH_TILE, 32):
        vt${int(N/32)} = ${VMULADD_F32}(vn${int(N/32)}, vminus_ln2_lo, vt${int(N/32)});

      // Compute degree-5 polynomial approximation for exp(t) on [-log(2)/2, log(2)/2].
      //  p := c5 * t + c4;
      //  p = p * t + c3;
      //  p = p * t + c2;
      //  p = p * t + c1;
      $for N in range(0, BATCH_TILE, 32):
        HVX_Vector vp${int(N/32)} = ${VMULADD_F32}(vc5, vt${int(N/32)}, vc4);

      $for N in range(0, BATCH_TILE, 32):
        vp${int(N/32)} = ${VMULADD_F32}(vp${int(N/32)}, vt${int(N/32)}, vc3);

      $for N in range(0, BATCH_TILE, 32):
        vp${int(N/32)} = ${VMULADD_F32}(vp${int(N/32)}, vt${int(N/32)}, vc2);

      $for N in range(0, BATCH_TILE, 32):
        vp${int(N/32)} = ${VMULADD_F32}(vp${int(N/32)}, vt${int(N/32)}, vc1);

      // Reconstruct the final f value:
      //   f = s * (1 + t * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5)))))
      //     = s + (t * s) * (c1 + t * (c2 + t * (c3 + t * (c4 + t * c5))))
      //     = s + (t * s) * p
      $for N in range(0, BATCH_TILE, 32):
        vt${int(N/32)} = xnn_mul_f32(vt${int(N/32)}, vs${int(N/32)});

      $for N in range(0, BATCH_TILE, 32):
        HVX_Vector vf${int(N/32)} = ${VMULADD_F32}(vt${int(N/32)}, vp${int(N/32)}, vs${int(N/32)});

      // For inputs below zero cutoff, replace output with +0.0f.
      // Note that for NaN inputs, comparison result is false, and outputs are left unchanged.
      $for N in range(0, BATCH_TILE, 32):
        vf${int(N/32)} = Q6_V_vand_QnV(Q6_Q_vcmp_gt_VsfVsf(vdenorm_cutoff, vx${int(N/32)}), vf${int(N/32)});

      xnn_storeu_f32(output, vf0);
      $for N in range(32, BATCH_TILE, 32):
        xnn_storeu_f32(output + ${N}, vf${int(N/32)});
      output += ${BATCH_TILE};

      $for N in range(0, BATCH_TILE, 32):
        vacc${N % ACCUMULATORS} = xnn_add_f32(vacc${N % ACCUMULATORS}, vf${int(N/32)});
    }
    $if ACCUMULATORS > 1:
      $ACC_SLICE = 1
      $while ACC_SLICE < ACCUMULATORS:
        $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
          $if A + ACC_SLICE < ACCUMULATORS:
            vacc${A} = xnn_add_f32(vacc${A}, vacc${A + ACC_SLICE});
        $ACC_SLICE *= 2

    HVX_Vector vacc = vacc0;
  $if BATCH_TILE == 32:
    HVX_Vector vacc = Q6_V_vzero();
  for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) {
    const HVX_Vector vi = xnn_loadu_f32(input);
    input += 32;

    const HVX_Vector vx = xnn_sub_f32(vi, vi_max);

    HVX_Vector vn = ${VMULADD_F32}(vx, vlog2e, vmagic_bias);

    const HVX_Vector vs = Q6_Vw_vasl_VwR(vn, 23);

    vn = xnn_sub_f32(vn, vmagic_bias);

    HVX_Vector vt = ${VMULADD_F32}(vn, vminus_ln2_hi, vx);
    vt = ${VMULADD_F32}(vn, vminus_ln2_lo, vt);

    HVX_Vector vp = ${VMULADD_F32}(vc5, vt, vc4);
    vp = ${VMULADD_F32}(vp, vt, vc3);
    vp = ${VMULADD_F32}(vp, vt, vc2);
    vp = ${VMULADD_F32}(vp, vt, vc1);

    vt = xnn_mul_f32(vt, vs);
    HVX_Vector vf = ${VMULADD_F32}(vt, vp, vs);

    vf = Q6_V_vand_QnV(Q6_Q_vcmp_gt_VsfVsf(vdenorm_cutoff, vx), vf);

    xnn_storeu_f32(output, vf);
    output += 32;

    vacc = xnn_add_f32(vacc, vf);
  }

  float vacc_lo = Q6_f32_vrsum_Vsf(vacc);
  if XNN_UNLIKELY(batch != 0) {
    assert(batch >= 1 * sizeof(float));
    assert(batch < 32 * sizeof(float));

    const HVX_Vector vi = xnn_loadu_f32(input);

    const HVX_Vector vx = xnn_sub_f32(vi, vi_max);

    HVX_Vector vn = ${VMULADD_F32}(vx, vlog2e, vmagic_bias);

    const HVX_Vector vs = Q6_Vw_vasl_VwR(vn, 23);

    vn = xnn_sub_f32(vn, vmagic_bias);

    HVX_Vector vt = ${VMULADD_F32}(vn, vminus_ln2_hi, vx);
    vt = ${VMULADD_F32}(vn, vminus_ln2_lo, vt);

    HVX_Vector vp = ${VMULADD_F32}(vc5, vt, vc4);
    vp = ${VMULADD_F32}(vp, vt, vc3);
    vp = ${VMULADD_F32}(vp, vt, vc2);
    vp = ${VMULADD_F32}(vp, vt, vc1);

    vt = xnn_mul_f32(vt, vs);
    HVX_Vector vf = ${VMULADD_F32}(vt, vp, vs);

    vf = Q6_V_vand_QnV(Q6_Q_vcmp_gt_VsfVsf(vdenorm_cutoff, vx), vf);

    Q6_V_vstu_variable(output, batch, vf);

    vf = Q6_V_vand_QV(Q6_Q_vsetq_R(batch), vf);
    vacc_lo += Q6_f32_vrsum_Vsf(vf);
  }
  *sum = vacc_lo;
}
