// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>

#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/vbinary.h"


void xnn_f32_vcmul_ukernel__scalar_u${BATCH_TILE}(
    size_t batch,
    const float* input_a,
    const float* input_b,
    float* output,
    const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  const float* ar = input_a;
  const float* ai = (const float*) ((uintptr_t) input_a + batch);
  const float* br = input_b;
  const float* bi = (const float*) ((uintptr_t) input_b + batch);
  float* or = output;
  float* oi = (float*) ((uintptr_t) output + batch);
  $if BATCH_TILE == 1:
    for (; batch >= sizeof(float); batch -= sizeof(float)) {
      const float var = *ar++;
      const float vai = *ai++;
      const float vbr = *br++;
      const float vbi = *bi++;
      const float vaccr = var * vbr - vai * vbi;
      const float vacci = var * vbi + vai * vbr;
      *or++ = vaccr;
      *oi++ = vacci;
    }
  $else:
    for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
      $for N in range(BATCH_TILE):
        const float va${ABC[N]}r = ar[${N}];
      ar += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        const float va${ABC[N]}i = ai[${N}];
      ai += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        const float vb${ABC[N]}r = br[${N}];
      br += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        const float vb${ABC[N]}i = bi[${N}];
      bi += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        const float vacc${ABC[N]}r = va${ABC[N]}r * vb${ABC[N]}r - va${ABC[N]}i * vb${ABC[N]}i;

      $for N in range(BATCH_TILE):
        const float vacc${ABC[N]}i = va${ABC[N]}r * vb${ABC[N]}i + va${ABC[N]}i * vb${ABC[N]}r;

      $for N in range(BATCH_TILE):
        or[${N}] = vacc${ABC[N]}r;
      or += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        oi[${N}] = vacc${ABC[N]}i;
      oi += ${BATCH_TILE};
    }
    if XNN_UNLIKELY(batch != 0) {
      $if BATCH_TILE == 2:
        assert(batch == sizeof(float));
        const float var = *ar;
        const float vai = *ai;
        const float vbr = *br;
        const float vbi = *bi;
        const float vaccr = var * vbr - vai * vbi;
        const float vacci = var * vbi + vai * vbr;
        *or = vaccr;
        *oi = vacci;
      $else:
        do {
          const float var = *ar++;
          const float vai = *ai++;
          const float vbr = *br++;
          const float vbi = *bi++;
          const float vaccr = var * vbr - vai * vbi;
          const float vacci = var * vbi + vai * vbr;
          *or++ = vaccr;
          *oi++ = vacci;
          batch -= sizeof(float);
        } while (batch != 0);
    }
}
