// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert CHANNEL_TILE % 32 == 0
$assert CHANNEL_TILE >= 32
$SIMD_TILE = CHANNEL_TILE // 32
$assert ACCUMULATORS <= SIMD_TILE
#include <assert.h>

#include <immintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
$_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32"
$ISA = "avxvnni" if AVX == 2 else "avx256vnni"

void xnn_qs8_rsum_ukernel__${ISA}_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const int8_t* input,
    int32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);

  $if AVX == 2:
    XNN_ALIGN(32) static const int8_t onemask_table[64] = {
      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    };

  const __m256i vone = _mm256_set1_epi8(1);
  $for ACC in range(ACCUMULATORS):
    __m256i vacc${ACC} = _mm256_setzero_si256();
  $if CHANNEL_TILE > 32:
    for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
      $for N in range(SIMD_TILE):
        vacc${N % ACCUMULATORS} = ${_MM256_DPBUSD_EPI32}(vacc${N % ACCUMULATORS}, vone, _mm256_loadu_si256((const __m256i*) input)); input += 32;
    }
  if (XNN_UNLIKELY(batch != 0)) {
    for (; batch >= 32; batch -= 32) {
      vacc0 = ${_MM256_DPBUSD_EPI32}(vacc0, vone, _mm256_loadu_si256((const __m256i*) input)); input += 32;
    }

    $if AVX == 2:
      // Remainder is between 17 and 31 bytes, so process 32 bytes (overread of up to 15)
      if (XNN_UNLIKELY(batch >= 17)) {
        assert(batch >= 17 && batch <= 31);
        const __m256i vonemask = _mm256_loadu_si256((const __m256i*) &onemask_table[32 - batch]);
        vacc0 = ${_MM256_DPBUSD_EPI32}(vacc0, vonemask, _mm256_loadu_si256((const __m256i*) input));
      // Remainder is between 1 and 16 bytes, so process 16 bytes (overread of up to 15)
      } else if (XNN_UNLIKELY(batch != 0)) {
        assert(batch >= 1 && batch <= 16);
        const __m256i vonemask = _mm256_loadu_si256((const __m256i*) &onemask_table[32 - batch]);
        vacc0 = ${_MM256_DPBUSD_EPI32}(vacc0, vonemask, _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*) input)));
      }
    $else:
      if (XNN_UNLIKELY(batch != 0)) {
        assert(batch >= 1 && batch <= 31);
        // Prepare mask for valid 8-bit elements (depends on batch).
        const __mmask32 vmask = _cvtu32_mask32((UINT32_C(1) << batch) - 1);
        vacc0 = ${_MM256_DPBUSD_EPI32}(vacc0, vone, _mm256_maskz_loadu_epi8(vmask, (const __m256i*) input)); input += 32;
      }
  }
  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = _mm256_add_epi32(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  __m128i vacc_lo = _mm_add_epi32(_mm256_castsi256_si128(vacc0), _mm256_extractf128_si256(vacc0, 1));
  vacc_lo = _mm_hadd_epi32(vacc_lo, vacc_lo);
  vacc_lo = _mm_hadd_epi32(vacc_lo, vacc_lo);

  *output += _mm_cvtsi128_si32(vacc_lo);
}
