// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR >= 4
$assert NR % 4 == 0
$assert KBLOCK >= 4
$assert KBLOCK % 4 == 0
$assert SR == 1 or SR == 4
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <arm_neon.h>

#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


$IFWHILE = "if" if KBLOCK==8 else "while"
void xnn_x32_packw_gemm_goi_ukernel_x${NR}${"s4" if SR == 4 else ""}__neon_ld4lane_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == 1);
  assert(sr == ${SR});
  assert(weights != NULL);
  assert(packed_weights != NULL);
  $for N in range(0, NR, 4):
    $for K in range(0, KBLOCK, 4):
      uint32x4x4_t vtmp${ABC[K:K+4]}x${ABC[N:N+4]};
      $for L in range(4):
        vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[${L}] = vdupq_n_u32(0);

  do {
    // NC main loop multiple of ${NR}
    const uint32_t* w0 = weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(bias != NULL) {
        $for N in range(0, NR, 4):
          uint32x4_t vb${N} = vld1q_u32(bias); bias += 4;
        $for N in range(0, NR, 4):
          vst1q_u32(packed_weights, vb${N}); packed_weights += 4;
      } else {
        const uint32x4_t vzero = vmovq_n_u32(0);
        $for N in range(0, NR, 4):
          vst1q_u32(packed_weights, vzero); packed_weights += 4;
      }

      $for N in range(1, NR):
        const uint32_t* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      // KC main loop multiple of ${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $if SR == 4:
          uint32x4_t vsrtmp;
        $for K in range(0, KBLOCK, 4):
          $for N in range(0, NR, 4):
            $for L in range(4):
              vtmp${ABC[K:K+4]}x${ABC[N:N+4]} = vld4q_lane_u32(w${N+L}, vtmp${ABC[K:K+4]}x${ABC[N:N+4]}, ${L}); w${N+L} += 4;
              $if SR == 4:
                vsrtmp = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[3];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[3] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[2];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[1];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[0];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[0] = vsrtmp;
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        $for K in range(0, KBLOCK, 4):
          $for L in range(4):
            $for N in range(0, NR, 4):
              vst1q_u32(packed_weights, vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
      }
      $if KBLOCK > 4:
        // KC remainder multiple of 4
        ${IFWHILE} (k >= 4) {
          $if SR == 4:
            uint32x4_t vsrtmp;
          $for N in range(0, NR, 4):
            $for L in range(4):
              vtmp${ABC[0:4]}x${ABC[N:N+4]} = vld4q_lane_u32(w${N+L}, vtmp${ABC[0:4]}x${ABC[N:N+4]}, ${L}); w${N+L} += 4;
              $if SR == 4:
                vsrtmp = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[3];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[3] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[2];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[1];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[0];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[0] = vsrtmp;
          $if PREFETCH:
            $for N in range(0, NR):
              xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
          $for L in range(4):
            $for N in range(0, NR, 4):
              vst1q_u32(packed_weights, vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
          k -= 4;
        }

      // KC remainder of 1..3
      // Same as main loop but ld1, ld2 or ld3
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        switch (k) {
          $for K in range(1, 4):
            // KC remainder of ${K}
            case ${K}:
            {
              $if K == 1:
                $for N in range(0, NR, 4):
                  uint32x4_t vtmp${ABC[0:K]}x${ABC[N:N+4]} = vdupq_n_u32(0);
              $else:
                $for N in range(0, NR, 4):
                  uint32x4x${K}_t vtmp${ABC[0:K]}x${ABC[N:N+4]};
                  $for L in range(0, K):
                    vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[${L}] = vdupq_n_u32(0);
              $if SR == 4:
                uint32x4_t vsrtmp;
                $for N in range(0, NR, 4):
                  $for L in range(K, 4):
                    uint32x4_t vsrtmp${ABC[L]}x${ABC[N:N+4]} = vdupq_n_u32(0);
              $for N in range(0, NR, 4):
                $for L in range(4):
                  vtmp${ABC[0:K]}x${ABC[N:N+4]} = vld${K}q_lane_u32(w${N+L}, vtmp${ABC[0:K]}x${ABC[N:N+4]}, ${L}); w${N+L} += ${K};
                  $if SR == 4:
                    $if K == 3:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[2];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0] = vsrtmp;
                    $if K == 2:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vsrtmp${ABC[2]}x${ABC[N:N+4]};
                      vsrtmp${ABC[2]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0] = vsrtmp;
                    $if K == 1:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vsrtmp${ABC[2]}x${ABC[N:N+4]};
                      vsrtmp${ABC[2]}x${ABC[N:N+4]} = vsrtmp${ABC[1]}x${ABC[N:N+4]};
                      vsrtmp${ABC[1]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]};
                      vtmp${ABC[0:K]}x${ABC[N:N+4]} = vsrtmp;
              $for L in range(K):
                $for N in range(0, NR, 4):
                  $if K == 1:
                    vst1q_u32(packed_weights, vtmp${ABC[0:K]}x${ABC[N:N+4]}); packed_weights += 4;
                  $else:
                    vst1q_u32(packed_weights, vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
              $if SR == 4:
                $for L in range(K, 4):
                  $for N in range(0, NR, 4):
                    vst1q_u32(packed_weights, vsrtmp${ABC[L]}x${ABC[N:N+4]}); packed_weights += 4;
              break;
            }
          default:
            XNN_UNREACHABLE;
        }
      }
      packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(bias != NULL) {
        size_t nb = n;
        do {
          *packed_weights++  = *bias++;
        } while (--nb != 0);
        packed_weights += (${NR} - n);
      } else {
        const uint32x4_t vzero = vmovq_n_u32(0);
        $for N in range(0, NR, 4):
          vst1q_u32(packed_weights, vzero); packed_weights += 4;
      }

      // NR remainder has less than ${NR} rows so last row is not loaded
      $for N in range(1, NR-1):
        const uint32_t* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      // KC main loop multiple of ${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $if SR == 4:
          uint32x4_t vsrtmp;
        $for K in range(0, KBLOCK, 4):
          $for N in range(0, NR, 4):
            $for L in range(4):
              $if N+L != NR-1:
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]} = vld4q_lane_u32(w${N+L}, vtmp${ABC[K:K+4]}x${ABC[N:N+4]}, ${L}); w${N+L} += 4;
              $if SR == 4:
                vsrtmp = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[3];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[3] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[2];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[1];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[0];
                vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[0] = vsrtmp;
        $if PREFETCH:
          $for N in range(0, NR-1):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        $for K in range(0, KBLOCK, 4):
          $for L in range(4):
            $for N in range(0, NR, 4):
              vst1q_u32(packed_weights, vtmp${ABC[K:K+4]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
      }
      $if KBLOCK > 4:
        // KC remainder multiple of 4
        ${IFWHILE} (k >= 4) {
          $if SR == 4:
            uint32x4_t vsrtmp;
          $for N in range(0, NR, 4):
            $for L in range(4):
              $if N+L != NR-1:
                vtmp${ABC[0:4]}x${ABC[N:N+4]} = vld4q_lane_u32(w${N+L}, vtmp${ABC[0:4]}x${ABC[N:N+4]}, ${L}); w${N+L} += 4;
              $if SR == 4:
                vsrtmp = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[3];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[3] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[2];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[1];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[0];
                vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[0] = vsrtmp;
          $if PREFETCH:
            $for N in range(0, NR-1):
              xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
          $for L in range(4):
            $for N in range(0, NR, 4):
              vst1q_u32(packed_weights, vtmp${ABC[0:4]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
          k -= 4;
        }

      // KC remainder of 1..3
      // Same as main loop but ld1, ld2 or ld3
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        switch (k) {
          $for K in range(1, 4):
            // KC remainder of ${K}
            case ${K}:
            {
              $if K == 1:
                $for N in range(0, NR, 4):
                  uint32x4_t vtmp${ABC[0:K]}x${ABC[N:N+4]} = vdupq_n_u32(0);
              $else:
                $for N in range(0, NR, 4):
                  uint32x4x${K}_t vtmp${ABC[0:K]}x${ABC[N:N+4]};
                  $for L in range(0, K):
                    vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[${L}] = vdupq_n_u32(0);
              $if SR == 4:
                uint32x4_t vsrtmp;
                $for N in range(0, NR, 4):
                  $for L in range(K, 4):
                    uint32x4_t vsrtmp${ABC[L]}x${ABC[N:N+4]} = vdupq_n_u32(0);
              $for N in range(0, NR, 4):
                $for L in range(4):
                  $if N+L != NR-1:
                    vtmp${ABC[0:K]}x${ABC[N:N+4]} = vld${K}q_lane_u32(w${N+L}, vtmp${ABC[0:K]}x${ABC[N:N+4]}, ${L});
                  $if SR == 4:
                    $if K == 3:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[2];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[2] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0] = vsrtmp;
                    $if K == 2:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vsrtmp${ABC[2]}x${ABC[N:N+4]};
                      vsrtmp${ABC[2]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[1] = vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0];
                      vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[0] = vsrtmp;
                    $if K == 1:
                      vsrtmp = vsrtmp${ABC[3]}x${ABC[N:N+4]};
                      vsrtmp${ABC[3]}x${ABC[N:N+4]} = vsrtmp${ABC[2]}x${ABC[N:N+4]};
                      vsrtmp${ABC[2]}x${ABC[N:N+4]} = vsrtmp${ABC[1]}x${ABC[N:N+4]};
                      vsrtmp${ABC[1]}x${ABC[N:N+4]} = vtmp${ABC[0:K]}x${ABC[N:N+4]};
                      vtmp${ABC[0:K]}x${ABC[N:N+4]} = vsrtmp;
              $for L in range(K):
                $for N in range(0, NR, 4):
                  $if K == 1:
                    vst1q_u32(packed_weights, vtmp${ABC[0:K]}x${ABC[N:N+4]}); packed_weights += 4;
                  $else:
                    vst1q_u32(packed_weights, vtmp${ABC[0:K]}x${ABC[N:N+4]}.val[${L}]); packed_weights += 4;
              $if SR == 4:
                $for L in range(K, 4):
                  $for N in range(0, NR, 4):
                    vst1q_u32(packed_weights, vsrtmp${ABC[L]}x${ABC[N:N+4]}); packed_weights += 4;
              break;
            }
          default:
            XNN_UNREACHABLE;
        }
      }
      packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
    }

    weights += nc * kc;
  } while (--g != 0);
}
