// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert CHANNEL_TILE % 16 == 0
$assert CHANNEL_TILE >= 16
$SIMD_TILE = CHANNEL_TILE // 16
$assert ACCUMULATORS <= SIMD_TILE
#include <assert.h>

#include <arm_neon.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_qs8_rsum_ukernel__neondot_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const int8_t* input,
    int32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);
  assert(params != NULL);

  XNN_ALIGN(16) static const int8_t onemask_table[32] = {
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  };

  const int8x16_t vone = vdupq_n_s8(INT8_C(1));
  $for A in range(ACCUMULATORS):
    int32x4_t vacc${A} = vmovq_n_s32(0);
  $if CHANNEL_TILE > 8:
    for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
      $for N in range(SIMD_TILE):
        const int8x16_t vt${N} = vld1q_s8(input); input += 16;
      $for N in range(SIMD_TILE):
        vacc${N % ACCUMULATORS} = vdotq_s32(vacc${N % ACCUMULATORS}, vt${N}, vone);
    }
  if (XNN_UNLIKELY(batch != 0)) {
    for (; batch >= 16; batch -= 16) {
      const int8x16_t vt = vld1q_s8(input); input += 16;
      vacc0 = vdotq_s32(vacc0, vt, vone);
    }
    if (XNN_UNLIKELY(batch != 0)) {
      int8x16_t vt = vld1q_s8(input);
      const int8x16_t vonemask = vld1q_s8(&onemask_table[16 - batch]);
      vacc0 = vdotq_s32(vacc0, vt, vonemask);
    }
  }
  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = vaddq_s32(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  #if XNN_ARCH_ARM64
    const int32_t vacc = vaddvq_s32(vacc0);
  #else
    int32x2_t vacc_lo = vadd_s32(vget_low_s32(vacc0), vget_high_s32(vacc0));
    vacc_lo = vpadd_s32(vacc_lo, vacc_lo);
    const int32_t vacc = vget_lane_s32(vacc_lo, 0);
  #endif

  *output += vacc;
}