// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert CHANNEL_TILE % 8 == 0
$assert CHANNEL_TILE >= 8
$SIMD_TILE = CHANNEL_TILE // 8
$assert ACCUMULATORS <= SIMD_TILE

#include <assert.h>

#include <wasm_simd128.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
$XINT32_T = {"QS8": "int32_t", "QU8": "uint32_t"}[DATATYPE]
$WASM_X16X8_LOAD8X8 = {"QS8": "wasm_i16x8_load8x8", "QU8": "wasm_u16x8_load8x8"}[DATATYPE]
$WASM_X32X4_EXTEND_LOW_X16X8 = {"QS8": "wasm_i32x4_extend_low_i16x8", "QU8": "wasm_u32x4_extend_low_u16x8"}[DATATYPE]
$WASM_X32X4_EXTEND_HIGH_X16X8 = {"QS8": "wasm_i32x4_extend_high_i16x8", "QU8": "wasm_u32x4_extend_high_u16x8"}[DATATYPE]
$WASM_X32X4_EXTRACT_LANE = {"QS8": "wasm_i32x4_extract_lane", "QU8": "wasm_u32x4_extract_lane"}[DATATYPE]
$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_${DATATYPE.lower()}_rsum_ukernel__wasmsimd_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const ${XINT8_T}* input,
    ${XINT32_T}* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);
  assert(params != NULL);

  XNN_ALIGN(16) static const int16_t mask_table[16] = {
    0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
    0, 0, 0, 0, 0, 0, 0, 0
  };

  $for ACC in range(ACCUMULATORS):
    v128_t vacc${ACC} = wasm_i32x4_const_splat(0);

  // 256 int8s may be summed into an int16 before overflowing.
  // Each register has 8 lanes and there are ${ACCUMULATORS} accumulators so batch size is ${ACCUMULATORS*256*8}
  for (; batch >= ${ACCUMULATORS*256*8}; batch -= ${ACCUMULATORS*256*8}) {
    $for ACC in range(ACCUMULATORS):
      v128_t vacc16_${ACC} = wasm_i16x8_const_splat(0);
    for (size_t current_batch = ${ACCUMULATORS*256*8}; current_batch > 0; current_batch -= ${CHANNEL_TILE}) {
      $for N in range(SIMD_TILE):
        const v128_t vt${N} = ${WASM_X16X8_LOAD8X8}(input); input += 8;
      $for N in range(SIMD_TILE):
        vacc16_${N % ACCUMULATORS} = wasm_i16x8_add(vacc16_${N % ACCUMULATORS}, vt${N});
    }
    $for ACC in range(ACCUMULATORS):
      vacc${ACC} = wasm_i32x4_add(vacc${ACC}, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc16_${ACC}));
      vacc${ACC} = wasm_i32x4_add(vacc${ACC}, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc16_${ACC}));
  }

  $if CHANNEL_TILE > 8:
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch < ${ACCUMULATORS*256*8});
      $for ACC in range(ACCUMULATORS):
        v128_t vacc16_${ACC} = wasm_i8x16_const_splat(0);
      for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
        $for N in range(SIMD_TILE):
          const v128_t vt${N} = ${WASM_X16X8_LOAD8X8}(input); input += 8;
        $for N in range(SIMD_TILE):
          vacc16_${N % ACCUMULATORS} = wasm_i16x8_add(vacc16_${N % ACCUMULATORS}, vt${N});
      }
      $for ACC in range(ACCUMULATORS):
        vacc${ACC} = wasm_i32x4_add(vacc${ACC}, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc16_${ACC}));
        vacc${ACC} = wasm_i32x4_add(vacc${ACC}, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc16_${ACC}));
    }
  if (XNN_UNLIKELY(batch != 0)) {
    assert(batch >= 1 && batch < ${256*8});
    v128_t vacc16 = wasm_i16x8_const_splat(0);
    for (; batch >= 8; batch -= 8) {
      const v128_t vt = ${WASM_X16X8_LOAD8X8}(input); input += 8;
      vacc16 = wasm_i16x8_add(vacc16, vt);
    }
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch <= 7);
      const v128_t mask = wasm_v128_load(&mask_table[8 - batch]);
      const v128_t vt = wasm_v128_bitselect(${WASM_X16X8_LOAD8X8}(input), wasm_i16x8_const_splat(0), mask);
      vacc16 = wasm_i16x8_add(vacc16, vt);
    }
    vacc0 = wasm_i32x4_add(vacc0, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc16));
    vacc0 = wasm_i32x4_add(vacc0, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc16));
  }
  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = wasm_i32x4_add(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  v128_t vacc_shifted = wasm_i32x4_shuffle(vacc0, vacc0, 1, 0, 3, 2);
  v128_t vacc_lo = wasm_i32x4_add(vacc0, vacc_shifted);

  v128_t vacc_final_shifted = wasm_i32x4_shuffle(vacc_lo, vacc_lo, 2, 3, 0, 1);
  v128_t vacc_final = wasm_i32x4_add(vacc_lo, vacc_final_shifted);
  const int32_t vacc = ${WASM_X32X4_EXTRACT_LANE}(vacc_final, 0);

  *output += vacc;
}
