// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>
#include <math.h>

#include "xnnpack/common.h"
#include "xnnpack/vunary.h"

#ifndef M_SQRT1_2
#define M_SQRT1_2 0.7071067811865475244
#endif

$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$for BATCH_TILE in BATCH_TILES:
  void xnn_f32_vgelu_ukernel__scalar_u${BATCH_TILE}(
      size_t batch,
      const float* input,
      float* output,
      const struct xnn_f32_default_params unused_params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);

    $if BATCH_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        $for N in range(BATCH_TILE):
          const float vx_${ABC[N]} = input[${N}];
        input += ${BATCH_TILE};

        $for N in range(BATCH_TILE):
          float vy_${ABC[N]} = erff(vx_${ABC[N]} * M_SQRT1_2);
        $for N in range(BATCH_TILE):
          vy_${ABC[N]} = 1.0f + vy_${ABC[N]};
        $for N in range(BATCH_TILE):
          vy_${ABC[N]} = vx_${ABC[N]} * 0.5f * vy_${ABC[N]};

        $for N in range(BATCH_TILE):
          output[${N}] = vy_${ABC[N]};
        output += ${BATCH_TILE};
      }
      if XNN_UNLIKELY(batch != 0) {
        $if BATCH_TILE > 2:
          do {
            const float vx = *input++;
            const float vy = vx * 0.5f * (1.0f + erff(vx * M_SQRT1_2));
            *output++ = vy;
            batch -= sizeof(float);
          } while (batch != 0);
        $else:
          const float vx = *input;
          const float vy = vx * 0.5f * (1.0f + erff(vx * M_SQRT1_2));
          *output = vy;
      }
    $else:
      for (; batch >= sizeof(float); batch -= sizeof(float)) {
        const float vx = *input++;
        const float vy = vx * 0.5f * (1.0f + erff(vx * M_SQRT1_2));
        *output++ = vy;
      }
  }
