// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR >= 2
$assert KBLOCK >= 2

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <arm_neon.h>

#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


void xnn_x32_packw_gemm_goi_ukernel_x${NR}__neon_ld2lane_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  $for N in range(0,NR,2):
    $for K in range(0,KBLOCK,2):
      uint32x2x2_t v${N}${K};
      $for L in range(0,2):
        v${N}${K}.val[${L}] = vdup_n_u32(0);

  do {
    // NC main loop multiple of ${NR}
    const uint32_t* w0 = weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(bias != NULL) {
        $for N in range(0,NR,2):
          uint32x2_t vb${N} = vld1_u32(bias); bias += 2;
        $for N in range(0,NR,2):
          vst1_u32(packed_weights, vb${N}); packed_weights += 2;
      } else {
        const uint32x2_t vzero = vmov_n_u32(0);
        $for N in range(0,NR,2):
          vst1_u32(packed_weights, vzero); packed_weights += 2;
      }

      $for N in range(1, NR):
        const uint32_t* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      // KC main loop multiple of ${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $for K in range(0,KBLOCK,2):
          $for N in range(0,NR,2):
            $for L in range(0,2):
              v${N}${K} = vld2_lane_u32(w${N+L}, v${N}${K}, ${L}); w${N+L} += 2;
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        $for K in range(0,KBLOCK,2):
          $for L in range(0,2):
            $for N in range(0,NR,2):
              vst1_u32(packed_weights + ${N+L*NR+K*KBLOCK}, v${N}${K}.val[${L}]);
        packed_weights += ${NR*KBLOCK};
      }

      // KC remainder
      for (; k != 0; --k) {
        $for K in range(0,KBLOCK,2):
          $for N in range(0,NR,2):
            $for L in range(0,2):
              v${N}${K}.val[0] = vld1_lane_u32(w${N+L}, v${N}${K}.val[0], ${L});  w${N+L} += 1;
        $for K in range(0,KBLOCK,2):
          $for N in range(0,NR,2):
            vst1_u32(packed_weights + ${N+K*KBLOCK}, v${N}${K}.val[0]);
        packed_weights += ${NR*KBLOCK//2};
      }
      packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
      w0 = w${NR-1};
    }

    if XNN_UNLIKELY(n != 0) {
      $if NR == 2:
        // NC remainder of 1
        if XNN_LIKELY(bias != NULL) {
          *packed_weights = *bias++;
        } else {
          const uint32x2_t vzero = vmov_n_u32(0);
          $for N in range(0,NR,2):
            vst1_u32(packed_weights + ${N}, vzero);
        }
        packed_weights += ${NR};
        size_t k = kc;
        do {
          *packed_weights = *w0++;
          packed_weights += ${NR};
        } while (--k);
        packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
      $else:
        // NC remainder (1..${NR-1})
        if XNN_LIKELY(bias != NULL) {
          size_t nb = n;
          do {
            *packed_weights++  = *bias++;
          } while (--nb != 0);
          packed_weights += (${NR} - n);
        } else {
          packed_weights += ${NR};
        }

        size_t k = kc;
        do {
          const uint32_t* wn = w0;
          size_t nw = n;
          do {
            *packed_weights++ = wn[0];
            wn += kc;
          } while (--nw != 0);
          ++w0;
          packed_weights += (${NR} - n);
        } while (--k);
        packed_weights = (uint32_t*) ((uintptr_t) packed_weights + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
