// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.


$assert NR == 8 or NR == 16
$assert KBLOCK == 16
#include <assert.h>
#include <stdint.h>
#include <string.h>

#include <immintrin.h>

#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"
$#
$INDENT = 0
$_ = ""
$#
$def SET_INDENT(NEW_INDENT):
  $global INDENT
  $global _
  $INDENT=NEW_INDENT
  $_ = "  " * INDENT
$#
$def BEGIN():
  $SET_INDENT(INDENT + 1)
  $return '{'
$#
$def END():
  $SET_INDENT(INDENT - 1)
  $return _ + '}'
$#
$def LOAD_TO_REGISTER(REGISTER, ROW, SIZE):
  ${_}__m256i ${REGISTER} = _mm256_maskload_epi32((const int*) ${ROW}, _mm256_loadu_si256((const __m256i*) (mask + (8 - ${SIZE} / 2))));
  ${_}switch (${SIZE}) {
    $for N in range(1, KBLOCK, 2):
      ${_}case ${N}: ${REGISTER} = _mm256_insert_epi16(${REGISTER}, (int16_t) ${ROW}[${N-1}], ${N-1}); break;
  ${_}  default:;
  ${_}}
  ${_}${ROW} += ${SIZE};
$#
$MASKS = [str(i) for i in ([-1] * 8) + ([0] * 8)]
$def LOAD_ROWS_TO_REGISTER_K_REMAINDER(IS_N_REMAINDER):
  $NUM = NR - IS_N_REMAINDER
  $for N in range(NR):
    ${_}__m256i v${N};
  ${_}__m256i vmask;
  ${_}switch(k) {
    $for K in range(1, KBLOCK):
        ${_}case ${K}:
          $ARGS = ', '.join(list(reversed(MASKS[(8 - K // 2) : (16 - K // 2)])))
          $if K > 1:
            ${_}vmask = _mm256_set_epi32(${ARGS});
          $for N in range(NUM):
            $if K > 1:
              ${_}v${N} = _mm256_maskload_epi32((const int*) w${N}, vmask);
            $else:
              ${_}v${N} = _mm256_setzero_si256();
            $if K % 2 == 1:
              ${_}v${N} = _mm256_insert_epi16(v${N}, (int16_t) w${N}[${K-1}], ${K-1});
          ${_}break;
  ${_}}
  $for N in range(NUM):
    ${_}w${N} += k;
$#
$def LOAD_ROWS_TO_REGISTERS(IS_N_REMAINDER):
  $for N in range(0, NR-1):
    ${_}__m256i v${N} = _mm256_loadu_si256((const __m256i*) w${N});
    ${_}w${N} += 16;
  $N = NR-1
  $if not IS_N_REMAINDER:
    ${_}__m256i v${N} = _mm256_loadu_si256((const __m256i*) w${N});
    ${_}w${N} += 16;
  $else:
    ${_}__m256i v${N};
$#
$def PLUS_8_ELEMENT_WISE(ARR):
  $return [X + 8 for X in ARR]
$#
$def DOUBLE_FOR_NR_16_IF_NEEDED(ARR):
  $return ARR if NR == 8 else ARR + PLUS_8_ELEMENT_WISE(ARR)
$#
$def TRANSPOSE(IS_N_REMAINDER):
    $IND = list(range(0, NR))
    $if IS_N_REMAINDER:
      $IND[NR-1] = NR-2
    ${_}// Interleave 16-bit lanes
    $for FIRST in range(0, NR, 2):
      $SECOND = FIRST + 1
      ${_}__m256i vt${FIRST} = _mm256_unpacklo_epi16(v${IND[FIRST]}, v${IND[SECOND]});
      ${_}__m256i vt${SECOND} = _mm256_unpackhi_epi16(v${IND[FIRST]}, v${IND[SECOND]});

    $PREFETCH_ROWS(NR - IS_N_REMAINDER)

    ${_}// Interleave 32-bit lanes
    $OUT_INDEX = 0
    $for FIRST in DOUBLE_FOR_NR_16_IF_NEEDED([0, 1, 4, 5]):
      $SECOND = FIRST + 2
      ${_}v${OUT_INDEX} = _mm256_unpacklo_epi32(vt${FIRST}, vt${SECOND});
      ${_}v${OUT_INDEX + 1} = _mm256_unpackhi_epi32(vt${FIRST}, vt${SECOND});
      $OUT_INDEX += 2

    ${_}// Interleave 64-bit lanes
    $OUT_INDEX = 0
    $for FIRST in DOUBLE_FOR_NR_16_IF_NEEDED([0, 1, 2, 3]):
      $SECOND = FIRST + 4
      ${_}vt${OUT_INDEX} = _mm256_unpacklo_epi64(v${FIRST}, v${SECOND});
      ${_}vt${OUT_INDEX + 1} = _mm256_unpackhi_epi64(v${FIRST}, v${SECOND});
      $OUT_INDEX += 2

    $if NR == 8:
      $FIRSTS = list(range(0, 8, 2))
      $SKIP = 1
    $else:
      $FIRSTS = list(range(8))
      $SKIP = 8
    $for FIRST in FIRSTS:
      $SECOND = FIRST + SKIP
      ${_}v${FIRST} = _mm256_inserti128_si256(vt${FIRST}, _mm256_castsi256_si128(vt${SECOND}), 1);
      ${_}v${SECOND} = _mm256_permute2x128_si256(vt${FIRST}, vt${SECOND}, 0x31);
$#
$def STORE_REGISTER(N):
  $ALIGNED_SUFFIX = "" if NR == 16 else "u"
  $FUN = f'_mm256_store{ALIGNED_SUFFIX}_si256'
  ${_}${FUN}((__m256i*) packed_weights + ${N}, v${ORDER[N]});
$#
$def STORE_REGISTERS_FULLY(NUM=NR):
    $for N in range(NUM):
      $STORE_REGISTER(N)
    ${_}packed_weights += ${16 * NUM};
$#
$def STORE_LOW_FIRST_REGISTER():
      ${_}_mm_storeu_si128((__m128i*) packed_weights, _mm256_castsi256_si128(v${ORDER[0]}));
      ${_}packed_weights += 8;
$#
$def STORE_REGISTERS_WITH_K_REMAINDER():
  ${_}if (k & 8) ${BEGIN()}
    $STORE_REGISTERS_FULLY(NR // 2)
    $MOVE_BOTTOM_HALF_REGISTERS(NR)
  ${END()}
  ${_}if (k & 4) ${BEGIN()}
    $STORE_REGISTERS_FULLY(NR // 4)
    $MOVE_BOTTOM_HALF_REGISTERS(NR // 2)
  ${END()}
  ${_}if (k & 2) ${BEGIN()}
    $STORE_REGISTERS_FULLY(NR // 8)
    $MOVE_BOTTOM_HALF_REGISTERS(NR // 4)
  ${END()}
  ${_}if (k & 1) ${BEGIN()}
    $if NR == 8:
      $STORE_LOW_FIRST_REGISTER()
    $else:
      $STORE_REGISTERS_FULLY(1)
  ${END()}
$#
$def MOVE_BOTTOM_HALF_REGISTERS(N):
  $for I in range(N//2):
    ${_}v${ORDER[I]} = v${ORDER[N//2 + I]};
$#
$if NR == 8:
  $ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
$else:
  $ORDER = list(range(NR))
$#
$def PREFETCH_ROWS(NUM):
  $if PREFETCH:
    $for N in range(NUM):
      ${_}xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
$#
$def ASSERT_K_REMAINDER_VALUE():
  ${_}assert(k >= 1);
  ${_}assert(k < 16);


void xnn_x16_packw_gemm_goi_ukernel_x${NR}__avx2_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint16_t* weights,
  const uint16_t* bias,
  const void* scale,
  uint16_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{

  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});   // This kernel is for NR=${NR}
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);


  do {
    const uint16_t* w0 = weights;
    size_t n = nc;
    for (; n >= ${NR}; n -= ${NR}) {
      $if NR == 8:
        $PREFIX = ""
        $TYPE_SIZE = "128"
        $U = "u"
      $else:
        $PREFIX = "256"
        $TYPE_SIZE = "256"
        $U = ""
      $TYPE = "__m" + TYPE_SIZE + "i"
      {
        ${TYPE} vtmp;
        if XNN_LIKELY(bias != NULL) {
          vtmp = _mm${PREFIX}_loadu_si${TYPE_SIZE}((const ${TYPE}*) bias);
          bias += ${NR};
        } else {
          vtmp = _mm${PREFIX}_setzero_si${TYPE_SIZE}();
        }
        _mm${PREFIX}_store${U}_si${TYPE_SIZE}((${TYPE}*) packed_weights, vtmp);
        packed_weights += ${NR};
      }
      $for N in range(1, NR):
        const uint16_t* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      size_t k = kc;
      for (; k >= 16; k -= 16) {
        $SET_INDENT(4)
        $LOAD_ROWS_TO_REGISTERS(False)
        $TRANSPOSE(False)
        $STORE_REGISTERS_FULLY()
      }
      // KC remainder
      if XNN_UNLIKELY(k != 0) {
        $SET_INDENT(4)
        $ASSERT_K_REMAINDER_VALUE()
        $LOAD_ROWS_TO_REGISTER_K_REMAINDER(False)
        $TRANSPOSE(False)
        $STORE_REGISTERS_WITH_K_REMAINDER()
      }
      packed_weights = (uint16_t*) ((uintptr_t) packed_weights + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(bias != NULL) {
        memcpy(packed_weights, bias, n * 2);
        bias += n;
      } else {
        memset(packed_weights, 0, ${NR*2});
      }
      packed_weights += ${NR};
      // NR remainder has less than ${NR} rows so last row is not loaded
      $for N in range(1, NR-1):
        const uint16_t* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $SET_INDENT(4)
        $LOAD_ROWS_TO_REGISTERS(True)
        $TRANSPOSE(True)
        $STORE_REGISTERS_FULLY()
      }

      // KC and NC remainder
      if XNN_UNLIKELY(k != 0) {
        $SET_INDENT(4)
        $ASSERT_K_REMAINDER_VALUE()
        $LOAD_ROWS_TO_REGISTER_K_REMAINDER(True)
        $TRANSPOSE(True)
        $STORE_REGISTERS_WITH_K_REMAINDER()
      }
      packed_weights = (uint16_t*) ((uintptr_t) packed_weights + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
