// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert DATATYPE in ["QS8", "QU8"]
$assert CHANNELS % 8 == 0
$assert CHANNELS >= 8

#include <assert.h>
#include <math.h>

#include <wasm_simd128.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
$XINT32_T = {"QS8": "int32_t", "QU8": "uint32_t"}[DATATYPE]
$WASM_X16X8_LOAD8X8 = {"QS8": "wasm_i16x8_load8x8", "QU8": "wasm_u16x8_load8x8"}[DATATYPE]
$WASM_X32X4_EXTEND_LOW_X16X8 = {"QS8": "wasm_i32x4_extend_low_i16x8", "QU8": "wasm_u32x4_extend_low_u16x8"}[DATATYPE]
$WASM_X32X4_EXTEND_HIGH_X16X8 = {"QS8": "wasm_i32x4_extend_high_i16x8", "QU8": "wasm_u32x4_extend_high_u16x8"}[DATATYPE]
$WASM_X32x4_MAKE = {"QS8": "wasm_i32x4_make", "QU8": "wasm_u32x4_make"}[DATATYPE]
$WASM_X32X4_EXTRACT_LANE = {"QS8": "wasm_i32x4_extract_lane", "QU8": "wasm_u32x4_extract_lane"}[DATATYPE]
void xnn_${DATATYPE.lower()}_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__wasmsimd_c${CHANNELS}(
    size_t rows,
    size_t channels,
    const ${XINT8_T}* input,
    size_t input_stride,
    const ${XINT8_T}* zero,
    ${XINT32_T}* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(rows != 0);
  assert(channels != 0);
  assert(input != NULL);
  assert(output != NULL);

  size_t input_increment = ${ACCUMULATORS} * input_stride;
  for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) {
    const ${XINT8_T}* i0 = input;
    $for ACC in range(1, ACCUMULATORS):
      const ${XINT8_T}* i${ACC} = (const ${XINT8_T}*) ((uintptr_t) input + ${ACC} * input_stride);

    $for C in range(0, CHANNELS, 4):
      v128_t vacc${C} = wasm_i32x4_const_splat(0);

    // 256 int8s may be summed into an int16 before overflowing
    // To prevent handling the tails of the inner 256 loop, we round 256 down to
    // the nearest integer multiple of ACCUMULATORS.
    $OVERFLOW = (256 // ACCUMULATORS) * ACCUMULATORS
    int r = rows;
    while (r > 0) {
      $for C in range(0, CHANNELS, 8):
        v128_t vacc16_${C} = wasm_i16x8_const_splat(0);
      for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
        $for N in range(1, ACCUMULATORS, 2):
          if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
            i${N} = zero;
          }
          if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
            i${N+1} = zero;
          }
        $for C in range(0, CHANNELS, 8):
          v128_t vin${C};
        $for ACC in range(ACCUMULATORS):
          $for C in range(0, CHANNELS, 8):
            vin${C} = ${WASM_X16X8_LOAD8X8}(&i${ACC}[${C}]);
          $for C in range(0, CHANNELS, 8):
            vacc16_${C} = wasm_i16x8_add(vacc16_${C}, vin${C});
        $for N in range(0, ACCUMULATORS):
          i${N} = (const int8_t*) ((uintptr_t) i${N} + input_increment);
      }
      $for C in range(0, CHANNELS, 8):
        vacc${C} = wasm_i32x4_add(vacc${C}, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc16_${C}));
        vacc${C+4} = wasm_i32x4_add(vacc${C+4}, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc16_${C}));
      r = doz(r, ${OVERFLOW});
    }

    const ${XINT32_T}* o = output;
    $for C in range(0, CHANNELS, 4):
      v128_t vo${C} = wasm_v128_load(o); o += 4;
    $for C in range(0, CHANNELS, 4):
      vo${C} = wasm_i32x4_add(vo${C}, vacc${C});
    $for C in range(0, CHANNELS, 4):
       wasm_v128_store(output, vo${C}); output += 4;

    input = (const int8_t*) ((uintptr_t) input + ${CHANNELS} * sizeof(int8_t));
  }
  if (channels != 0) {
    input_increment = ${ACCUMULATORS} * input_stride;
    // 256 int8s may be summed into an int16 before overflowing.
    do {
      int num_batches = floor((rows + ${OVERFLOW - 1}) / ${OVERFLOW});
      int r = rows;
      const int8_t* i0 = input;
      $for ACC in range(1, ACCUMULATORS):
        const int8_t* i${ACC} = (const int8_t*) ((uintptr_t) input + ${ACC} * input_stride);

      v128_t vacc0 = wasm_i32x4_const_splat(0);
      v128_t vacc1 = wasm_i32x4_const_splat(0);

      for (; num_batches > 0; --num_batches) {
        v128_t vacc16 = wasm_i16x8_const_splat(0);
        for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
          $for N in range(1, ACCUMULATORS, 2):
            if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
              i${N} = zero;
            }
            if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
              i${N+1} = zero;
            }

          $for ACC in range(ACCUMULATORS):
            v128_t vin${ACC} = ${WASM_X16X8_LOAD8X8}(&i${ACC}[0]);
          $for ACC in range(ACCUMULATORS):
            vacc16 = wasm_i16x8_add(vacc16, vin${ACC});
          $for N in range(ACCUMULATORS):
            i${N} = (const int8_t*) ((uintptr_t) i${N} + input_increment);
        }
        vacc0 = wasm_i32x4_add(vacc0, ${WASM_X32X4_EXTEND_LOW_X16X8}(vacc16));
        vacc1 = wasm_i32x4_add(vacc1, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vacc16));
        r = doz(r, ${OVERFLOW});
      }

      if XNN_LIKELY(channels >= 8) {
        v128_t vo0 = wasm_v128_load(output);
        v128_t vo1 = wasm_v128_load(output + 4);
        vo0 = wasm_i32x4_add(vo0, vacc0);
        vo1 = wasm_i32x4_add(vo1, vacc1);
        wasm_v128_store(output, vo0); output += 4;
        wasm_v128_store(output, vo1); output += 4;
        channels -= 8;
        input = (const ${XINT8_T}*) ((uintptr_t) input + 8 * sizeof(${XINT8_T}));
      } else {
        if (channels & 4) {
          v128_t vo = wasm_v128_load(output);
          vo = wasm_i32x4_add(vo, vacc0);
          wasm_v128_store(output, vo); output += 4;
          vacc0 = vacc1;
        }
        if (channels & 2) {
          v128_t vo = ${WASM_X32x4_MAKE}(output[0], output[1], 0, 0);
          vo = wasm_i32x4_add(vo, vacc0);
          wasm_v128_store64_lane(output, vo, 0); output += 2;
          vacc0 = wasm_v64x2_shuffle(vacc0, vacc0, 1, 1);
        }
        if (channels & 1) {
          *output += ${WASM_X32X4_EXTRACT_LANE}(vacc0, 0);
        }
        channels = 0;
      }
    } while (channels != 0);
  }
}
