// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <math.h>

#include <tmmintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/math.h"
#include "xnnpack/reduce.h"
#include "xnnpack/unaligned.h"


void xnn_qu8_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__ssse3_c${CHANNELS}(
    size_t rows,
    size_t channels,
    const uint8_t* input,
    size_t input_stride,
    const uint8_t* zero,
    uint32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(rows != 0);
  assert(channels != 0);
  assert(input != NULL);
  assert(output != NULL);

  size_t input_increment = ${ACCUMULATORS} * input_stride;
  for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) {
    const uint8_t* i0 = input;
    $for ACC in range(1, ACCUMULATORS):
      const uint8_t* i${ACC} = (const uint8_t*) ((uintptr_t) input + ${ACC} * input_stride);

    $for C in range(0, CHANNELS, 4):
      __m128i vacc${C} = _mm_setzero_si128();

    // 256 uint8s may be summed into an uint16 before overflowing
    // To prevent handling the tails of the inner 256 loop, we round 256 down to
    // the nearest integer multiple of ACCUMULATORS.
    $OVERFLOW = (256 // ACCUMULATORS) * ACCUMULATORS
    int r = rows;
    __m128i vone = _mm_set1_epi8(1);

    while (r > 0) {
      $for C in range(0, CHANNELS, 8):
        __m128i vacc16_${C} = _mm_setzero_si128();
      for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
        $for N in range(1, ACCUMULATORS, 2):
          if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
            i${N} = zero;
          }
          if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
            i${N+1} = zero;
          }

        __m128i vin_lo;
        __m128i vin_hi;
        $for ACC in range(ACCUMULATORS):
          __m128i vin${ACC};

        $for C in range(0, CHANNELS, 16):
          $for ACC in range(ACCUMULATORS):
            vin${ACC} = _mm_loadu_si128((const __m128i*)&i${ACC}[${C}]);
          $for ACC in range(0, ACCUMULATORS - 1, 2):
            vin_lo = _mm_unpacklo_epi8(vin${ACC}, vin${ACC+1});
            vin_hi = _mm_unpackhi_epi8(vin${ACC}, vin${ACC+1});
            vin_lo = _mm_maddubs_epi16(vin_lo, vone);
            vin_hi = _mm_maddubs_epi16(vin_hi, vone);
            vacc16_${C} = _mm_add_epi16(vacc16_${C}, vin_lo);
            vacc16_${C+8} = _mm_add_epi16(vacc16_${C+8}, vin_hi);

          $if ACCUMULATORS % 2 != 0:
            $ACC = ACCUMULATORS - 1
            vin_lo = _mm_unpacklo_epi8(vin${ACC}, _mm_setzero_si128());
            vin_hi = _mm_unpackhi_epi8(vin${ACC}, _mm_setzero_si128());
            vacc16_${C} = _mm_add_epi16(vacc16_${C}, vin_lo);
            vacc16_${C+8} = _mm_add_epi16(vacc16_${C+8}, vin_hi);

        $for N in range(0, ACCUMULATORS):
          i${N} = (const uint8_t*) ((uintptr_t) i${N} + input_increment);
      }
      $for C in range(0, CHANNELS, 8):
        vacc${C} = _mm_add_epi32(vacc${C}, _mm_unpacklo_epi16(vacc16_${C}, _mm_setzero_si128()));
        vacc${C+4} = _mm_add_epi32(vacc${C+4}, _mm_unpacklo_epi16(_mm_srli_si128(vacc16_${C}, 8), _mm_setzero_si128()));
      r = doz(r, ${OVERFLOW});
    }

    $for C in range(0, CHANNELS, 4):
      __m128i vo${C} = _mm_loadu_si128((const __m128i*) ((uintptr_t) output + ${C} * sizeof(uint32_t)));
    $for C in range(0, CHANNELS, 4):
      vo${C} = _mm_add_epi32(vo${C}, vacc${C});
    $for C in range(0, CHANNELS, 4):
      _mm_storeu_si128((__m128i*) output, vo${C}); output += 4;

    input = (const uint8_t*) ((uintptr_t) input + ${CHANNELS} * sizeof(uint8_t));
  }
  if (channels != 0) {
    input_increment = ${ACCUMULATORS} * input_stride;
    // 256 uint8s may be summed into an uint16 before overflowing.
    do {
      int num_batches = floor((rows + ${OVERFLOW - 1}) / ${OVERFLOW});
      int r = rows;
      const uint8_t* i0 = input;
      $for ACC in range(1, ACCUMULATORS):
        const uint8_t* i${ACC} = (const uint8_t*) ((uintptr_t) input + ${ACC} * input_stride);

      __m128i vacc0123 = _mm_setzero_si128();
      __m128i vacc4567 = _mm_setzero_si128();
      __m128i vone = _mm_set1_epi8(1);

      for (; num_batches > 0; --num_batches) {
        __m128i vacc16_01234567 = _mm_setzero_si128();
        for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
          $for N in range(1, ACCUMULATORS, 2):
            if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
              i${N} = zero;
            }
            if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
              i${N+1} = zero;
            }

          __m128i vin_lo;
          __m128i vin_hi;
          $for ACC in range(ACCUMULATORS):
            __m128i vin${ACC} = _mm_loadl_epi64((const __m128i*)&i${ACC}[0]);
          $for ACC in range(0, ACCUMULATORS - 1, 2):
            vin_lo = _mm_unpacklo_epi8(vin${ACC}, vin${ACC+1});
            vin_hi = _mm_unpackhi_epi8(vin${ACC}, vin${ACC+1});
            vin_lo = _mm_maddubs_epi16(vin_lo, vone);
            vin_hi = _mm_maddubs_epi16(vin_hi, vone);
            vacc16_01234567 = _mm_add_epi16(vacc16_01234567, vin_lo);
            vacc16_01234567 = _mm_add_epi16(vacc16_01234567, vin_hi);

          $if ACCUMULATORS % 2 != 0:
            $ACC = ACCUMULATORS - 1
            vin_lo = _mm_unpacklo_epi8(vin${ACC}, _mm_setzero_si128());
            vin_hi = _mm_unpackhi_epi8(vin${ACC}, _mm_setzero_si128());
            vacc16_01234567 = _mm_add_epi16(vacc16_01234567, vin_lo);
            vacc16_01234567 = _mm_add_epi16(vacc16_01234567, vin_hi);

          $for N in range(ACCUMULATORS):
            i${N} = (const uint8_t*) ((uintptr_t) i${N} + input_increment);
        }
        vacc0123 = _mm_add_epi32(vacc0123, _mm_unpacklo_epi16(vacc16_01234567, _mm_setzero_si128()));
        vacc4567 = _mm_add_epi32(vacc4567, _mm_unpacklo_epi16(_mm_srli_si128(vacc16_01234567, 8), _mm_setzero_si128()));
        r = doz(r, ${OVERFLOW});
      }

      if XNN_LIKELY(channels >= 8) {
        __m128i vo0123 = _mm_loadu_si128((const __m128i*) output);
        __m128i vo4567 = _mm_loadu_si128((const __m128i*) ((uintptr_t) output + 4 * sizeof(uint32_t)));
        vo0123 = _mm_add_epi32(vo0123, vacc0123);
        vo4567 = _mm_add_epi32(vo4567, vacc4567);
        _mm_storeu_si128((__m128i*) output, vo0123); output += 4;
        _mm_storeu_si128((__m128i*) output, vo4567); output += 4;
        channels -= 8;
        input = (const uint8_t*) ((uintptr_t) input + 8 * sizeof(uint8_t));
      } else {
        if (channels & 4) {
          __m128i vo0123 = _mm_loadu_si128((const __m128i*) output);
          vo0123 = _mm_add_epi32(vo0123, vacc0123);
          _mm_storeu_si128((__m128i*) output, vo0123); output += 4;
          vacc0123 = vacc4567;
        }
        if (channels & 2) {
          __m128i vo01 = _mm_loadl_epi64((const __m128i*) output);
          vo01 = _mm_add_epi32(vo01, vacc0123);
          _mm_storel_epi64((__m128i*) output, vo01); output += 2;
          vacc0123 = _mm_srli_si128(vacc0123, 8);
        }
        if (channels & 1) {
          __m128i vo0 = _mm_cvtsi32_si128(unaligned_load_u32(output));
          vo0 = _mm_add_epi32(vo0, vacc0123);
          _mm_storeu_si32(output, vo0);
        }
        channels = 0;
      }
    } while (channels != 0);
  }
}