// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <math.h>

#include <immintrin.h>

#include "xnnpack/common.h"
#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/math.h"
#include "xnnpack/reduce.h"
#include "xnnpack/unaligned.h"


void xnn_qs8_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c${CHANNELS}(
    size_t rows,
    size_t channels,
    const int8_t* input,
    size_t input_stride,
    const int8_t* zero,
    int32_t* output,
    const struct xnn_qs8_rsum_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(rows != 0);
  assert(channels != 0);
  assert(input != NULL);
  assert(output != NULL);

  size_t input_increment = ${ACCUMULATORS} * input_stride;
  for (; channels >= ${CHANNELS}; channels -= ${CHANNELS}) {
    const int8_t* i0 = input;
    $for ACC in range(1, ACCUMULATORS):
      const int8_t* i${ACC} = (const int8_t*) ((uintptr_t) input + ${ACC} * input_stride);

    $for C in range(0, CHANNELS, 16):
      __m512i vacc${C}_${C+16} = _mm512_setzero_si512();

    // 256 int8s may be summed into an int16 before overflowing
    // To prevent handling the tails of the inner 256 loop, we round 256 down to
    // the nearest integer multiple of ACCUMULATORS.
    $OVERFLOW = (256 // ACCUMULATORS) * ACCUMULATORS
    int num_batches = floor((rows + ${OVERFLOW - 1}) / ${OVERFLOW});
    int r = rows;
    for (; num_batches > 0; --num_batches) {
      $for C in range(0, CHANNELS, 32):
        __m512i v16acc_${C}_${C+32} = _mm512_setzero_si512();
      for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
        $for N in range(1, ACCUMULATORS, 2):
          if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
            i${N} = zero;
          }
          if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
            i${N+1} = zero;
          }
        $for C in range(0, CHANNELS, 32):
          __m512i vin${C}_${C+32};
        $for ACC in range(ACCUMULATORS):
          $for C in range(0, CHANNELS, 32):
            vin${C}_${C+32} = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*) &i${ACC}[${C}]));
          $for C in range(0, CHANNELS, 32):
            v16acc_${C}_${C+32} = _mm512_add_epi16(v16acc_${C}_${C+32}, vin${C}_${C+32});
        $for N in range(0, ACCUMULATORS):
          i${N} = (const int8_t*) ((uintptr_t) i${N} + input_increment);
      }
      $for C in range(0, CHANNELS, 32):
        vacc${C}_${C+16} = _mm512_add_epi32(vacc${C}_${C+16}, _mm512_cvtepi16_epi32(_mm512_castsi512_si256(v16acc_${C}_${C+32})));
        vacc${C+16}_${C+32} = _mm512_add_epi32(vacc${C+16}_${C+32}, _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(v16acc_${C}_${C+32}, 1)));
      r = doz(r, ${OVERFLOW});
    }

    const int32_t* o = output;
    $for C in range(0, CHANNELS, 16):
      __m512i vo${C}_${C+16} = _mm512_loadu_si512((const __m512i*) o); o += 16;
    $for C in range(0, CHANNELS, 16):
      vo${C}_${C+16} = _mm512_add_epi32(vacc${C}_${C+16}, vo${C}_${C+16});
    $for C in range(0, CHANNELS, 16):
      _mm512_storeu_si512((__m512i*) output, vo${C}_${C+16}); output += 16;

    input = (const int8_t*) ((uintptr_t) input + ${CHANNELS} * sizeof(int8_t));
  }
  if (channels != 0) {
    input_increment = ${ACCUMULATORS} * input_stride;
    // 256 int8s may be summed into an int16 before overflowing.
    do {
      int num_batches = floor((rows + ${OVERFLOW - 1}) / ${OVERFLOW});
      int r = rows;
      const int8_t* i0 = input;
      $for ACC in range(1, ACCUMULATORS):
        const int8_t* i${ACC} = (const int8_t*) ((uintptr_t) input + ${ACC} * input_stride);

      __m512i vacc0_16 = _mm512_setzero_si512();
      __m512i v16acc_32 = _mm512_setzero_si512();

      const size_t shift = channels < 32 ? channels : 32;
      const __mmask32 vmask = _cvtu32_mask32((uint32_t) ((UINT64_C(1) << shift) - UINT64_C(1)));
      for (; num_batches > 0; --num_batches) {
        __m512i v16acc_0_32 = _mm512_setzero_si512();
        for (int current_batch = min(r, ${OVERFLOW}); current_batch > 0; current_batch -= ${ACCUMULATORS}) {
          $for N in range(1, ACCUMULATORS, 2):
            if XNN_UNPREDICTABLE(current_batch < ${N+1}) {
              i${N} = zero;
            }
            if XNN_UNPREDICTABLE(current_batch <= ${N+1}) {
              i${N+1} = zero;
            }

          $for ACC in range(ACCUMULATORS):
            __m512i vin${ACC} = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(vmask, (const __m256i*)&i${ACC}[0]));
          $for ACC in range(ACCUMULATORS):
            v16acc_0_32 = _mm512_add_epi16(v16acc_0_32, vin${ACC});
          $for N in range(ACCUMULATORS):
            i${N} = (const int8_t*) ((uintptr_t) i${N} + input_increment);
        }
        vacc0_16 = _mm512_add_epi32(vacc0_16, _mm512_cvtepi16_epi32(_mm512_castsi512_si256(v16acc_0_32)));
        v16acc_32 = _mm512_add_epi32(v16acc_32, _mm512_cvtepi16_epi32(_mm512_extracti32x8_epi32(v16acc_0_32, 1)));
        r = doz(r, ${OVERFLOW});
      }

      if XNN_LIKELY(channels >= 32) {
        __m512i vo0_16 = _mm512_loadu_epi32(output);
        __m512i vo16_32 = _mm512_loadu_epi32(output + 16);
        vo0_16 = _mm512_add_epi32(vo0_16, vacc0_16);
        vo16_32 = _mm512_add_epi32(vo16_32, v16acc_32);
        _mm512_storeu_si512((__m512i*) output, vo0_16); output += 16;
        _mm512_storeu_si512((__m512i*) output, vo16_32); output += 16;
        channels -= 32;
        input = (const int8_t*) ((uintptr_t) input + 32 * sizeof(int8_t));
      } else {
        if (channels & 16) {
          __m512i vo0_16 = _mm512_loadu_epi32(output);
          vo0_16 = _mm512_add_epi32(vo0_16, vacc0_16);
          _mm512_storeu_si512((__m512i*) output, vo0_16); output += 16;
          vacc0_16 = v16acc_32;
        }
        __m256i vacc0_8 = _mm512_castsi512_si256(vacc0_16);
        if (channels & 8) {
          __m256i vo0_8 = _mm256_loadu_si256((const __m256i*) output);
          vo0_8 = _mm256_add_epi32(vo0_8, vacc0_8);
          _mm256_storeu_si256((__m256i*) output, vo0_8); output += 8;
          vacc0_8 = _mm512_extracti32x8_epi32(vacc0_16, 1);
        }
        if (channels & 4) {
          __m128i vo0_4 = _mm_loadu_si128((const __m128i*) output);
          vo0_4 = _mm_add_epi32(vo0_4, _mm256_castsi256_si128(vacc0_8));
          _mm_storeu_si128((__m128i*) output, vo0_4); output += 4;
          vacc0_8 = _mm256_castsi128_si256(_mm256_extractf128_si256(vacc0_8, 1));
        }
        if (channels & 2) {
          __m128i vo0_2 = _mm_loadl_epi64((const __m128i*) output);
          vo0_2 = _mm_add_epi32(vo0_2, _mm256_castsi256_si128(vacc0_8));
          _mm_storel_epi64((__m128i*) output, vo0_2); output += 2;
          vacc0_8 = _mm256_srli_si256(vacc0_8, 8);
        }
        if (channels & 1) {
          *output += _mm_cvtsi128_si32(_mm256_castsi256_si128(vacc0_8));
        }
        channels = 0;
      }
    } while (channels != 0);
  }
}
