// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR >= 4
$assert NR % 4 == 0
$assert KBLOCK == 4
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <wasm_simd128.h>

#include "xnnpack/common.h"
#include "xnnpack/packw.h"


void xnn_x32_packw_gemm_goi_ukernel_x${NR}s4__wasmsimd_u${KBLOCK}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == 1);
  assert(sr == 4);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  do {
    // NC main loop multiple of ${NR}
    const uint32_t* w0 = (const uint32_t*) weights;

    size_t n = nc;
    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(bias != NULL) {
        const v128_t vb${ABC[0:4]} = wasm_v128_load(bias);
        $for N in range(4, NR, 4):
          const v128_t vb${ABC[N:N+4]} = wasm_v128_load(bias + ${N});
        bias += ${NR};

        wasm_v128_store(packed_weights, vb${ABC[0:4]});
        $for N in range(4, NR, 4):
          wasm_v128_store(packed_weights + ${N}, vb${ABC[N:N+4]});
      } else {
        const v128_t vzero = wasm_i32x4_const_splat(0);
        wasm_v128_store(packed_weights, vzero);
        $for N in range(4, NR, 4):
          wasm_v128_store(packed_weights + ${N}, vzero);
      }
      packed_weights += ${NR};

      $for N in range(1, NR):
        const uint32_t* w${ABC[N]} = w${ABC[N-1]} + kc;

      // KC main loop multiple of ${NR}x${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $for N in range(NR):
          v128_t v${ABC[N]}x${ABC[0:4]} = wasm_v128_load(w${ABC[N]});
          $for K in range(4, KBLOCK, 4):
            v128_t v${ABC[N]}x${ABC[K:K+4]} = wasm_v128_load(w${ABC[N]} + ${K});
          w${ABC[N]} += ${KBLOCK};

        $for N in range(0, NR, 4):
          $for L in range(1, 4):
            $for K in range(0, KBLOCK, 4):
              v${ABC[N+L]}x${ABC[K:K+4]} = wasm_v32x4_shuffle(v${ABC[N+L]}x${ABC[K:K+4]}, v${ABC[N+L]}x${ABC[K:K+4]}, ${L}, ${(L+1)%4}, ${(L+2)%4}, ${(L+3)%4});

        $for N in range(0, NR, 4):
          $for K in range(0, KBLOCK, 4):
            const v128_t v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]} = wasm_v32x4_shuffle(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]}, 0, 4, 1, 5);
            const v128_t v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]} = wasm_v32x4_shuffle(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[N+3]}x${ABC[K:K+4]}, 0, 4, 1, 5);
            const v128_t v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]} = wasm_v32x4_shuffle(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]}, 2, 6, 3, 7);
            const v128_t v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]} = wasm_v32x4_shuffle(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[N+3]}x${ABC[K:K+4]}, 2, 6, 3, 7);

        $for N in range(0, NR, 4):
          $for K in range(0, KBLOCK, 4):
            const v128_t v${ABC[N:N+4]}x${ABC[K]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, 0, 2);
            const v128_t v${ABC[N:N+4]}x${ABC[K+1]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, 1, 3);
            const v128_t v${ABC[N:N+4]}x${ABC[K+2]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, 0, 2);
            const v128_t v${ABC[N:N+4]}x${ABC[K+3]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, 1, 3);

        wasm_v128_store(packed_weights, v${ABC[0:4]}x${ABC[0]});
        $for K in range(KBLOCK):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              wasm_v128_store(packed_weights + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
        packed_weights += ${NR*KBLOCK};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        $for N in range(NR):
          v128_t v${ABC[N]}x0123 = wasm_i32x4_const_splat(0);

        switch (k) {
          case 1:
            $for N in range(NR):
              v${ABC[N]}x0123 = wasm_v128_load32_zero(w${ABC[N]});
              w${ABC[N]} += 1;
            break;
          case 2:
            $for N in range(NR):
              v${ABC[N]}x0123 = wasm_v128_load64_zero(w${ABC[N]});
              w${ABC[N]} += 2;
            break;
          case 3:
            $for N in range(NR):
              v${ABC[N]}x0123 = wasm_v128_load64_zero(w${ABC[N]});
              w${ABC[N]} += 2;

            $for N in range(NR):
              v${ABC[N]}x0123 = wasm_v128_load32_lane(w${ABC[N]}, v${ABC[N]}x0123, 2);
              w${ABC[N]} += 1;
            break;
          default:
            XNN_UNREACHABLE;
        }

        $for N in range(0, NR, 4):
          $for L in range(1, 4):
            v${ABC[N+L]}x0123 = wasm_v32x4_shuffle(v${ABC[N+L]}x0123, v${ABC[N+L]}x0123, ${L}, ${(L+1)%4}, ${(L+2)%4}, ${(L+3)%4});

        $for N in range(0, NR, 4):
          const v128_t v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = wasm_v32x4_shuffle(v${ABC[N]}x0123, v${ABC[N+1]}x0123, 0, 4, 1, 5);
          const v128_t v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = wasm_v32x4_shuffle(v${ABC[N+2]}x0123, v${ABC[N+3]}x0123, 0, 4, 1, 5);
          const v128_t v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3 = wasm_v32x4_shuffle(v${ABC[N]}x0123, v${ABC[N+1]}x0123, 2, 6, 3, 7);
          const v128_t v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3 = wasm_v32x4_shuffle(v${ABC[N+2]}x0123, v${ABC[N+3]}x0123, 2, 6, 3, 7);

        $for N in range(0, NR, 4):
          const v128_t v${ABC[N:N+4]}x0 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, 0, 2);
          const v128_t v${ABC[N:N+4]}x1 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, 1, 3);
          const v128_t v${ABC[N:N+4]}x2 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3, v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3, 0, 2);
          const v128_t v${ABC[N:N+4]}x3 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3, v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3, 1, 3);

        wasm_v128_store(packed_weights, v${ABC[0:4]}x0);
        $for K in range(4):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              wasm_v128_store(packed_weights + ${N+K*NR}, v${ABC[N:N+4]}x${K});
        packed_weights += ${NR*4};
      }
      packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
      w0 = w${ABC[NR-1]};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(bias != NULL) {
        size_t nb = n;
        do {
          *packed_weights++  = *bias++;
        } while (--nb != 0);
        packed_weights += (${NR} - n);
      } else {
        const v128_t vzero = wasm_i32x4_const_splat(0);
        wasm_v128_store(packed_weights, vzero);
        $for N in range(4, NR, 4):
          wasm_v128_store(packed_weights + ${N}, vzero);
        packed_weights += ${NR};
      }

      $for N in range(1, NR-1):
        const uint32_t* w${ABC[N]} = w${ABC[N-1]} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${ABC[N]} = w${ABC[N-1]};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${ABC[N]} = w${ABC[N-1]};
          }

      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $for N in range(NR-1):
          v128_t v${ABC[N]}x${ABC[0:4]} = wasm_v128_load(w${ABC[N]});
          $for K in range(4, KBLOCK, 4):
            v128_t v${ABC[N]}x${ABC[K:K+4]} = wasm_v128_load(w${ABC[N]} + ${K});
          w${ABC[N]} += ${KBLOCK};

        $for N in range(0, NR, 4):
          $for L in range(1, 4):
            $if N + L != NR - 1:
              $for K in range(0, KBLOCK, 4):
                v${ABC[N+L]}x${ABC[K:K+4]} = wasm_v32x4_shuffle(v${ABC[N+L]}x${ABC[K:K+4]}, v${ABC[N+L]}x${ABC[K:K+4]}, ${L}, ${(L+1)%4}, ${(L+2)%4}, ${(L+3)%4});

        $for N in range(0, NR, 4):
          $for K in range(0, KBLOCK, 4):
            const v128_t v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]} = wasm_v32x4_shuffle(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]}, 0, 4, 1, 5);
            const v128_t v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]} = wasm_v32x4_shuffle(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[min(N+3, NR-2)]}x${ABC[K:K+4]}, 0, 4, 1, 5);
            const v128_t v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]} = wasm_v32x4_shuffle(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]}, 2, 6, 3, 7);
            const v128_t v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]} = wasm_v32x4_shuffle(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[min(N+3, NR-2)]}x${ABC[K:K+4]}, 2, 6, 3, 7);

        $for N in range(0, NR, 4):
          $for K in range(0, KBLOCK, 4):
            const v128_t v${ABC[N:N+4]}x${ABC[K]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, 0, 2);
            const v128_t v${ABC[N:N+4]}x${ABC[K+1]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, 1, 3);
            const v128_t v${ABC[N:N+4]}x${ABC[K+2]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, 0, 2);
            const v128_t v${ABC[N:N+4]}x${ABC[K+3]} = wasm_v64x2_shuffle(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, 1, 3);

        wasm_v128_store(packed_weights, v${ABC[0:4]}x${ABC[0]});
        $for K in range(KBLOCK):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              wasm_v128_store(packed_weights + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
        packed_weights += ${NR*KBLOCK};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        $for N in range(NR-1):
          v128_t v${ABC[N]}x0123 = wasm_i32x4_const_splat(0);

        switch (k) {
          case 1:
            $for N in range(NR-1):
              v${ABC[N]}x0123 = wasm_v128_load32_zero(w${ABC[N]});
              w${ABC[N]} += 1;
            break;
          case 2:
            $for N in range(NR-1):
              v${ABC[N]}x0123 = wasm_v128_load64_zero(w${ABC[N]});
              w${ABC[N]} += 2;
            break;
          case 3:
            $for N in range(NR-1):
              v${ABC[N]}x0123 = wasm_v128_load64_zero(w${ABC[N]});
              w${ABC[N]} += 2;

            $for N in range(NR-1):
              v${ABC[N]}x0123 = wasm_v128_load32_lane(w${ABC[N]}, v${ABC[N]}x0123, 2);
              w${ABC[N]} += 1;
            break;
          default:
            XNN_UNREACHABLE;
        }

        $for N in range(0, NR, 4):
          $for L in range(1, 4):
            $if N + L != NR - 1:
              v${ABC[N+L]}x0123 = wasm_v32x4_shuffle(v${ABC[N+L]}x0123, v${ABC[N+L]}x0123, ${L}, ${(L+1)%4}, ${(L+2)%4}, ${(L+3)%4});

        $for N in range(0, NR, 4):
          const v128_t v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = wasm_v32x4_shuffle(v${ABC[N]}x0123, v${ABC[N+1]}x0123, 0, 4, 1, 5);
          const v128_t v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = wasm_v32x4_shuffle(v${ABC[N+2]}x0123, v${ABC[min(N+3, NR-2)]}x0123, 0, 4, 1, 5);
          const v128_t v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3 = wasm_v32x4_shuffle(v${ABC[N]}x0123, v${ABC[N+1]}x0123, 2, 6, 3, 7);
          const v128_t v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3 = wasm_v32x4_shuffle(v${ABC[N+2]}x0123, v${ABC[min(N+3, NR-2)]}x0123, 2, 6, 3, 7);

        $for N in range(0, NR, 4):
          const v128_t v${ABC[N:N+4]}x0 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, 0, 2);
          const v128_t v${ABC[N:N+4]}x1 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, 1, 3);
          const v128_t v${ABC[N:N+4]}x2 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3, v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3, 0, 2);
          const v128_t v${ABC[N:N+4]}x3 = wasm_v64x2_shuffle(v${ABC[N:N+2]}x2_${ABC[N:N+2]}x3, v${ABC[N+2:N+4]}x2_${ABC[N+2:N+4]}x3, 1, 3);

        wasm_v128_store(packed_weights, v${ABC[0:4]}x0);
        $for K in range(4):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              wasm_v128_store(packed_weights + ${N+K*NR}, v${ABC[N:N+4]}x${K});
        packed_weights += ${NR*4};
      }
      packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
