// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DIV in ("DIV", "NR")
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]
#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/simd/f32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/microparams.h"
#include "xnnpack/vunary.h"

$for BATCH_TILE in BATCH_TILES:
  $assert BATCH_TILE % SIMD_SIZE == 0
  $assert BATCH_TILE >= SIMD_SIZE
  $SIMD_TILE = BATCH_TILE // SIMD_SIZE

  void xnn_f32_vtanh_ukernel__${ARCH}_rational_9_8_${DIV.lower()}_u${BATCH_TILE}(
      size_t batch,
      const float* input,
      float* output,
      const struct xnn_f32_default_params unused_params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});

    // Cap the inputs to this value as `tanh(x)` will always be `+/-1.0f` beyond
    // this point. This value is chosen as the first floating point number as of
    // which the interpolation returns 1.0f.
    #if XNN_SIMD_HAS_NATIVE_FMA
      $if DIV == "DIV":
        XNN_SIMD_CONST_F32(vmax_x, 7.9807181358e+00);
        XNN_SIMD_CONST_F32(vmin_x, -7.9807181358e+00);
      $else:
        $if ARCH == "avx512f":
          XNN_SIMD_CONST_F32(vmax_x, 7.8473143578e+00);
          XNN_SIMD_CONST_F32(vmin_x, -7.8473143578e+00);
        $else:
          XNN_SIMD_CONST_F32(vmax_x, 7.8947348595e+00);
          XNN_SIMD_CONST_F32(vmin_x, -7.8947348595e+00);
    #else
      $if DIV == "DIV":
        XNN_SIMD_CONST_F32(vmax_x, 7.8522667885e+00);
        XNN_SIMD_CONST_F32(vmin_x, -7.8522667885e+00);
      $else:
        XNN_SIMD_CONST_F32(vmax_x, 7.7497606277e+00);
        XNN_SIMD_CONST_F32(vmin_x, -7.7497606277e+00);
    #endif  // XNN_SIMD_HAS_NATIVE_FMA

    // The monomial coefficients of the numerator polynomial (odd).
    // XNN_SIMD_CONST_F32(valpha_1, 1.0000000000e+00f);
    XNN_SIMD_CONST_F32(valpha_3, 1.3412411511e-01f);
    XNN_SIMD_CONST_F32(valpha_5, 3.5330520477e-03f);
    XNN_SIMD_CONST_F32(valpha_7, 2.1235626264e-05f);
    XNN_SIMD_CONST_F32(valpha_9, 1.4248920266e-08f);


    // The monomial coefficients of the denominator polynomial (even).
    // XNN_SIMD_CONST_F32(vbeta_0, 1.0000000000e+00f);
    XNN_SIMD_CONST_F32(vbeta_2, 4.6745735407e-01f);
    XNN_SIMD_CONST_F32(vbeta_4, 2.6018999517e-02f);
    XNN_SIMD_CONST_F32(vbeta_6, 3.3472978976e-04f);
    XNN_SIMD_CONST_F32(vbeta_8, 8.1365948290e-07f);

    // Some useful constants.
    XNN_SIMD_CONST_F32(vone, 1.0f);
    $if DIV == "NR":
      XNN_SIMD_CONST_F32(vtwo, 2.0f);

    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        xnn_simd_f32_t vx_${ABC[0]} = xnn_loadu_f32(input);
        $for N in range(1, SIMD_TILE):
          xnn_simd_f32_t vx_${ABC[N]} = xnn_loadu_f32(input + ${N} * xnn_simd_size_f32);
        input += ${BATCH_TILE};

        // Clamp the inputs to the interpolation range.
        $for N in range(SIMD_TILE):
          vx_${ABC[N]} = xnn_min_f32(vmax_x, vx_${ABC[N]});
        $for N in range(SIMD_TILE):
          vx_${ABC[N]} = xnn_max_f32(vmin_x, vx_${ABC[N]});

        // Since the polynomials are odd/even, we need x^2.
        $for N in range(SIMD_TILE):
          const xnn_simd_f32_t vx2_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vx_${ABC[N]});

        // Evaluate the numerator polynomial p.
        $for N in range(SIMD_TILE):
          xnn_simd_f32_t vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, valpha_9, valpha_7);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_5);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, valpha_3);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vp_${ABC[N]}, vone);
        $for N in range(SIMD_TILE):
          vp_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vp_${ABC[N]});

        // Evaluate the denominator polynomial q.
        $for N in range(SIMD_TILE):
          xnn_simd_f32_t vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vbeta_8, vbeta_6);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vbeta_4);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vbeta_2);
        $for N in range(SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx2_${ABC[N]}, vq_${ABC[N]}, vone);

        // Divide the numerator by the denominator.
        $if DIV == "DIV":
          $for N in range(SIMD_TILE):
            const xnn_simd_f32_t vy_${ABC[N]} = xnn_div_f32(vp_${ABC[N]}, vq_${ABC[N]});
        $else:
          $for N in range(SIMD_TILE):
            xnn_simd_f32_t vrq_${ABC[N]} = xnn_rcp_f32(vq_${ABC[N]});
          for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
            $for N in range(SIMD_TILE):
              vrq_${ABC[N]} = xnn_mul_f32(vrq_${ABC[N]}, xnn_fnmadd_f32(vrq_${ABC[N]}, vq_${ABC[N]}, vtwo));
          }
          $for N in range(SIMD_TILE):
            const xnn_simd_f32_t vy_${ABC[N]} = xnn_mul_f32(vp_${ABC[N]}, vrq_${ABC[N]});

        xnn_storeu_f32(output, vy_${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy_${ABC[N]});
        output += ${BATCH_TILE};
      }
    for (; batch >= xnn_simd_bytes_f32; batch -= xnn_simd_bytes_f32) {
      xnn_simd_f32_t vx = xnn_loadu_f32(input);
      input += xnn_simd_size_f32;

      // Clamp the inputs to the interpolation range.
      vx = xnn_min_f32(vmax_x, vx);
      vx = xnn_max_f32(vmin_x, vx);

      // Since the polynomials are odd/even, we need x^2.
      const xnn_simd_f32_t vx2 = xnn_mul_f32(vx, vx);

      // Evaluate the numerator polynomial p.
      xnn_simd_f32_t vp = xnn_fmadd_f32(vx2, valpha_9, valpha_7);
      vp = xnn_fmadd_f32(vx2, vp, valpha_5);
      vp = xnn_fmadd_f32(vx2, vp, valpha_3);
      vp = xnn_fmadd_f32(vx2, vp, vone);
      vp = xnn_mul_f32(vx, vp);

      // Evaluate the denominator polynomial q.
      xnn_simd_f32_t vq = xnn_fmadd_f32(vx2, vbeta_8, vbeta_6);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_4);
      vq = xnn_fmadd_f32(vx2, vq, vbeta_2);
      vq = xnn_fmadd_f32(vx2, vq, vone);

      // Divide the numerator by the denominator.
      $if DIV == "DIV":
        const xnn_simd_f32_t vy =  xnn_div_f32(vp, vq);
      $else:
        xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
        for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
          vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
        }
        const xnn_simd_f32_t vy = xnn_mul_f32(vp, vrq);

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }
    $if SIMD_SIZE > 1:
      if XNN_UNLIKELY(batch != 0) {
        xnn_simd_f32_t vx = xnn_load_tail_f32(input, batch >> XNN_LOG2_SIZEOF_FLOAT);

        // Clamp the inputs to the interpolation range.
        vx = xnn_min_f32(vmax_x, vx);
        vx = xnn_max_f32(vmin_x, vx);

        // Since the polynomials are odd/even, we need x^2.
        const xnn_simd_f32_t vx2 = xnn_mul_f32(vx, vx);

        // Evaluate the numerator polynomial p.
        xnn_simd_f32_t vp = xnn_fmadd_f32(vx2, valpha_9, valpha_7);
        vp = xnn_fmadd_f32(vx2, vp, valpha_5);
        vp = xnn_fmadd_f32(vx2, vp, valpha_3);
        vp = xnn_fmadd_f32(vx2, vp, vone);
        vp = xnn_mul_f32(vx, vp);

        // Evaluate the denominator polynomial q.
        xnn_simd_f32_t vq = xnn_fmadd_f32(vx2, vbeta_8, vbeta_6);
        vq = xnn_fmadd_f32(vx2, vq, vbeta_4);
        vq = xnn_fmadd_f32(vx2, vq, vbeta_2);
        vq = xnn_fmadd_f32(vx2, vq, vone);

        // Divide the numerator by the denominator.
        $if DIV == "DIV":
          const xnn_simd_f32_t vy =  xnn_div_f32(vp, vq);
        $else:
          xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
          for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
            vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
          }
          const xnn_simd_f32_t vy = xnn_mul_f32(vp, vrq);

        xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
      }
  }
