// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR in [8, 16]
$assert KR == 8
$assert DATATYPE in ["QS8", "X8", "QS4"]
$assert TYPE in ["int8_t"]
$assert IZP in [0, 128]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <immintrin.h>

#include "xnnpack/packw.h"
#include "xnnpack/unaligned.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"
$if VARIANT == "MADD":
  // AVXVNNI replacement that uses vpmaddubsw.
  // u7 is vone.  s8 is int8 weights.
  static XNN_INTRINSIC
  __m256i _mm256_dpbusd_epi32_madd(__m256i i32, const __m256i u7, const __m256i s8) {
    const __m256i vone = _mm256_set1_epi16(1);
    const __m256i i16 = _mm256_maddubs_epi16(u7, s8);  // u7 * s8 = s16
    const __m256i v = _mm256_madd_epi16(i16, vone);  // convert 16 bits to 32 bits
    return _mm256_add_epi32(i32, v);
  }

XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) {
  uint64_t value = 0;
  assert(n <= sizeof(uint64_t));
  const uint8_t* bytes = (const uint8_t*) address;
  for (size_t i = 0; i < n; ++i) {
    value |= (uint64_t) bytes[i] << (i * 8);
  }
  return value;
}

$BTYPE = {"QS8": "int32_t", "QS4": "int32_t", "X8": "uint32_t"}[DATATYPE]
$WTYPE = {"QS8": "int8_t", "QS4": "uint8_t", "X8": "int8_t"}[DATATYPE]
$PACKEDWTYPE = {"QS8": "int8_t", "QS4": "void", "X8": "int8_t"}[DATATYPE]
$SCALETYPE = {"QS8": "void", "QS4": "float", "X8": "void"}[DATATYPE]
$PARAMTYPE = {"QS8": "void", "QS4": "struct xnn_qs8_qc4w_packing_params", "X8": "void"}[DATATYPE]
$if DATATYPE in ["QS8", "QS4"]:
  $_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32"
  $ISA = "avx2" if VARIANT == "MADD" else "avxvnni" if AVX == 2 else "avx256vnni"
$else:
  $ISA = "avx2" if AVX == 2 else "avx256skx"
$NAMETYPE = "qs8_to_qu8" if IZP == 128 else {"QS8": "qs8", "QS4": "qs8_qc4w", "X8": "x8"}[DATATYPE]

void xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const ${WTYPE}* weights,
  const ${BTYPE}* bias,
  const ${SCALETYPE}* scale,
  ${PACKEDWTYPE}* packed_weights,
  size_t extra_bytes,
  const ${PARAMTYPE}* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == ${KR});
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);
  assert(params != NULL);

  $if DATATYPE == "QS4":
    // Use scalar pack if not an even block size
    // TODO: Support NC remainder
    if ((kc & 1) || (nc & 7)) {
      xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__scalar(
        g, nc, kc, nr, kr, sr,
        weights, bias, scale, packed_weights, extra_bytes, params);
      return;
    }
    kc = (kc + 1) / 2;

  ${TYPE}* out = (${TYPE}*) packed_weights;
  const ${BTYPE}* b = (const ${BTYPE}*) bias;

  $if DATATYPE in ["QS8"]:
    const __m256i vone = _mm256_set1_epi8(1);
    const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}));
  $elif DATATYPE in ["QS4"]:
    const __m256i vone = _mm256_set1_epi8(1);
    const __m256i vmask = _mm256_set1_epi8(0xF0);
    const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + ${IZP});
    const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111);
    assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0);
    static const uint8_t perm0[32] = {
      0, -1,  1, -1,  2, -1,  3, -1,
      8, -1,  9, -1, 10, -1, 11, -1,
      0, -1,  1, -1,  2, -1,  3, -1,
      8, -1,  9, -1, 10, -1, 11, -1 };

    static const uint8_t perm1[32] = {
      -1,  0, -1,  1, -1,  2, -1,  3,
      -1,  8, -1,  9, -1, 10, -1, 11,
      -1,  0, -1,  1, -1,  2, -1,  3,
      -1,  8, -1,  9, -1, 10, -1, 11 };

    static const uint8_t perm2[32] = {
       4, -1,  5, -1,  6, -1,  7, -1,
      12, -1, 13, -1, 14, -1, 15, -1,
       4, -1,  5, -1,  6, -1,  7, -1,
      12, -1, 13, -1, 14, -1, 15, -1 };

    static const uint8_t perm3[32] = {
      -1,  4, -1,  5, -1,  6, -1,  7,
      -1, 12, -1, 13, -1, 14, -1, 15,
      -1,  4, -1,  5, -1,  6, -1,  7,
      -1, 12, -1, 13, -1, 14, -1, 15 };
    const __m256i vperm0 = _mm256_load_si256((const __m256i*) perm0);
    const __m256i vperm1 = _mm256_load_si256((const __m256i*) perm1);
    const __m256i vperm2 = _mm256_load_si256((const __m256i*) perm2);
    const __m256i vperm3 = _mm256_load_si256((const __m256i*) perm3);

  do {
    // NC main loop multiple of ${NR}
    const ${TYPE}* w0 = (const ${TYPE}*) weights;
    size_t n = nc;
    for (;n >= ${NR}; n -= ${NR}) {
      $if DATATYPE in ["QS8", "QS4"]:
        ${BTYPE}* packed_b = (${BTYPE}*) out;
      if XNN_LIKELY(b != NULL) {
        $for N in range(0, NR, 8):
          const __m256i vb${N} = _mm256_loadu_si256((const __m256i*) (b + ${N}));
        $for N in range(0, NR, 8):
          _mm256_storeu_si256((__m256i*) (out + ${N*4}), vb${N});
        b += ${NR};
      } else {
        $for N in range(0, NR, 8):
          _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256());
      }
      out += ${NR} * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          $for OFFSET in range(0, 448, 64):
            xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET});

      $if DATATYPE in ["QS8", "QS4"]:
        $for N in range(0, NR, 4):
          __m256i vacc${N} = _mm256_setzero_si256();

      size_t k = kc;
      // KC main loop multiple of ${NR}x${4 * KR}
      for (; k >= ${4 * KR}; k -= ${4 * KR}) {
        $for N in range(NR):
          const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N});

        $for N in range(0, NR, 2):
          const __m256i v${N}${N+1}_02 = _mm256_unpacklo_epi64(v${N}_0123, v${N+1}_0123);
          const __m256i v${N}${N+1}_13 = _mm256_unpackhi_epi64(v${N}_0123, v${N+1}_0123);

        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 448);

        $for N in range(0, NR, 4):
          $for I in range(0, 4):
            __m256i v${N}_${I} = _mm256_permute2f128_si256(v${N}${N+1}_${I%2}${I%2+2}, v${N+2}${N+3}_${I%2}${I%2+2}, _MM_SHUFFLE(0, ${I//2+2}, 0, ${I//2}));

        $if DATATYPE in ["QS8"]:
          $for N in range(0, NR, 4):
            $for I in range(0, 4):
              vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}_${I});
        $elif DATATYPE in ["QS4"]:
          $for N in range(0, NR, 4):
            $for I in range(0, 4):
              v${N}_${I} = _mm256_xor_si256(v${N}_${I}, vkernel_zero_point);    // uint4 -> int4
              const __m256i vt${N}_${I} = _mm256_slli_epi32(v${N}_${I}, 4);     // isolate lower int4
              const __m256i vh${N}_${I} = _mm256_and_si256(v${N}_${I}, vmask);  // isolate upper int4
              const __m256i vl${N}_${I} = _mm256_and_si256(vt${N}_${I}, vmask);
              vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vh${N}_${I});
              vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vl${N}_${I});
              const __m256i v0x${N}_${I} = _mm256_shuffle_epi8(vl${N}_${I}, vperm0);  // 0,2,4,6
              const __m256i v1x${N}_${I} = _mm256_shuffle_epi8(vh${N}_${I}, vperm1);  // 1,3,5,7
              const __m256i v01x${N}_${I} = _mm256_or_si256(v0x${N}_${I}, v1x${N}_${I});
              const __m256i v2x${N}_${I} = _mm256_shuffle_epi8(vl${N}_${I}, vperm2);  // 8,A,C,E
              const __m256i v3x${N}_${I} = _mm256_shuffle_epi8(vh${N}_${I}, vperm3);  // 9,B,D,F
              const __m256i v23x${N}_${I} = _mm256_or_si256(v2x${N}_${I}, v3x${N}_${I});
              const __m256i vt01${N}_${I} = _mm256_srli_epi32(v01x${N}_${I}, 4);  // first plane 0-7
              v${N}_${I} = _mm256_or_si256(vt01${N}_${I}, v23x${N}_${I});            // + second plane 8-F

        $for I in range(0, 4):
          $for N in range(0, NR, 4):
            _mm256_storeu_si256((__m256i *)&out[${(I*NR + N)*KR}],  v${N}_${I});

        $for N in range(NR):
          w${N} += ${4 * KR};
        out += ${4*NR*KR};
      }

      // KC main loop multiple of ${NR}x${KR}
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(0, NR, 4):
          __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N}));
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0);
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 448);

        $if DATATYPE in ["QS8"]:
          $for N in range(0, NR, 4):
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N});
        $elif DATATYPE in ["QS4"]:
          $for N in range(0, NR, 4):
            v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point);    // uint4 -> int4
            const __m256i vt${N} = _mm256_slli_epi32(v${N}, 4);     // isolate lower int4
            const __m256i vh${N} = _mm256_and_si256(v${N}, vmask);  // isolate upper int4
            const __m256i vl${N} = _mm256_and_si256(vt${N}, vmask);
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vh${N});
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vl${N});
            const __m256i v0x${N} = _mm256_shuffle_epi8(vl${N}, vperm0);  // 0,2,4,6
            const __m256i v1x${N} = _mm256_shuffle_epi8(vh${N}, vperm1);  // 1,3,5,7
            const __m256i v01x${N} = _mm256_or_si256(v0x${N}, v1x${N});
            const __m256i v2x${N} = _mm256_shuffle_epi8(vl${N}, vperm2);  // 8,A,C,E
            const __m256i v3x${N} = _mm256_shuffle_epi8(vh${N}, vperm3);  // 9,B,D,F
            const __m256i v23x${N} = _mm256_or_si256(v2x${N}, v3x${N});
            const __m256i vt01${N} = _mm256_srli_epi32(v01x${N}, 4);  // first plane 0-7
            v${N} = _mm256_or_si256(vt01${N}, v23x${N});            // + second plane 8-F

        $for N in range(0, NR, 4):
          _mm256_storeu_si256((__m256i *)&out[${N * KR}],  v${N});

        $for N in range(NR):
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});

        $for N in range(0, NR, 4):
          __m256i v${N} = _mm256_set1_epi64x((int64_t) safe_load_u64(w${N}, k));
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+1}, k)), 0x0C);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+2}, k)), 0x30);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+3}, k)), 0xC0);

        $for N in range(NR):
          w${N} += k;

        $if DATATYPE in ["QS8"]:
          $for N in range(0, NR, 4):
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N});
        $elif DATATYPE in ["QS4"]:
          $for N in range(0, NR, 4):
            v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point);    // uint4 -> int4
            const __m256i vt${N} = _mm256_slli_epi32(v${N}, 4);     // isolate lower int4
            const __m256i vh${N} = _mm256_and_si256(v${N}, vmask);  // isolate upper int4
            const __m256i vl${N} = _mm256_and_si256(vt${N}, vmask);
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vh${N});
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, vl${N});
            const __m256i v0x${N} = _mm256_shuffle_epi8(vl${N}, vperm0);  // 0,2,4,6
            const __m256i v1x${N} = _mm256_shuffle_epi8(vh${N}, vperm1);  // 1,3,5,7
            const __m256i v01x${N} = _mm256_or_si256(v0x${N}, v1x${N});
            const __m256i v2x${N} = _mm256_shuffle_epi8(vl${N}, vperm2);  // 8,A,C,E
            const __m256i v3x${N} = _mm256_shuffle_epi8(vh${N}, vperm3);  // 9,B,D,F
            const __m256i v23x${N} = _mm256_or_si256(v2x${N}, v3x${N});
            const __m256i vt01${N} = _mm256_srli_epi32(v01x${N}, 4);  // first plane 0-7
            v${N} = _mm256_or_si256(vt01${N}, v23x${N});            // + second plane 8-F

        $for N in range(0, NR, 4):
          _mm256_storeu_si256((__m256i *)&out[${N * KR}],  v${N});

        out += ${NR*KR};
      }

      $if DATATYPE in ["QS8", "QS4"]:
        $for N in range(0, NR, 8):
          __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4});
          vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0));
        $for N in range(0, NR, 8):
          vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint);
        $for N in range(0, NR, 8):
          __m256i vpack${N} =  _mm256_loadu_si256((const __m256i*) (packed_b + ${N}));
        $for N in range(0, NR, 8):
          vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N});
        $for N in range(0, NR, 8):
          _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N});
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1 && n <= ${NR-1});

      $if DATATYPE in ["QS8", "QS4"]:
        ${BTYPE}* packed_b = (${BTYPE}*) out;
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          *((${BTYPE}*) out) = *b++;
          out += sizeof(${BTYPE});
        } while (--nb != 0);
      } else {
        size_t nb = n;
        do {
          *((${BTYPE}*) out) = 0;
          out += sizeof(${BTYPE});
        } while (--nb != 0);
      }
      out += (${NR} - n) * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      $if DATATYPE in ["QS8", "QS4"]:
        $for N in range(0, NR, 4):
          __m256i vacc${N} = _mm256_setzero_si256();

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(0, NR, 4):
          __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N}));
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0);
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 448);

        $if DATATYPE in ["QS8", "QS4"]:
          $for N in range(0, NR, 4):
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N});

        $for N in range(0, NR, 4):
          _mm256_storeu_si256((__m256i *)&out[${N * KR}],  v${N});

        $for N in range(NR):
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});

        $for N in range(0, NR, 4):
          __m256i v${N} = _mm256_set1_epi64x((int64_t) safe_load_u64(w${N}, k));
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+1}, k)), 0x0C);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+2}, k)), 0x30);
          v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+3}, k)), 0xC0);

        $for N in range(NR):
          w${N} += k;

        $if DATATYPE in ["QS8", "QS4"]:
          $for N in range(0, NR, 4):
            vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N});

        $for N in range(0, NR, 4):
          _mm256_storeu_si256((__m256i *)&out[${N * KR}],  v${N});

        out += ${NR*KR};
      }

      $if DATATYPE in ["QS8", "QS4"]:
        $for N in range(0, NR, 8):
          __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4});
          vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0));
        $for N in range(0, NR, 8):
          vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint);
        $for N in range(0, NR, 8):
          __m256i vpack${N} =  _mm256_loadu_si256((const __m256i*) (packed_b + ${N}));
        $for N in range(0, NR, 8):
          vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N});
        $for N in range(0, NR, 8):
          _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N});
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
    }

    weights = (const ${WTYPE}*)((intptr_t) weights + nc * kc);
  } while (--g != 0);
}
