// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/packw.h"


void xnn_x32_packw_gemm_gio_ukernel_x${NR}__scalar(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  size_t k_stride,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});   // This kernel is for NR=${NR}
  assert(kr == 1);
  assert(sr == 1);
  assert(k_stride != 0);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  const float* b = (const float*) bias;
  float* packed_w = (float*) packed_weights;
  do {
    // NC main loop multiple of ${NR}
    const float* w = (const float*) weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(b != NULL) {
        $for N in range(0,NR,2):
          const uint64_t v${N//2} = ((const uint64_t*)b)[${N//2}];
        $for N in range(0,NR,2):
          ((uint64_t*)packed_w)[${N//2}] = v${N//2};
        b += ${NR};
      } else {
        $for N in range(0,NR,2):
          ((uint64_t*)packed_w)[${N//2}] = 0;
      }
      packed_w += ${NR};

      // KC main loop
      for (size_t k = kc; k > 0; --k) {
        $for N in range(0,NR,2):
          const uint64_t v${N//2} = ((const uint64_t*)w)[${N//2}];
        $for N in range(0,NR,2):
          ((uint64_t*)packed_w)[${N//2}] = v${N//2};
        w += k_stride;
        packed_w += ${NR};
      }
      w = w - kc * k_stride + ${NR};  // Advance to next column of ${NR} floats
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});

      if XNN_LIKELY(b != NULL) {
        for (size_t i = 0; i < n; ++i) {
          packed_w[i] = b[i];
        }
        b += n;
      } else {
        $for N in range(0,NR,2):
          ((uint64_t*)packed_w)[${N//2}] = 0;
      }
      packed_w += ${NR};

      // KC main loop
      for (size_t k = kc; k > 0; --k) {
        for (size_t i = 0; i < n; ++i) {
          packed_w[i] = w[i];
        }
        w += k_stride;
        packed_w += ${NR};
      }
    }
    weights += nc * kc;
  } while (--g != 0);
}
