// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert OP in ["ABS", "NEG", "SQR"]
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]

#include <assert.h>
#include <stddef.h>

#include "xnnpack/simd/f32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/vunary.h"
#include "xnnpack/microparams.h"

$for BATCH_TILE in BATCH_TILES:

  void xnn_f32_v${OP.lower()}_ukernel__${ARCH}_u${BATCH_TILE}(
      size_t batch,
      const float* input,
      float* output,
      const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});

    $SIMD_TILE = BATCH_TILE // SIMD_SIZE
    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        const xnn_simd_f32_t vx${ABC[0]} = xnn_loadu_f32(input);
        $for N in range(1, SIMD_TILE):
          const xnn_simd_f32_t vx${ABC[N]} = xnn_loadu_f32(input + ${N} * xnn_simd_size_f32);
        input += ${SIMD_TILE} * xnn_simd_size_f32;

        $for N in range(SIMD_TILE):
          $if OP == "ABS":
            const xnn_simd_f32_t vy${ABC[N]} = xnn_abs_f32(vx${ABC[N]});
          $elif OP == "NEG":
            const xnn_simd_f32_t vy${ABC[N]} = xnn_neg_f32(vx${ABC[N]});
          $elif OP == "SQR":
            const xnn_simd_f32_t vy${ABC[N]} = xnn_mul_f32(vx${ABC[N]}, vx${ABC[N]});

        xnn_storeu_f32(output, vy${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy${ABC[N]});
        output += ${SIMD_TILE} * xnn_simd_size_f32;
      }

    for (; batch >= xnn_simd_bytes_f32; batch -= xnn_simd_bytes_f32) {
      const xnn_simd_f32_t vx = xnn_loadu_f32(input);
      input += xnn_simd_size_f32;

      $if OP == "ABS":
        const xnn_simd_f32_t vy = xnn_abs_f32(vx);
      $elif OP == "NEG":
        const xnn_simd_f32_t vy = xnn_neg_f32(vx);
      $elif OP == "SQR":
        const xnn_simd_f32_t vy = xnn_mul_f32(vx, vx);

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }

    if XNN_UNLIKELY(batch != 0) {
      const xnn_simd_f32_t vx =
          xnn_load_tail_f32(input, batch >> XNN_LOG2_SIZEOF_FLOAT);

      $if OP == "ABS":
        const xnn_simd_f32_t vy = xnn_abs_f32(vx);
      $elif OP == "NEG":
        const xnn_simd_f32_t vy = xnn_neg_f32(vx);
      $elif OP == "SQR":
        const xnn_simd_f32_t vy = xnn_mul_f32(vx, vx);

      xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
    }
  }
