// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]
#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/simd/f32-${ARCH}.h"

#include "xnnpack/common.h"
#include "xnnpack/microparams.h"
#include "xnnpack/vunary.h"

$for BATCH_TILE in BATCH_TILES:
  $assert BATCH_TILE % SIMD_SIZE == 0
  $assert BATCH_TILE >= SIMD_SIZE
  $SIMD_TILE = BATCH_TILE // SIMD_SIZE

  void xnn_f32_vrcopysignc_ukernel__${ARCH}_u${BATCH_TILE}(
      size_t batch,
      const float* sign,
      const float* mag,
      float* output,
      const struct xnn_f32_default_params unused_params[restrict XNN_MIN_ELEMENTS(1)])
  {
    assert(batch != 0);
    assert(batch % sizeof(float) == 0);
    assert(sign != NULL);
    assert(mag != NULL);
    assert(output != NULL);
    assert(xnn_simd_size_f32 == ${SIMD_SIZE});

    XNN_SIMD_CONST_F32(vsign_mask, -0.f);
    xnn_simd_f32_t vmag = xnn_abs_f32(xnn_set1_f32(*mag));

    $if SIMD_TILE > 1:
      for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
        xnn_simd_f32_t vsign_${ABC[0]} = xnn_loadu_f32(sign);
        $for N in range(1, SIMD_TILE):
          xnn_simd_f32_t vsign_${ABC[N]} = xnn_loadu_f32(sign + ${N} * xnn_simd_size_f32);
        sign += ${BATCH_TILE};

        $for N in range(0, SIMD_TILE):
          vsign_${ABC[N]} = xnn_and_f32(vsign_${ABC[N]}, vsign_mask);

        $for N in range(0, SIMD_TILE):
          xnn_simd_f32_t vy_${ABC[N]} = xnn_or_f32(vsign_${ABC[N]}, vmag);

        xnn_storeu_f32(output, vy_${ABC[0]});
        $for N in range(1, SIMD_TILE):
          xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy_${ABC[N]});
        output += ${BATCH_TILE};
      }
    for (; batch >= xnn_simd_bytes_f32; batch -= xnn_simd_bytes_f32) {
      xnn_simd_f32_t vsign = xnn_loadu_f32(sign);
      sign += xnn_simd_size_f32;

      vsign = xnn_and_f32(vsign, vsign_mask);

      xnn_simd_f32_t vy = xnn_or_f32(vsign, vmag);

      xnn_storeu_f32(output, vy);
      output += xnn_simd_size_f32;
    }
    $if SIMD_SIZE > 1:
      if XNN_UNLIKELY(batch != 0) {
        xnn_simd_f32_t vsign = xnn_load_tail_f32(sign, batch >> XNN_LOG2_SIZEOF_FLOAT);
        vsign = xnn_and_f32(vsign, vsign_mask);

        xnn_simd_f32_t vy = xnn_or_f32(vsign, vmag);

        xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT);
      }
  }
