// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert CHANNEL_TILE % 16 == 0
$assert CHANNEL_TILE >= 16
$SIMD_TILE = CHANNEL_TILE // 16
$assert ACCUMULATORS <= SIMD_TILE

#include <assert.h>

#include <wasm_simd128.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
$XINT32_T = {"QS8": "int32_t", "QU8": "uint32_t"}[DATATYPE]
$WASM_X32X4_EXTRACT_LANE = {"QS8": "wasm_i32x4_extract_lane", "QU8": "wasm_u32x4_extract_lane"}[DATATYPE]
$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_${DATATYPE.lower()}_rsum_ukernel__wasmrelaxedsimd_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const ${XINT8_T}* input,
    ${XINT32_T}* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);
  assert(params != NULL);

  XNN_ALIGN(32) static const int8_t mask_table[32] = {
    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
  };

  const v128_t vone = wasm_i8x16_const_splat(1);

  $for ACC in range(ACCUMULATORS):
    v128_t vacc${ACC} = wasm_i32x4_const_splat(0);

  for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
    $for N in range(SIMD_TILE):
      const v128_t vt${N} = wasm_v128_load(input); input += 16;
    $for N in range(SIMD_TILE):
      vacc${N % ACCUMULATORS} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(vt${N}, vone, vacc${N % ACCUMULATORS});
  }

  $if CHANNEL_TILE > 16:
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch < 32);
      for (; batch >= 16; batch -= 16) {
        const v128_t vt${N} = wasm_v128_load(input); input += 16;
        vacc${N % ACCUMULATORS} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(vt${N}, vone, vacc${N % ACCUMULATORS});
      }
    }

  if (XNN_UNLIKELY(batch != 0)) {
    assert(batch >= 1 && batch <= 15);
    const v128_t vmask = wasm_v128_load(&mask_table[16 - batch]);
    const v128_t vt = wasm_v128_bitselect(wasm_v128_load(input), wasm_i8x16_const_splat(0), vmask);
    vacc0 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(vt, vone, vacc0);
  }

  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = wasm_i32x4_add(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  v128_t vacc_shifted = wasm_i32x4_shuffle(vacc0, vacc0, 1, 0, 3, 2);
  v128_t vacc_lo = wasm_i32x4_add(vacc0, vacc_shifted);

  v128_t vacc_final_shifted = wasm_i32x4_shuffle(vacc_lo, vacc_lo, 2, 3, 0, 1);
  v128_t vacc_final = wasm_i32x4_add(vacc_lo, vacc_final_shifted);
  const int32_t vacc = ${WASM_X32X4_EXTRACT_LANE}(vacc_final, 0);

  *output += vacc;
}
