// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DIV in ("DIV", "NR")
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]
#include <assert.h>
#include <stddef.h>

#include "xnnpack/simd/f32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/microparams.h"
#include "xnnpack/vunary.h"

$for BATCH_TILE in BATCH_TILES:
  $assert BATCH_TILE % SIMD_SIZE == 0
  $assert BATCH_TILE >= SIMD_SIZE
  $SIMD_TILE = BATCH_TILE // SIMD_SIZE

  void xnn_f32_vgelu_ukernel__${ARCH}_rational_12_10_${DIV.lower()}_u${BATCH_TILE}(
      size_t batch,
      const float* input,
      float* output,
      const struct xnn_f32_default_params unused_params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});

    // Cap the inputs to this value as `erf(x/sqrt(2))` will always be `+/-1.0f`
    // beyond this point. This value is chosen as the first floating point
    // number as of which the interpolation returns +/-1.0f.
    #if XNN_SIMD_HAS_NATIVE_FMA || (XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR)
      $if DIV == "NR":
        XNN_SIMD_CONST_F32(vmax_abs_x, 5.1164608002e+00f);
      $else:
        XNN_SIMD_CONST_F32(vmax_abs_x, 5.1638283730e+00f);
    #else
      XNN_SIMD_CONST_F32(vmax_abs_x, 5.1158981323e+00f);
    #endif  // XNN_SIMD_HAS_NATIVE_FMA

    // The monomial coefficients of the numerator polynomial (odd).
    XNN_SIMD_CONST_F32(valpha_1, 7.9788452387e-01f);
    XNN_SIMD_CONST_F32(valpha_3, 6.6972173750e-02f);
    XNN_SIMD_CONST_F32(valpha_5, 9.3065137044e-03f);
    XNN_SIMD_CONST_F32(valpha_7, 3.2973114867e-04f);
    XNN_SIMD_CONST_F32(valpha_9, 1.2609783880e-05f);
    XNN_SIMD_CONST_F32(valpha_11, 4.5835321316e-08f);

    // The monomial coefficients of the denominator polynomial (even).
    // XNN_SIMD_CONST_F32(vbeta_0, 1.0f);
    XNN_SIMD_CONST_F32(vbeta_2, 2.5060352683e-01f);
    XNN_SIMD_CONST_F32(vbeta_4, 2.8431978077e-02f);
    XNN_SIMD_CONST_F32(vbeta_6, 1.8622842617e-03f);
    XNN_SIMD_CONST_F32(vbeta_8, 7.2267655923e-05f);
    XNN_SIMD_CONST_F32(vbeta_10, 1.1988805682e-06f);

    XNN_SIMD_CONST_F32(vone, 1.0f);
    XNN_SIMD_CONST_F32(vhalf, 0.5f);
    $if DIV == "NR":
      // Constant needed for the Newton-Raphson iteration of the reciprocal.
      XNN_SIMD_CONST_F32(vtwo, 2.0f);

    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        const xnn_simd_f32_t vx_orig_${ABC[0]} = xnn_loadu_f32(input);
        $for N in range(1, SIMD_TILE):
          const xnn_simd_f32_t vx_orig_${ABC[N]} = xnn_loadu_f32(input + ${N} * xnn_simd_size_f32);
        input += ${BATCH_TILE};

        // Clamp the inputs to the interpolation range.
        $for N in range(SIMD_TILE):
          xnn_simd_f32_t vx_${ABC[N]} = xnn_min_f32(vmax_abs_x, vx_orig_${ABC[N]});
        $for N in range(SIMD_TILE):
          vx_${ABC[N]} = xnn_max_f32(xnn_neg_f32(vmax_abs_x), vx_${ABC[N]});

        // Since the polynomials are odd/even, we need x^2.
        $for N in range(SIMD_TILE):
          const xnn_simd_f32_t vx2_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vx_${ABC[N]});

        // Evaluate the numerator polynomial p.
        $for N in range(SIMD_TILE):
          xnn_simd_f32_t vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, valpha_11, valpha_9);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_7);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_5);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_3);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_1);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vp_${ABC[N]});

        // Evaluate the denominator polynomial q.
        $for N in range(SIMD_TILE):
          xnn_simd_f32_t vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vbeta_10, vbeta_8);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vbeta_6);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vbeta_4);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vbeta_2);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vone);

        // Divide the numerator by the denominator.
        $if DIV == "DIV":
          $for N in range(SIMD_TILE):
            const xnn_simd_f32_t verf_${ABC[N]} = xnn_div_f32(vp_${ABC[N]}, vq_${ABC[N]});
        $else:
          $for N in range(SIMD_TILE):
            xnn_simd_f32_t vrq_${ABC[N]} = xnn_rcp_f32(vq_${ABC[N]});
          for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
            $for N in range(SIMD_TILE):
              vrq_${ABC[N]} = xnn_mul_f32(vrq_${ABC[N]}, xnn_fnmadd_f32(vrq_${ABC[N]}, vq_${ABC[N]}, vtwo));
          }
          // Note that we _could_ use a fused multiply-add to compute `p * rq + 1`,
          // but we actually want this to round to zero near the edges, so we
          // don't want the extended precision of the fused multiply-add.
          $for N in range(SIMD_TILE):
            const xnn_simd_f32_t verf_${ABC[N]} = xnn_mul_f32(vp_${ABC[N]}, vrq_${ABC[N]});

        // Add one to the rational interpolant, and multiply by 0.5 times the
        // original input.
        $for N in range(SIMD_TILE):
          const xnn_simd_f32_t vy_${ABC[N]} = xnn_mul_f32(xnn_mul_f32(vx_orig_${ABC[N]}, vhalf),
                                              xnn_add_f32(verf_${ABC[N]}, vone));

        xnn_storeu_f32(output, vy_${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy_${ABC[N]});
        output += ${BATCH_TILE};
      }
    for (; batch >= xnn_simd_bytes_f32; batch -= xnn_simd_bytes_f32) {
      const xnn_simd_f32_t vx_orig = xnn_loadu_f32(input);
      input += xnn_simd_size_f32;

      // Clamp the inputs to the interpolation range.
      xnn_simd_f32_t vx = xnn_min_f32(vmax_abs_x, vx_orig);
      vx = xnn_max_f32(xnn_neg_f32(vmax_abs_x), vx);

      // Since the polynomials are odd/even, we need x^2.
      const xnn_simd_f32_t vx2 = xnn_mul_f32(vx, vx);

      // Evaluate the numerator polynomial p.
      xnn_simd_f32_t vp = xnn_fmadd_f32(vx2, valpha_11, valpha_9);
      vp = xnn_fmadd_f32(vx2, vp, valpha_7);
      vp = xnn_fmadd_f32(vx2, vp, valpha_5);
      vp = xnn_fmadd_f32(vx2, vp, valpha_3);
      vp = xnn_fmadd_f32(vx2, vp, valpha_1);
      vp = xnn_mul_f32(vx, vp);

      // Evaluate the denominator polynomial q.
      xnn_simd_f32_t vq = xnn_fmadd_f32(vx2, vbeta_10, vbeta_8);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_6);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_4);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_2);
      vq = xnn_fmadd_f32(vx2, vq, vone);

      // Divide the numerator by the denominator and add one
      $if DIV == "DIV":
        const xnn_simd_f32_t verf =  xnn_div_f32(vp, vq);
      $else:
        xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
        for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
          vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
        }
        // Note that we _could_ use a fused multiply-add to compute `p * rq + 1`,
        // but we actually want this to round to zero near the edges, so we
        // don't want the extended precision of the fused multiply-add.
        const xnn_simd_f32_t verf = xnn_mul_f32(vp, vrq);

      // Add one to the rational interpolant, and multiply by 0.5 times the
      // original input.
      const xnn_simd_f32_t vy = xnn_mul_f32(xnn_mul_f32(vx_orig, vhalf),
                                            xnn_add_f32(verf, vone));

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }
    $if SIMD_SIZE > 1:
      if XNN_UNLIKELY(batch != 0) {
        xnn_simd_f32_t vx_orig = xnn_load_tail_f32(input, batch >> XNN_LOG2_SIZEOF_FLOAT);

      // See above for comments.
      xnn_simd_f32_t vx = xnn_min_f32(vmax_abs_x, vx_orig);
      vx = xnn_max_f32(xnn_neg_f32(vmax_abs_x), vx);
      const xnn_simd_f32_t vx2 = xnn_mul_f32(vx, vx);
      xnn_simd_f32_t vp = xnn_fmadd_f32(vx2, valpha_11, valpha_9);
      vp = xnn_fmadd_f32(vx2, vp, valpha_7);
      vp = xnn_fmadd_f32(vx2, vp, valpha_5);
      vp = xnn_fmadd_f32(vx2, vp, valpha_3);
      vp = xnn_fmadd_f32(vx2, vp, valpha_1);
      vp = xnn_mul_f32(vx, vp);
      xnn_simd_f32_t vq = xnn_fmadd_f32(vx2, vbeta_10, vbeta_8);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_6);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_4);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_2);
      vq = xnn_fmadd_f32(vx2, vq, vone);
      $if DIV == "DIV":
        const xnn_simd_f32_t verf =  xnn_div_f32(vp, vq);
      $else:
        xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
        for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
          vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
        }
        const xnn_simd_f32_t verf = xnn_mul_f32(vp, vrq);
      const xnn_simd_f32_t vy = xnn_mul_f32(xnn_mul_f32(vx_orig, vhalf),
                                            xnn_add_f32(verf, vone));

        xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
      }
  }
