// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert CHANNEL_TILE % 16 == 0
$assert CHANNEL_TILE >= 16
$SIMD_TILE = CHANNEL_TILE // 16
$assert ACCUMULATORS <= SIMD_TILE
#include <assert.h>

#include <emmintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/reduce.h"

$ACC_SUFFIX = "" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS
void xnn_qu8_rsum_ukernel__sse2_u${CHANNEL_TILE}${ACC_SUFFIX}(
    size_t batch,
    const uint8_t* input,
    uint32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(batch != 0);
  assert(input != NULL);
  assert(output != NULL);
  assert(params != NULL);

  XNN_ALIGN(16) static const int8_t mask_table[32] = {
    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  };

  const __m128i vzero = _mm_setzero_si128();
  $for ACC in range(ACCUMULATORS):
    __m128i vacc${ACC} = _mm_setzero_si128();

  $if CHANNEL_TILE > 16:
    for (; batch >= ${CHANNEL_TILE}; batch -= ${CHANNEL_TILE}) {
      $for N in range(SIMD_TILE):
        const __m128i vin${N} = _mm_loadu_si128((const __m128i*) (input + ${N * 16}));
      input += ${CHANNEL_TILE};
      $for N in range(SIMD_TILE):
        const __m128i vt${N} = _mm_sad_epu8(vin${N}, vzero);
       $for N in range(SIMD_TILE):
         vacc${N % ACCUMULATORS} = _mm_add_epi32(vacc${N % ACCUMULATORS}, vt${N});
    }
    $if ACCUMULATORS > 1:
      $ACC_SLICE = 1
      $while ACC_SLICE < ACCUMULATORS:
        $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
          $if A + ACC_SLICE < ACCUMULATORS:
            vacc${A} = _mm_add_epi32(vacc${A}, vacc${A + ACC_SLICE});
        $ACC_SLICE *= 2

  for (; batch >= 16; batch -= 16) {
    const __m128i vin = _mm_loadu_si128((const __m128i*) input);
    input += 16;
    const __m128i vt = _mm_sad_epu8(vin, vzero);
    vacc0 = _mm_add_epi32(vacc0, vt);
  }

  if (XNN_UNLIKELY(batch != 0)) {
    assert(batch >= 1 && batch <= 15);
    const __m128i vmask = _mm_loadu_si128((const __m128i*) &mask_table[16 - batch]);
    const __m128i vt = _mm_sad_epu8(_mm_and_si128(_mm_loadu_si128((const __m128i*) input), vmask), vzero);
    vacc0 = _mm_add_epi32(vacc0, vt);
  }

  vacc0 = _mm_add_epi32(vacc0, _mm_castps_si128(_mm_movehl_ps(_mm_castsi128_ps(vacc0), _mm_castsi128_ps(vacc0))));
  vacc0 = _mm_add_epi32(vacc0, _mm_shuffle_epi32(vacc0, _MM_SHUFFLE(1, 1, 1, 1)));

  *output += (uint32_t)_mm_cvtsi128_si32(vacc0);
}