// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF", "PRELU"]
#include <assert.h>

#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/vbinary.h"


$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
$OP_FUNC = {
$  "ADD": lambda x, y: "%s + %s" % (x, y),
$  "DIV": lambda x, y: "%s / %s" % (x, y),
$  "MAX": lambda x, y: "%s(%s, %s)" % (MAX_F32, x, y),
$  "MIN": lambda x, y: "%s(%s, %s)" % (MIN_F32, x, y),
$  "MUL": lambda x, y: "%s * %s" % (x, y),
$  "SUB": lambda x, y: "%s - %s" % (x, y),
$  "SQRDIFF": lambda x, y: "%s - %s" % (x, y),
$  "PRELU": lambda x, y: "%s * %s" % (x, y),
$}[OP]
void xnn_f32_v${OP.lower()}_ukernel__${"wasm" if WASM else "scalar"}_u${BATCH_TILE}(
    size_t batch,
    const float* input_a,
    const float* input_b,
    float* output,
    const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  $if BATCH_TILE == 1:
    for (; batch >= sizeof(float); batch -= sizeof(float)) {
      const float va = *input_a++;
      const float vb = *input_b++;
      float vacc = ${OP_FUNC("va", "vb")};
      $if OP == "SQRDIFF":
        vacc = vacc * vacc;
      $elif OP == "PRELU":
        vacc = XNN_UNPREDICTABLE(va < 0.0f) ? va * vb : va;
      *output++ = vacc;
    }
  $else:
    for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
      $for N in range(BATCH_TILE):
        const float va${ABC[N]} = input_a[${N}];
      input_a += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        const float vb${ABC[N]} = input_b[${N}];
      input_b += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        float vacc${ABC[N]} = ${OP_FUNC("va" + ABC[N], "vb" + ABC[N])};

      $if OP == "SQRDIFF":
        $for N in range(BATCH_TILE):
          vacc${ABC[N]} = vacc${ABC[N]} * vacc${ABC[N]};
      $elif OP == "PRELU":
        $for N in range(BATCH_TILE):
          vacc${ABC[N]} = XNN_UNPREDICTABLE(va${ABC[N]} < 0.0f) ? va${ABC[N]} * vb${ABC[N]} : va${ABC[N]};

      $for N in range(BATCH_TILE):
        output[${N}] = vacc${ABC[N]};
      output += ${BATCH_TILE};
    }
    if XNN_UNLIKELY(batch != 0) {
      $if BATCH_TILE == 2:
        assert(batch == sizeof(float));
        const float va = *input_a;
        const float vb = *input_b;
        float vacc = ${OP_FUNC("va", "vb")};
        $if OP == "SQRDIFF":
          vacc = vacc * vacc;
        $elif OP == "PRELU":
          vacc = XNN_UNPREDICTABLE(va < 0.0f) ? vacc : va;
        *output = vacc;
      $else:
        do {
          const float va = *input_a++;
          const float vb = *input_b++;
          float vacc = ${OP_FUNC("va", "vb")};
          $if OP == "SQRDIFF":
            vacc = vacc * vacc;
          $elif OP == "PRELU":
            vacc = XNN_UNPREDICTABLE(va < 0.0f) ? va * vb : va;
          *output++ = vacc;
          batch -= sizeof(float);
        } while (batch != 0);
    }
}
