// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert CHANNEL_TILE % 64 == 0
$assert CHANNEL_TILE >= 64
$SIMD_TILE = CHANNEL_TILE // 64
$assert ACCUMULATORS <= SIMD_TILE
#include <assert.h>

#include <immintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_qs8_rsum_ukernel__avx512vnni_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const int8_t* input,
    int32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);

  $for ACC in range(ACCUMULATORS):
    __m512i vacc${ACC} = _mm512_setzero_si512();
  const __m512i vone = _mm512_set1_epi8(1);
  for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
    $for N in range(SIMD_TILE):
      vacc${N % ACCUMULATORS} = _mm512_dpbusd_epi32(vacc${N % ACCUMULATORS}, vone, _mm512_loadu_si512((const __m512i*) input)); input += 64;
  }
  if (XNN_UNLIKELY(batch != 0)) {
    for (; batch >= 64; batch -= 64) {
      vacc0 = _mm512_dpbusd_epi32(vacc0, vone, _mm512_loadu_si512((const __m512i*) input)); input += 64;
    }
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch <= 63);
      const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (batch & 63)) - UINT64_C(1)));
      vacc0 = _mm512_dpbusd_epi32(vacc0, vone, _mm512_maskz_loadu_epi8(vmask, (const __m512i*) input)); input += 64;
    }
  }

  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = _mm512_add_epi32(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  int32_t res = _mm512_reduce_add_epi32(vacc0);

  *output += res;
}
