// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR == 8
$assert KR == 8
$assert TYPE in ["int8_t"]
$assert IZP in [0, 128]

#include <assert.h>

#include <wasm_simd128.h>

#include "xnnpack/packw.h"


$ABC = "012345678"
$BTYPE = {"int8_t": "uint32_t"}[TYPE]
$WTYPE = {"int8_t": "int8_t"}[TYPE]
void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__wasmrelaxedsimd(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const ${WTYPE}* weights,
  const int32_t* bias,
  const void* scale,
  ${WTYPE}* packed_weights,
  size_t extra_bytes,
  const void* params) XNN_OOB_READS
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == ${KR});
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  const v128_t vone = wasm_i8x16_splat(1);
  const v128_t vzero = wasm_i32x4_splat(0);
  XNN_FORCE_REALIZATION(vone);
  XNN_FORCE_REALIZATION(vzero);
  ${TYPE}* out = (${TYPE}*) packed_weights;
  const ${BTYPE}* b = (const ${BTYPE}*) bias;
  const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP});
  v128_t vzeropoint = wasm_i32x4_splat((int32_t) izp);

  do {
    // NC main loop multiple of ${NR}
    const ${TYPE}* w0 = (const ${TYPE}*) weights;
    size_t n = nc;
    for (;n >= ${NR}; n -= ${NR}) {
      int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        $for N in range(0, NR, 4):
          const v128_t vb${N>>2} = wasm_v128_load(b + ${N});
          wasm_v128_store(out + ${N * 4}, vb${N>>2});
        b += ${NR};
      } else {
        $for N in range(0, NR, 4):
          wasm_v128_store(out + ${N * 4}, vzero);
      }
      out += ${NR} * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;

      $for N in range(0, NR, 2):
        v128_t vacc${ABC[N:N+2]} = wasm_i32x4_splat(0);

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${2 * KR}; k -= ${2 * KR}) {
        $for N in range(NR):
          v128_t v${N}_01 = wasm_v128_load(w${N});

        $for N in range(0, NR, 2):
          v128_t v${ABC[N:N+2]}_0 = wasm_i64x2_shuffle(v${N}_01, v${N+1}_01, 0, 2);
          v128_t v${ABC[N:N+2]}_1 = wasm_i64x2_shuffle(v${N}_01, v${N+1}_01, 1, 3);

        $for N in range(0, NR, 2):
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}_0, vone, vacc${ABC[N:N+2]});
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}_1, vone, vacc${ABC[N:N+2]});

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${N * KR}, v${ABC[N:N+2]}_0);

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${(N + 8) * KR}, v${ABC[N:N+2]}_1);

        $for N in range(NR):
          w${N} += ${2 * KR};
        out += ${2*NR*KR};
      }

      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(0, NR, 2):
          const v128_t v${N} = wasm_v128_load64_splat(w${N});
          const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
          const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);

        $for N in range(0, NR, 2):
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${N * KR}, v${ABC[N:N+2]});

        $for N in range(NR):
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder 1..KR-1
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});

        const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

        $for N in range(0, NR, 2):
          const v128_t v${N} = wasm_v128_load64_splat(w${N});
          const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
          v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);
          v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);

        $for N in range(0, NR, 2):
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${N * KR}, v${ABC[N:N+2]});

        $for N in range(NR):
          w${N} += k;
        out += ${NR*KR};
      }

      v128_t vksum0123 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc01, vacc23, 0, 2, 4, 6), wasm_v32x4_shuffle(vacc01, vacc23, 1, 3, 5, 7));
      v128_t vksum4567 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc45, vacc67, 0, 2, 4, 6), wasm_v32x4_shuffle(vacc45, vacc67, 1, 3, 5, 7));

      vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
      vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

      v128_t vpack0123 = wasm_v128_load(packed_b);
      v128_t vpack4567 = wasm_v128_load(packed_b + 4);

      wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
      wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1 && n <= ${NR-1});

      int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          *((${BTYPE}*) out) = *b++;
          out += sizeof(${BTYPE});
        } while (--nb != 0);
      } else {
        size_t nb = n;
        do {
          *((${BTYPE}*) out) = 0;
          out += sizeof(${BTYPE});
        } while (--nb != 0);
      }
      out += (${NR} - n) * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      $for N in range(0, NR, 2):
        v128_t vacc${ABC[N:N+2]} = wasm_i32x4_splat(0);

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(0, NR, 2):
          const v128_t v${N} = wasm_v128_load64_splat(w${N});
          const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
          const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);

        $for N in range(0, NR, 2):
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${N * KR}, v${ABC[N:N+2]});

        $for N in range(NR):
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});

        const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

        $for N in range(0, NR, 2):
          const v128_t v${N} = wasm_v128_load64_splat(w${N});
          const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
          v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);
          v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);

        $for N in range(0, NR, 2):
          vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});

        $for N in range(0, NR, 2):
          wasm_v128_store(out + ${N * KR}, v${ABC[N:N+2]});

        out += ${NR*KR};
      }

      v128_t vksum0123 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc01, vacc23, 0, 2, 4, 6), wasm_v32x4_shuffle(vacc01, vacc23, 1, 3, 5, 7));
      v128_t vksum4567 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc45, vacc67, 0, 2, 4, 6), wasm_v32x4_shuffle(vacc45, vacc67, 1, 3, 5, 7));

      vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
      vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

      v128_t vpack0123 = wasm_v128_load(packed_b);
      v128_t vpack4567 = wasm_v128_load(packed_b + 4);

      wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
      wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
