// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(","))
$SIMD_SIZE = BATCH_TILES[0]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/simd/s32-${ARCH}.h"

#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"

static XNN_INLINE xnn_simd_s32_t
xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) {
  assert(num_elements <= xnn_simd_size_s32);
  int32_t buf[${SIMD_SIZE}];
  for (size_t i = 0; i < num_elements; ++i) {
    buf[i] = input[i];
  }
  return xnn_loadu_s32((const int32_t*) &buf[0]);
}

$for NR in BATCH_TILES:
  $SIMD_TILE = NR // SIMD_SIZE

  // Pack pre-transposed weights (GIO) for use by f32-gemm
  void xnn_x32_packw_gemm_gio_ukernel_x${NR}__${ARCH}_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
    size_t g,                  // Batch size (outer loop).  usually 1
    size_t nc,                 // Number of columns and typically large
    size_t kc,                 // Number of rows and typically small
    size_t nr,                 // Matches gemm and is a multiple of vector sizes
    size_t kr,                 // unused - must be 1
    size_t sr,                 // unused - must be 1
    size_t k_stride,           // Elements per row (typically same as nc)
    const uint32_t* weights,   // Weights to pack. unaligned, unpadded
    const uint32_t* bias,      // Bias to pack. unaligned, unpadded, can be NULL
    const void* scale,         // unused
    uint32_t* packed_weights,  // packed weights output buffer - aligned, padded
    size_t extra_bytes,        // number of extra bytes between weights. aligned
    const void* params)        // unused
  {
    assert(g != 0);
    assert(nc != 0);
    assert(kc != 0);
    assert(nr == ${NR});   // This kernel is for NR=${NR}
    assert(kr == 1);
    assert(sr == 1);
    assert(k_stride != 0);
    assert(weights != NULL);
    assert(packed_weights != NULL);

    const xnn_simd_s32_t vzero = xnn_set1_s32(0);
    const int32_t* b = (const int32_t*) bias;
    int32_t* packed_w = (int32_t*) packed_weights;
    do {
      // NC main loop multiple of ${NR}
      const int32_t* w = (const int32_t*) weights;
      size_t n = nc;

      for (; n >= ${NR}; n -= ${NR}) {
        if XNN_LIKELY(b != NULL) {
          $for N in range(SIMD_TILE):
            const xnn_simd_s32_t vb${N} = xnn_loadu_s32(b + ${N*SIMD_SIZE});
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vb${N});
          b += ${NR};
        } else {
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vzero);
        }
        packed_w += ${NR};

        // KC main loop ${KBLOCK}x${NR}
        size_t k = kc;
        $if KBLOCK > 1:
          for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
            $for K in range(KBLOCK):
              $for N in range(SIMD_TILE):
                const xnn_simd_s32_t v${N}_${K} = xnn_loadu_s32(w + ${N*SIMD_SIZE} + ${K} * k_stride);
            $if PREFETCH:
              $for K in range(KBLOCK):
                $for N in range(SIMD_TILE):
                  xnn_prefetch_to_l1((const int8_t*) w + ${N*4+960});
            $for K in range(KBLOCK):
              $for N in range(SIMD_TILE):
                xnn_storeu_s32(packed_w + ${N*SIMD_SIZE+K*NR}, v${N}_${K});
            w += k_stride * ${KBLOCK};
            packed_w += ${NR*KBLOCK};
          }

        // KC remainder loop
        for (; k > 0; --k) {
          $for N in range(SIMD_TILE):
            const xnn_simd_s32_t v${N} = xnn_loadu_s32(w + ${N*SIMD_SIZE});
          $if PREFETCH:
            $for N in range(SIMD_TILE):
              xnn_prefetch_to_l1((const int8_t*) w + ${N*4+960});
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, v${N});
          w += k_stride;
          packed_w += ${NR};
        }
        w = w - kc * k_stride + ${NR};  // Advance to next column of ${NR} int32_t
      }

      // NC remainder (1..${NR-1})
      if XNN_UNLIKELY(n != 0) {
        assert(n >= 1);
        assert(n <= ${NR-1});

        // Prepare count for valid 32-bit elements (depends on n).
        $for N in range(SIMD_TILE):
          $if SIMD_TILE == 1:
            const size_t vcount0 = n;
          $else:
            const size_t vcount${N} = (int) (n - ${N*SIMD_SIZE}) < 0 ? 0 : ((int) (n - ${N*SIMD_SIZE}) > ${SIMD_SIZE} ? ${SIMD_SIZE} : n - ${N*SIMD_SIZE});

        if XNN_LIKELY(b != NULL) {
          $for N in range(SIMD_TILE):
            const xnn_simd_s32_t vb${N} = xnn_load_tail_no_oob_s32(b + ${N*SIMD_SIZE}, vcount${N});
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vb${N});
          b += n;
        } else {
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vzero);
        }
        packed_w += ${NR};

        // KC main loop
        for (size_t k = kc; k > 0; --k) {
          $for N in range(SIMD_TILE):
            const xnn_simd_s32_t v${N} = xnn_load_tail_no_oob_s32(w + ${N*SIMD_SIZE}, vcount${N});
          $for N in range(SIMD_TILE):
            xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, v${N});
          w += k_stride;
          packed_w += ${NR};
        }
      }
      weights += nc * kc;
    } while (--g != 0);
  }
