// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DIV in ("DIV", "NR")
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]
#include <assert.h>
$if ARCH == "scalar":
  #include <math.h>
#include <stddef.h>

#include "xnnpack/simd/f32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/microparams.h"
#include "xnnpack/vunary.h"

// Define some mathematical constants in case they are not provided by `math.h`.
#ifndef M_LN2
#define M_LN2 0.69314718055994531
#endif  // M_LN2

// Extracts the exponent of the input `a` as a `float` value.
#ifndef HAVE_XNN_SIGNED_GETEXP_F32
#define HAVE_XNN_SIGNED_GETEXP_F32
static XNN_INLINE xnn_simd_f32_t xnn_signed_getexp_f32(xnn_simd_f32_t a) {
  $if ARCH == "avx512f":
    // Create a mask of the zeros in the input.
    __mmask16 zero_mask = _mm512_cmp_ps_mask(a, _mm512_setzero_ps(), _CMP_EQ_OQ);

    // Create a mask of the negative inputs.
    __mmask16 neg_mask = _mm512_cmp_ps_mask(a, _mm512_setzero_ps(), _CMP_LT_OQ);

    // Extract the exponent.
    __m512 res = _mm512_getexp_ps(a);

    // Set the zero inputs to `-Inf` and the negative inputs to `NaN`.
    res = _mm512_castsi512_ps(_mm512_mask_set1_epi32(
        _mm512_castps_si512(res), zero_mask, 0xFF800000 /*Inf*/));
    res = _mm512_castsi512_ps(_mm512_mask_set1_epi32(
        _mm512_castps_si512(res), neg_mask, 0x7FC00001 /*NaN*/));

    return res;
  $else:
    // The bits of IEE754 single-precision floating-point format are:
    //
    //   s | e e e e e e e e | m m m m m m m m m m m m m m m m m m m m m m m
    //
    // We start by masking out the sign and exponent and shifting it 8 bits to the
    // right arithmetically, i.e. extending by the leftmost sign bit:
    //
    //   s | s s s s s s s s | e e e e e e e e 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    //
    // These bits are then `or`-ed with `256.0f`, which has a biased exponent of
    // `135` and all mantissa bit set to zero. This is equivalent to adding the
    // biased integer exponent to `256.0`:
    //
    //   0 | 1 0 0 0 0 1 1 1 | e e e e e e e e 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    //
    // We can re-extract the exponent as a `float` value by subtracting `256.0`
    // plus the exponent bias `127.0`, i.e. `383.0`.
    //
    // Note that if the sign bit is `1`, we end up with the floating point bits:
    //
    //   1 | 1 1 1 1 1 1 1 1 | e e e e e e e e 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    //
    // Which is `-NaN` if the exponent is non-zero, and `-Inf` if the exponent is
    // zero (e.g. the input was `0.0f` or denormal).

    // Some useful constants.
    XNN_SIMD_CONST_F32(sign_mask, -0.0f);
    XNN_SIMD_CONST_F32_FROM_INT32(sign_and_exp_mask, 0xff800000);
    XNN_SIMD_CONST_F32(bias_256, 256.0f);
    XNN_SIMD_CONST_F32(bias_383, 383.0f);

    // If `a` is `0.0f`, flip its sign bit so that we return `-Inf`.
    a = xnn_or_f32(xnn_and_f32(xnn_cmpeq_f32(a, xnn_zero_f32()), sign_mask), a);

    // Extract the exponent and shift the exponent to the most significant bits of
    // the mantissa.
    const xnn_simd_f32_t exp =
        xnn_sra_f32(xnn_and_f32(a, sign_and_exp_mask), 8);

    // Add the shifted exponent to `256.0f` by copying its bits to the mantissa,
    // then subtract out `383.0f`, i.e. the original `256.0f` plus the `127`
    // exponent bias, resulting in the unbiased exponent.
    return xnn_sub_f32(xnn_or_f32(bias_256, exp), bias_383);
}
#endif  // HAVE_XNN_SIGNED_GETEXP_F32

$for BATCH_TILE in BATCH_TILES:
  $assert BATCH_TILE % SIMD_SIZE == 0
  $assert BATCH_TILE >= SIMD_SIZE
  $SIMD_TILE = BATCH_TILE // SIMD_SIZE

  void xnn_f32_vlog_ukernel__${ARCH}_rational_3_3_${DIV.lower()}_u${BATCH_TILE}(
      size_t batch,
      const float* input,
      float* output,
      const struct xnn_f32_default_params unused_params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});

    // Some useful constants.
    XNN_SIMD_CONST_F32(vone, 1.0f);
    XNN_SIMD_CONST_F32(vln2, M_LN2);
    XNN_SIMD_CONST_F32_FROM_INT32(vmantissa_bits_mask, 0x007FFFFFUL);

    // Note that these two values are not _exactly_ `(float)M_SQRT2` and
    // `(float)M_SQRT1_2`, but are instead chosen such that their product is
    // exactly `1.0f` when evaluated in `float` precision.
    XNN_SIMD_CONST_F32(vsqrt2, 1.4142134190e+00);
    XNN_SIMD_CONST_F32(vsqrt1_2, 7.0710688829e-01);

    // The monomial coefficients of the numerator polynomial.
    // XNN_SIMD_CONST_F32(valpha_0, 0.0f);
    // XNN_SIMD_CONST_F32(valpha_1, 1.0f);
    // XNN_SIMD_CONST_F32(valpha_2, 1.0f);
    XNN_SIMD_CONST_F32(valpha_3, 1.824996918440e-01);

    // The monomial coefficients of the denominator polynomial.
    // XNN_SIMD_CONST_F32(vbeta_0, 1.0f);
    XNN_SIMD_CONST_F32(vbeta_1, 1.5f);
    XNN_SIMD_CONST_F32(vbeta_2, 0.599170029163);
    XNN_SIMD_CONST_F32(vbeta_3, 0.049584995955);

    $if DIV == "NR":
      // Constant needed for the Newton-Raphson iteration of the reciprocal.
      XNN_SIMD_CONST_F32(vtwo, 2.0f);

    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        xnn_simd_f32_t vx_${ABC[0]} = xnn_loadu_f32(input);
        $for N in range(1, SIMD_TILE):
          xnn_simd_f32_t vx_${ABC[N]} = xnn_loadu_f32(input + ${N} * xnn_simd_size_f32);
        input += ${BATCH_TILE};

        // Scale `x` with `sqrt(2)` so that the exponent is rounded up.
        $for N in range(0, SIMD_TILE):
          vx_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vsqrt2);

        // Extract the exponent.
        $for N in range(0, SIMD_TILE):
          const xnn_simd_f32_t vexp_${ABC[N]} = xnn_signed_getexp_f32(vx_${ABC[N]});

        // Normalize `x` to an exponent of zero.
        $for N in range(0, SIMD_TILE):
          vx_${ABC[N]} = xnn_or_f32(xnn_and_f32(vx_${ABC[N]}, vmantissa_bits_mask), vone);

        // Scale `x` back with `1/sqrt(2)` to move its range from `[1.0, 2.0)` to
        // `[sqrt(1/2), sqrt(2))`, and further subtract `1.0` so that it is around
        // zero, i.e. `[sqrt(1/2) - 1, sqrt(2) - 1)`, or `[−0.29289, 0.4142136)`.
        $for N in range(0, SIMD_TILE):
          vx_${ABC[N]} = xnn_sub_f32(xnn_mul_f32(vx_${ABC[N]}, vsqrt1_2), vone);

        // In the following, we use a 3/2-degree rational polynomial to
        // approximate the (shifted) `log(x + 1)` on the (shifted) interval
        // `[sqrt(1/2) - 1, sqrt(2) - 1)`. The shifted interval is chosen so that
        // `f(0) = 0`.

        // Evaluate the numerator polynomial p.
        $for N in range(0, SIMD_TILE):
          xnn_simd_f32_t vp_${ABC[N]} = xnn_fmadd_f32(vx_${ABC[N]}, valpha_3, vone);
        $for N in range(0, SIMD_TILE):
          vp_${ABC[N]} = xnn_fmadd_f32(vx_${ABC[N]}, vp_${ABC[N]}, vone);
        $for N in range(0, SIMD_TILE):
          vp_${ABC[N]} = xnn_mul_f32(vx_${ABC[N]}, vp_${ABC[N]});

        // Evaluate the denominator polynomial q.
        $for N in range(0, SIMD_TILE):
          xnn_simd_f32_t vq_${ABC[N]} = xnn_fmadd_f32(vx_${ABC[N]}, vbeta_3, vbeta_2);
        $for N in range(0, SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx_${ABC[N]}, vq_${ABC[N]}, vbeta_1);
        $for N in range(0, SIMD_TILE):
          vq_${ABC[N]} = xnn_fmadd_f32(vx_${ABC[N]}, vq_${ABC[N]}, vone);

        // Divide the numerator by the denominator.
        $if DIV == "DIV":
          $for N in range(SIMD_TILE):
            xnn_simd_f32_t vy_${ABC[N]} = xnn_div_f32(vp_${ABC[N]}, vq_${ABC[N]});
        $else:
          $for N in range(SIMD_TILE):
            xnn_simd_f32_t vrq_${ABC[N]} = xnn_rcp_f32(vq_${ABC[N]});
          for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
            $for N in range(SIMD_TILE):
              vrq_${ABC[N]} = xnn_mul_f32(vrq_${ABC[N]}, xnn_fnmadd_f32(vrq_${ABC[N]}, vq_${ABC[N]}, vtwo));
          }
          $for N in range(SIMD_TILE):
            xnn_simd_f32_t vy_${ABC[N]} = xnn_mul_f32(vp_${ABC[N]}, vrq_${ABC[N]});

        // Put it all together, i.e. `log(x) = `log(2)*exp + y`.
        $for N in range(0, SIMD_TILE):
          vy_${ABC[N]} = xnn_fmadd_f32(vexp_${ABC[N]}, vln2, vy_${ABC[N]});

        xnn_storeu_f32(output, vy_${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy_${ABC[N]});
        output += ${BATCH_TILE};
      }
    for (; batch >= xnn_simd_bytes_f32; batch -= xnn_simd_bytes_f32) {
      xnn_simd_f32_t vx = xnn_loadu_f32(input);
      input += xnn_simd_size_f32;

      // Scale `x` with `sqrt(2)` so that the exponent is rounded up.
      vx = xnn_mul_f32(vx, vsqrt2);

      // Extract the exponent.
      const xnn_simd_f32_t vexp = xnn_signed_getexp_f32(vx);

      // Normalize `x` to an exponent of zero.
      vx = xnn_or_f32(xnn_and_f32(vx, vmantissa_bits_mask), vone);

      // Scale `x` back with `1/sqrt(2)` to move its range from `[1.0, 2.0)` to
      // `[sqrt(1/2), sqrt(2))`, and further subtract `1.0` so that it is around
      // zero, i.e. `[sqrt(1/2) - 1, sqrt(2) - 1)`, or `[−0.29289, 0.4142136)`.
      vx = xnn_sub_f32(xnn_mul_f32(vx, vsqrt1_2), vone);

      // In the following, we use a 3/2-degree rational polynomial to
      // approximate the (shifted) `log(x + 1)` on the (shifted) interval
      // `[sqrt(1/2) - 1, sqrt(2) - 1)`. The shifted interval is chosen so that
      // `f(0) = 0`.

      // Evaluate the numerator polynomial p.
      xnn_simd_f32_t vp = xnn_fmadd_f32(vx, valpha_3, vone);
      vp = xnn_fmadd_f32(vx, vp, vone);
      vp = xnn_mul_f32(vx, vp);

      // Evaluate the denominator polynomial q.
      xnn_simd_f32_t vq = xnn_fmadd_f32(vx, vbeta_3, vbeta_2);
      vq = xnn_fmadd_f32(vx, vq, vbeta_1);
      vq = xnn_fmadd_f32(vx, vq, vone);

      // Divide the numerator by the denominator.
      $if DIV == "DIV":
        xnn_simd_f32_t vy =  xnn_div_f32(vp, vq);
      $else:
        xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
        for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
          vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
        }
        xnn_simd_f32_t vy = xnn_mul_f32(vp, vrq);

      // Put it all together, i.e. `log(x) = `log(2)*exp + y`.
      vy = xnn_fmadd_f32(vexp, vln2, vy);

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }
    $if SIMD_SIZE > 1:
      if XNN_UNLIKELY(batch != 0) {
        xnn_simd_f32_t vx = xnn_load_tail_f32(input, batch >> XNN_LOG2_SIZEOF_FLOAT);

        // See the loop above for comments.
        vx = xnn_mul_f32(vx, vsqrt2);
        const xnn_simd_f32_t vexp = xnn_signed_getexp_f32(vx);
        vx = xnn_or_f32(xnn_and_f32(vx, vmantissa_bits_mask), vone);
        vx = xnn_sub_f32(xnn_mul_f32(vx, vsqrt1_2), vone);
        xnn_simd_f32_t vp = xnn_fmadd_f32(vx, valpha_3, vone);
        vp = xnn_fmadd_f32(vx, vp, vone);
        vp = xnn_mul_f32(vx, vp);
        xnn_simd_f32_t vq = xnn_fmadd_f32(vx, vbeta_3, vbeta_2);
        vq = xnn_fmadd_f32(vx, vq, vbeta_1);
        vq = xnn_fmadd_f32(vx, vq, vone);
        $if DIV == "DIV":
          xnn_simd_f32_t vy =  xnn_div_f32(vp, vq);
        $else:
          xnn_simd_f32_t vrq = xnn_rcp_f32(vq);
          for (size_t iter = 0; iter < XNN_SIMD_NUM_RCP_ITER_F32; iter++) {
            vrq = xnn_mul_f32(vrq, xnn_fnmadd_f32(vrq, vq, vtwo));
          }
          xnn_simd_f32_t vy = xnn_mul_f32(vp, vrq);
        vy = xnn_fmadd_f32(vexp, vln2, vy);

        xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
      }
  }
