// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert CHANNEL_TILE % 16 == 0
$assert CHANNEL_TILE >= 16
$SIMD_TILE = CHANNEL_TILE // 16
$assert ACCUMULATORS <= SIMD_TILE
#include <assert.h>

#include <tmmintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_qs8_rsum_ukernel__ssse3_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const int8_t* input,
    int32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);
  assert(params != NULL);

  XNN_ALIGN(16) static const int8_t onemask_table[32] = {
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  };

  const __m128i vone = _mm_load_si128((const __m128i*) &onemask_table[0]);
  const __m128i vone_16 = _mm_srli_epi16(vone, 8);
  $for ACC in range(ACCUMULATORS):
    __m128i vacc${ACC} = _mm_setzero_si128();

  // 256 int8s may be summed into an int16 before overflowing.
  // Each register has 16 lanes and there are ${ACCUMULATORS} accumulators so batch size is ${ACCUMULATORS*256*16}
  for (; batch >= ${ACCUMULATORS*256*16}; batch -= ${ACCUMULATORS*256*16}) {
    $for ACC in range(ACCUMULATORS):
      __m128i vacc16_${ACC} = _mm_setzero_si128();
    for (size_t current_batch = ${ACCUMULATORS*256*16}; current_batch > 0; current_batch -= ${CHANNEL_TILE}) {
      $for N in range(SIMD_TILE):
        const __m128i vt${N} = _mm_maddubs_epi16(vone, _mm_loadu_si128((const __m128i*) input)); input += 16;
      $for N in range(SIMD_TILE):
        vacc16_${N % ACCUMULATORS} = _mm_add_epi16(vacc16_${N % ACCUMULATORS}, vt${N});
    }
    $for ACC in range(ACCUMULATORS):
      vacc${ACC} = _mm_add_epi32(vacc${ACC}, _mm_madd_epi16(vone_16, vacc16_${ACC}));
  }

  $if CHANNEL_TILE > 16:
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch < ${ACCUMULATORS*256*16});
      $for ACC in range(ACCUMULATORS):
        __m128i vacc16_${ACC} = _mm_setzero_si128();
      for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
        $for N in range(SIMD_TILE):
          const __m128i vt${N} = _mm_maddubs_epi16(vone, _mm_loadu_si128((const __m128i*) input)); input += 16;
        $for N in range(SIMD_TILE):
          vacc16_${N % ACCUMULATORS} = _mm_add_epi16(vacc16_${N % ACCUMULATORS}, vt${N});
      }
      $for ACC in range(ACCUMULATORS):
        vacc${ACC} = _mm_add_epi32(vacc${ACC}, _mm_madd_epi16(vone_16, vacc16_${ACC}));
    }
  if (XNN_UNLIKELY(batch != 0)) {
    assert(batch >= 1 && batch < ${256*16});
    __m128i vacc16 = _mm_setzero_si128();
    for (; batch >= 16; batch -= 16) {
      const __m128i vt = _mm_maddubs_epi16(vone, _mm_loadu_si128((const __m128i*) input)); input += 16;
      vacc16 = _mm_add_epi16(vacc16, vt);
    }
    if (XNN_UNLIKELY(batch != 0)) {
      assert(batch >= 1 && batch <= 15);
      const __m128i vonemask = _mm_loadu_si128((const __m128i*) &onemask_table[16 - batch]);
      const __m128i vt = _mm_maddubs_epi16(vonemask, _mm_loadu_si128((const __m128i*) input));
      vacc16 = _mm_add_epi16(vacc16, vt);
    }
    vacc0 = _mm_add_epi32(vacc0, _mm_madd_epi16(vone_16, vacc16));
  }
  $if ACCUMULATORS > 1:
    $ACC_SLICE = 1
    $while ACC_SLICE < ACCUMULATORS:
      $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
        $if A + ACC_SLICE < ACCUMULATORS:
          vacc${A} = _mm_add_epi32(vacc${A}, vacc${A + ACC_SLICE});
      $ACC_SLICE *= 2

  __m128i vacc_lo = _mm_hadd_epi32(vacc0, vacc0);
  vacc_lo = _mm_hadd_epi32(vacc_lo, vacc_lo);
  const int32_t vacc = _mm_cvtsi128_si32(vacc_lo);

  *output += vacc;
}
