// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert P == H + 1 or P == H + 2
$assert not PS or (P, H) == (4, 3)
$assert DIV in ["DIV", "RCP"]
$assert SAT in ["MINMAX", "SELECT"]
$assert AVX != 2 or FMA == 3
$assert BATCH_TILE % 8 == 0
$assert BATCH_TILE >= 8
$SIMD_TILE = BATCH_TILE // 8
#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <immintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/microparams.h"
#include "xnnpack/vunary.h"

$COEFS = {
$  3: "0x1.560722p+0f", 
$  2: "0x1.01E2A2p+1f",
$}

$POLY_SUFFIX = "p%dh%d%s" % (P, H, "ps" if PS else "ts")
$DIV_SUFFIX = DIV.lower()
$PARAMS_STRUCT = "avx_expm1minus_rr1_" + ("p%dh%d" % (P, H))
$ISA = "avx2" if AVX == 2 else "fma3" if FMA == 3 else "f16c"
void xnn_f16_vtanh_ukernel__${ISA}_expm1minus_rr1_${POLY_SUFFIX}_${DIV_SUFFIX}_u${BATCH_TILE}(
    size_t batch,
    const xnn_float16* input,
    xnn_float16* output,
    const struct xnn_f16_default_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(batch % sizeof(uint16_t) == 0);
  assert(input != NULL);
  assert(output != NULL);

  const __m128i vsign_mask = _mm_set1_epi16(UINT16_C(0x8000));
  const __m256 vsat_cutoff = _mm256_set1_ps(-0x1.208000p+2f);
  const __m256 vlog2e = _mm256_set1_ps(0x1.715476p+0f);
  const __m256 vmagic_bias = _mm256_set1_ps(0x1.8000FEp+22f);
  const __m256 vminus_ln2 = _mm256_set1_ps(-0x1.62E430p-1f);
  XNN_FORCE_REALIZATION(vsign_mask);
  XNN_FORCE_REALIZATION(vsat_cutoff);
  XNN_FORCE_REALIZATION(vlog2e);
  XNN_FORCE_REALIZATION(vmagic_bias);
  XNN_FORCE_REALIZATION(vminus_ln2);
  $for i in reversed(range(2, P+1)):
    const __m256 vc${i} = _mm256_set1_ps(${COEFS[i]});
    XNN_FORCE_REALIZATION(vc${i});
  $if P != H + 1:
    const __m256 vminus_one = _mm256_set1_ps(-1.0f);
    XNN_FORCE_REALIZATION(vminus_one);
  const __m256 vtwo = _mm256_set1_ps(2.0f);
  XNN_FORCE_REALIZATION(vtwo);
  $if P == H + 1:
    const __m256 vminus_one = _mm256_set1_ps(-1.0f);
    XNN_FORCE_REALIZATION(vminus_one);

  const uint16_t* i = (const uint16_t*) input;
  uint16_t* o = (uint16_t*) output;
  $if BATCH_TILE > 8:
    for (; batch >= ${BATCH_TILE} * sizeof(uint16_t); batch -= ${BATCH_TILE} * sizeof(uint16_t)) {
      const __m128i vx0 = _mm_loadu_si128((const __m128i*) i);
      $for N in range(1, SIMD_TILE):
        const __m128i vx${N} = _mm_loadu_si128((const __m128i*) (i + ${N * 8}));
      i += ${BATCH_TILE};

      $for N in range(SIMD_TILE):
        const __m128i vabsx${N} = _mm_or_si128(vx${N}, vsign_mask);

      $for N in range(SIMD_TILE):
        __m256 vz${N} = _mm256_cvtph_ps(vabsx${N});
        const __m128i vinvsignx${N} = _mm_xor_si128(vx${N}, vabsx${N});

      $for N in range(SIMD_TILE):
        $if SAT == "MINMAX":
          vz${N} = _mm256_max_ps(vsat_cutoff, vz${N});
        $elif SAT == "SELECT":
          const __m256 vm${N} = _mm256_cmp_ps(vz${N}, vsat_cutoff, _CMP_LE_OS);
        $if FMA == 3:
          __m256 vn${N} = _mm256_fmadd_ps(vz${N}, vlog2e, vmagic_bias);
        $else:
          __m256 vn${N} = _mm256_add_ps(_mm256_mul_ps(vz${N}, vlog2e), vmagic_bias);

      $if AVX == 1:
        $for N in range(SIMD_TILE):
          const __m128 vn${N}_hi = _mm256_extractf128_ps(vn${N}, 1);
          __m256 vs${N} = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn${N})), 23)));
          vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias);

        $for N in range(SIMD_TILE):
          const __m128 vs${N}_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn${N}_hi), 23));

        $for N in range(SIMD_TILE):
          vs${N} = _mm256_insertf128_ps(vs${N}, vs${N}_hi, 1);
      $else:
        $for N in range(SIMD_TILE):
          const __m256 vs${N} = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn${N}), 23));
          vn${N} = _mm256_sub_ps(vn${N}, vmagic_bias);

      $for N in range(SIMD_TILE):
        $if FMA == 3:
          const __m256 vt${N} = _mm256_fmadd_ps(vn${N}, vminus_ln2, vz${N});
        $else:
          const __m256 vt${N} = _mm256_add_ps(_mm256_mul_ps(vn${N}, vminus_ln2), vz${N});

      $if FMA == 3:
        $for N in range(SIMD_TILE):
          __m256 vp${N} = vc${P};
        $for i in reversed(range(2, P)):
          $for N in range(SIMD_TILE):
            vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vc${i});
      $else:
        $for N in range(SIMD_TILE):
          __m256 vp${N} = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt${N}), vc${P-1});
        $for i in reversed(range(2, P-1)):
          $for N in range(SIMD_TILE):
            vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vt${N}), vc${i});
      $if P == H + 1:
        $for N in range(SIMD_TILE):
          $if FMA == 3:
            vp${N} = _mm256_fmadd_ps(vp${N}, vt${N}, vtwo);
          $else:
            vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vt${N}), vtwo);
      $else:
        $for N in range(SIMD_TILE):
          vp${N} = _mm256_mul_ps(vp${N}, vt${N});

      $for N in range(SIMD_TILE):
        const __m256 vts${N} = _mm256_mul_ps(vt${N}, vs${N});
        const __m256 vsmo${N} = _mm256_add_ps(vs${N}, vminus_one);
      $if P == H + 1:
        $for N in range(SIMD_TILE):
          $if FMA == 3:
            const __m256 vemo${N} = _mm256_fmadd_ps(vp${N}, vts${N}, vsmo${N});
          $else:
            const __m256 vemo${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vts${N}), vsmo${N});
      $else:
        $if FMA == 3:
          $for N in range(SIMD_TILE):
            vp${N} = _mm256_fmadd_ps(vp${N}, vts${N}, vts${N});
          $for N in range(SIMD_TILE):
            const __m256 vemo${N} = _mm256_fmadd_ps(vp${N}, vtwo${N}, vsmo${N});
        $else:
          $for N in range(SIMD_TILE):
            vp${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vts${N}), vts${N});
          $for N in range(SIMD_TILE):
            const __m256 vemo${N} = _mm256_add_ps(_mm256_mul_ps(vp${N}, vtwo), vsmo${N});

      $for N in range(SIMD_TILE):
        const __m256 vepo${N} = _mm256_add_ps(vemo${N}, vtwo);

      $if DIV == "DIV":
        $for N in range(SIMD_TILE):
          __m256 vy${N} = _mm256_div_ps(vemo${N}, vepo${N});
      $else:
        $for N in range(SIMD_TILE):
          __m256 vrepo${N} = _mm256_rcp_ps(vepo${N});

        $for N in range(SIMD_TILE):
          __m256 vy${N} = _mm256_mul_ps(vemo${N}, vrepo${N});

      $if SAT == "SELECT":
        $for N in range(SIMD_TILE):
          vy${N} = _mm256_blendv_ps(vy${N}, vminus_one, vm${N});

      $for N in range(SIMD_TILE):
        __m128i vh${N} = _mm256_cvtps_ph(vy${N}, _MM_FROUND_TO_NEAREST_INT);
      $for N in range(SIMD_TILE):
        vh${N} = _mm_xor_si128(vh${N}, vinvsignx${N});

      _mm_storeu_si128((__m128i*) o, vh0);
      $for N in range(1, SIMD_TILE):
        _mm_storeu_si128((__m128i*) (o + ${N * 8}), vh${N});
      o += ${BATCH_TILE};
    }
  for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) {
    const __m128i vx = _mm_loadu_si128((const __m128i*) i);
    i += 8;

    const __m128i vabsx = _mm_or_si128(vx, vsign_mask);
    __m256 vz = _mm256_cvtph_ps(vabsx);

    const __m128i vinvsignx = _mm_xor_si128(vx, vabsx);
    $if SAT == "MINMAX":
      vz = _mm256_max_ps(vsat_cutoff, vz);
    $elif SAT == "SELECT":
      const __m256 vm = _mm256_cmp_ps(vz, vsat_cutoff, _CMP_LE_OS);

    $if FMA == 3:
      __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
    $else:
      __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);

    $if AVX == 1:
      const __m128 vn_hi = _mm256_extractf128_ps(vn, 1);
      __m256 vs = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)));
      const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn_hi), 23));
      vs = _mm256_insertf128_ps(vs, vs_hi, 1);
    $else:
      const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));

    vn = _mm256_sub_ps(vn, vmagic_bias);

    $if FMA == 3:
      const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
    $else:
      const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz);

    $if FMA == 3:
      __m256 vp = vc${P};
      $for i in reversed(range(2, P)):
        vp = _mm256_fmadd_ps(vp, vt, vc${i});
    $else:
      __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt), vc${P-1});
      $for i in reversed(range(2, P-1)):
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc${i});
    $if P == H + 1:
      $if FMA == 3:
        vp = _mm256_fmadd_ps(vp, vt, vtwo);
      $else:
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vtwo);
    $else:
      vp = _mm256_mul_ps(vp, vt);

    const __m256 vts = _mm256_mul_ps(vt, vs);
    const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
    $if P == H + 1:
      $if FMA == 3:
        const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
      $else:
        const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vts), vsmo);
    $else:
      $if FMA == 3:
        vp = _mm256_fmadd_ps(vp, vts, vts);
        const __m256 vemo = _mm256_fmadd_ps(vp, vtwo, vsmo);
      $else:
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts);
        const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo);

    const __m256 vepo = _mm256_add_ps(vemo, vtwo);

    $if DIV == "DIV":
      __m256 vy = _mm256_div_ps(vemo, vepo);
    $else:
      __m256 vrepo = _mm256_rcp_ps(vepo);

      __m256 vy = _mm256_mul_ps(vemo, vrepo);

    $if SAT == "SELECT":
      vy = _mm256_blendv_ps(vy, vminus_one, vm);

    __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_TO_NEAREST_INT);
    vh = _mm_xor_si128(vh, vinvsignx);

    _mm_storeu_si128((__m128i*) o, vh);
    o += 8;
  }
  if (batch != 0) {
    const __m128i vx = _mm_loadu_si128((const __m128i*) i);

    const __m128i vabsx = _mm_or_si128(vx, vsign_mask);
    __m256 vz = _mm256_cvtph_ps(vabsx);

    const __m128i vinvsignx = _mm_xor_si128(vx, vabsx);
    $if SAT == "MINMAX":
      vz = _mm256_max_ps(vsat_cutoff, vz);
    $elif SAT == "SELECT":
      const __m256 vm = _mm256_cmp_ps(vz, vsat_cutoff, _CMP_LE_OS);

    $if FMA == 3:
      __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
    $else:
      __m256 vn = _mm256_add_ps(_mm256_mul_ps(vz, vlog2e), vmagic_bias);

    $if AVX == 1:
      const __m128 vn_hi = _mm256_extractf128_ps(vn, 1);
      __m256 vs = _mm256_castps128_ps256(_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(_mm256_castps256_ps128(vn)), 23)));
      const __m128 vs_hi = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(vn_hi), 23));
      vs = _mm256_insertf128_ps(vs, vs_hi, 1);
    $else:
      const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));

    vn = _mm256_sub_ps(vn, vmagic_bias);

    $if FMA == 3:
      const __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
    $else:
      const __m256 vt = _mm256_add_ps(_mm256_mul_ps(vn, vminus_ln2), vz);

    $if FMA == 3:
      __m256 vp = vc${P};
      $for i in reversed(range(2, P)):
        vp = _mm256_fmadd_ps(vp, vt, vc${i});
    $else:
      __m256 vp = _mm256_add_ps(_mm256_mul_ps(vc${P}, vt), vc${P-1});
      $for i in reversed(range(2, P-1)):
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vc${i});
    $if P == H + 1:
      $if FMA == 3:
        vp = _mm256_fmadd_ps(vp, vt, vtwo);
      $else:
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vt), vtwo);
    $else:
      vp = _mm256_mul_ps(vp, vt);

    const __m256 vts = _mm256_mul_ps(vt, vs);
    const __m256 vsmo = _mm256_add_ps(vs, vminus_one);
    $if P == H + 1:
      $if FMA == 3:
        const __m256 vemo = _mm256_fmadd_ps(vp, vts, vsmo);
      $else:
        const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vts), vsmo);
    $else:
      $if FMA == 3:
        vp = _mm256_fmadd_ps(vp, vts, vts);
        const __m256 vemo = _mm256_fmadd_ps(vp, vtwo, vsmo);
      $else:
        vp = _mm256_add_ps(_mm256_mul_ps(vp, vts), vts);
        const __m256 vemo = _mm256_add_ps(_mm256_mul_ps(vp, vtwo), vsmo);

    const __m256 vepo = _mm256_add_ps(vemo, vtwo);

    $if DIV == "DIV":
      __m256 vy = _mm256_div_ps(vemo, vepo);
    $else:
      __m256 vrepo = _mm256_rcp_ps(vepo);

      __m256 vy = _mm256_mul_ps(vemo, vrepo);

    $if SAT == "SELECT":
      vy = _mm256_blendv_ps(vy, vminus_one, vm);

    __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_TO_NEAREST_INT);
    vh = _mm_xor_si128(vh, vinvsignx);

    if (batch & (4 * sizeof(uint16_t))) {
      _mm_storel_epi64((__m128i*) o, vh);
      vh = _mm_unpackhi_epi64(vh, vh);
      o += 4;
    }
    if (batch & (2 * sizeof(uint16_t))) {
      _mm_storeu_si32(o, vh);
      vh = _mm_srli_epi64(vh, 32);
      o += 2;
    }
    if (batch & (1 * sizeof(uint16_t))) {
      *o = (uint16_t) _mm_extract_epi16(vh, 0);
    }
  }
}
