// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert BATCH_TILE >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert OP in ["ADD", "DIV", "RDIV", "MAX", "MIN", "MUL", "SUB", "RSUB", "SQRDIFF", "PRELU", "RPRELU"]
#include <assert.h>

#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/vbinary.h"


$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
$OP_FUNC = {
$  "ADD": lambda x: "%s + vb" % x,
$  "DIV": lambda x: "%s / vb" % x,
$  "RDIV": lambda x: "vb / %s" % x,
$  "MAX": lambda x: "%s(%s, vb)" % (MAX_F32, x),
$  "MIN": lambda x: "%s(%s, vb)" % (MIN_F32, x),
$  "MUL": lambda x: "%s * vb" % x,
$  "SUB": lambda x: "%s - vb" % x,
$  "RSUB": lambda x: "vb - %s" % x,
$  "SQRDIFF": lambda x: "%s - vb" % x,
$  "PRELU": lambda x: "%s * vb" % x,
$  "RPRELU": lambda x: "%s * vb" % x,
$}[OP]
void xnn_f32_v${OP.lower()}c_ukernel__${"wasm" if WASM else "scalar"}_u${BATCH_TILE}(
    size_t batch,
    const float* input_a,
    const float* input_b,
    float* output,
    const struct xnn_f32_default_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(batch % sizeof(float) == 0);
  assert(input_a != NULL);
  assert(input_b != NULL);
  assert(output != NULL);

  const float vb = *input_b;

  $if BATCH_TILE == 1:
    for (; batch >= sizeof(float); batch -= sizeof(float)) {
      const float va = *input_a++;
      float vacc = ${OP_FUNC("va")};
      $if OP == "SQRDIFF":
        vacc = vacc * vacc;
      $elif OP == "PRELU":
        vacc = XNN_UNPREDICTABLE(va < 0.0f) ? vacc : va;
      $elif OP == "RPRELU":
        vacc = (vb < 0.0f) ? vacc : vb;
      *output++ = vacc;
    }
  $else:
    for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
      $for N in range(BATCH_TILE):
        const float va${ABC[N]} = input_a[${N}];
      input_a += ${BATCH_TILE};

      $for N in range(BATCH_TILE):
        float vacc${ABC[N]} = ${OP_FUNC("va" + ABC[N])};

      $if OP == "SQRDIFF":
        $for N in range(BATCH_TILE):
          vacc${ABC[N]} = vacc${ABC[N]} * vacc${ABC[N]};
      $elif OP == "PRELU":
        $for N in range(BATCH_TILE):
          vacc${ABC[N]} = XNN_UNPREDICTABLE(va${ABC[N]} < 0.0f) ? vacc${ABC[N]} : va${ABC[N]};
      $elif OP == "RPRELU":
        $for N in range(BATCH_TILE):
          vacc${ABC[N]} = vb < 0.0f ? vacc${ABC[N]} : vb;

      $for N in range(BATCH_TILE):
        output[${N}] = vacc${ABC[N]};
      output += ${BATCH_TILE};
    }
    if XNN_UNLIKELY(batch != 0) {
      $if BATCH_TILE == 2:
        assert(batch == sizeof(float));
        const float va = *input_a;
        float vacc = ${OP_FUNC("va")};
        $if OP == "SQRDIFF":
          vacc = vacc * vacc;
        $elif OP == "PRELU":
          vacc = XNN_UNPREDICTABLE(va < 0.0f) ? vacc : va;
        $elif OP == "RPRELU":
          vacc = XNN_UNPREDICTABLE(vb < 0.0f) ? vacc : vb;
        *output = vacc;
      $else:
        do {
          const float va = *input_a++;
          float vacc = ${OP_FUNC("va")};
          $if OP == "SQRDIFF":
            vacc = vacc * vacc;
          $elif OP == "PRELU":
            vacc = XNN_UNPREDICTABLE(va < 0.0f) ? vacc : va;
          $elif OP == "RPRELU":
            vacc = XNN_UNPREDICTABLE(vb < 0.0f) ? vacc : vb;
          *output++ = vacc;
          batch -= sizeof(float);
        } while (batch != 0);
    }
}
