// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]

#include <assert.h>
#include <stddef.h>

#include "xnnpack/simd/f32-${ARCH}.h"
#include "xnnpack/simd/s32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/microparams.h"

$for BATCH_TILE in BATCH_TILES:

  void xnn_s32_f32_vcvt_ukernel__${ARCH}_u${BATCH_TILE}(
      size_t batch,
      const int32_t* input,
      float* output,
      const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(input != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});
    assert(xnn_simd_size_s32 == ${SIMD_SIZE});

    const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point);

    $SIMD_TILE = BATCH_TILE // SIMD_SIZE
    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(int32_t); batch -= ${BATCH_TILE} * sizeof(int32_t)) {
        const xnn_simd_s32_t vx${ABC[0]} = xnn_loadu_s32(input);
        $for N in range(1, SIMD_TILE):
          const xnn_simd_s32_t vx${ABC[N]} = xnn_loadu_s32(input + ${N} * xnn_simd_size_s32);
        input += ${SIMD_TILE} * xnn_simd_size_s32;

        $for N in range(SIMD_TILE):
          const xnn_simd_f32_t vy${ABC[N]} = xnn_cvt_f32_s32(xnn_sub_s32(vx${ABC[N]}, sub));

        xnn_storeu_f32(output, vy${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy${ABC[N]});
        output += ${SIMD_TILE} * xnn_simd_size_f32;
      }

    for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) {
      const xnn_simd_s32_t vx = xnn_loadu_s32(input);
      input += xnn_simd_size_f32;

      const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub));

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }

    if (batch != 0) {
      const xnn_simd_s32_t vx =
          xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T);

      const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub));

      xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
    }
  }
