// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR % 16 == 0
$assert NR <= 64
$SIMD_SIZE = 16
$SIMD_TILE = NR // 16

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <immintrin.h>

#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


// Pack pre-transposed weights (GIO) for use by f32-gemm
void xnn_x32_packw_gemm_gio_ukernel_x${NR}__avx512f_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,                  // Batch size (outer loop).  usually 1
  size_t nc,                 // Number of columns and typically large
  size_t kc,                 // Number of rows and typically small
  size_t nr,                 // Matches gemm and is a multiple of vector sizes
  size_t kr,                 // unused - must be 1
  size_t sr,                 // unused - must be 1
  size_t k_stride,           // Elements per row (typically same as nc)
  const uint32_t* weights,   // Weights to pack. unaligned, unpadded
  const uint32_t* bias,      // Bias to pack. unaligned, unpadded, can be NULL
  const void* scale,         // unused
  uint32_t* packed_weights,  // packed weights output buffer - aligned, padded
  size_t extra_bytes,        // number of extra bytes between weights. aligned
  const void* params)        // unused
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});   // This kernel is for NR=${NR}
  assert(kr == 1);
  assert(sr == 1);
  assert(k_stride != 0);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  const __m512 vzero = _mm512_setzero_ps();
  const float* b = (const float*) bias;
  float* packed_w = (float*) packed_weights;
  do {
    // NC main loop multiple of ${NR}
    const float* w = (const float*) weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(b != NULL) {
        $for N in range(SIMD_TILE):
          const __m512 vb${N} = _mm512_loadu_ps(b + ${N*SIMD_SIZE});
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, vb${N});
        b += ${NR};
      } else {
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, vzero);
      }
      packed_w += ${NR};

      size_t k = kc;
      $if KBLOCK > 1:
        // KC main loop ${KBLOCK}x${NR}
        for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
          $for K in range(KBLOCK):
            $for N in range(SIMD_TILE):
              const __m512 v${N}_${K} = _mm512_loadu_ps(w + ${N*SIMD_SIZE} + ${K} * k_stride);
          $if PREFETCH:
            $for K in range(KBLOCK):
              $for N in range(0,SIMD_TILE*SIMD_SIZE,64):
                xnn_prefetch_to_l1((const int8_t*) w + ${N+960} + ${K} * k_stride);
          $for K in range(KBLOCK):
            $for N in range(SIMD_TILE):
              _mm512_store_ps(packed_w + ${N*SIMD_SIZE+K*NR}, v${N}_${K});
          w += k_stride * ${KBLOCK};
          packed_w += ${NR*KBLOCK};
        }

      // KC remainder loop
      for (; k > 0; --k) {
        $for N in range(SIMD_TILE):
          const __m512 v${N} = _mm512_loadu_ps(w + ${N*SIMD_SIZE});
        $if PREFETCH:
          $for N in range(0,SIMD_TILE*SIMD_SIZE,64):
            xnn_prefetch_to_l1((const int8_t*) w + ${N+960});
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, v${N});
        w += k_stride;
        packed_w += ${NR};
      }
      w = w - kc * k_stride + ${NR};  // Advance to next column of ${NR} floats
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});

      // Prepare mask for valid 32-bit elements (depends on n).
      $for N in range(SIMD_TILE):
        const __mmask16 vmask${N} = _cvtu32_mask16((uint32_t) (((UINT64_C(1) << n) - 1) >> ${N*SIMD_SIZE}));

      if XNN_LIKELY(b != NULL) {
        $for N in range(SIMD_TILE):
          const __m512 vb${N} = _mm512_maskz_loadu_ps(vmask${N}, b + ${N*SIMD_SIZE});
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, vb${N});
        b += n;
      } else {
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, vzero);
      }
      packed_w += ${NR};

      // KC main loop
      for (size_t k = kc; k > 0; --k) {
        $for N in range(SIMD_TILE):
          const __m512 v${N} = _mm512_maskz_loadu_ps(vmask${N}, w + ${N*SIMD_SIZE});
        $for N in range(SIMD_TILE):
          _mm512_store_ps(packed_w + ${N*SIMD_SIZE}, v${N});
        w += k_stride;
        packed_w += ${NR};
      }
    }
    weights += nc * kc;
  } while (--g != 0);
}
