// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR > 1
$assert KBLOCK >= 1
$assert TYPE in ["int8_t", "uint16_t", "uint32_t", "float"]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/math.h"
#include "xnnpack/packw.h"

$TYPE_SUFFIX = {"int8_t": "int", "uint16_t": "int", "uint32_t": "int", "float": "float"}[TYPE]
$BITS = {"int8_t": 8, "uint16_t": 16, "uint32_t": 32, "float": 32}[TYPE]
$BTYPE = {"int8_t": "int32_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "float"}[TYPE]
$WTYPE = {"int8_t": "int8_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "uint32_t"}[TYPE]


void xnn_x${BITS}_packw_gemm_goi_ukernel_x${NR}__scalar_${TYPE_SUFFIX}_u${KBLOCK}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const ${WTYPE}* weights,
  $if BITS == 8:
    const int32_t* bias,
  $else:
    const ${WTYPE}* bias,
  const void* scale,
  ${WTYPE}* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  ${TYPE}* out = (${TYPE}*) packed_weights;
  const ${BTYPE}* b = (const ${BTYPE}*) bias;
  $if BITS == 8:
    const uint32_t izp = (uint32_t) ((const struct xnn_qs8_packw_params*) params)->input_zero_point;

  do {
    // NC main loop multiple of ${NR}
    const ${TYPE}* w0 = (const ${TYPE}*) weights;
    size_t n = nc;
    for (;n >= ${NR}; n -= ${NR}) {
      $if BITS == 8:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = b[${N}];
          $else:
            ((${BTYPE}*) out)[${N}] = b[${N}];
        b += ${NR};
      } else {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = 0;
          $else:
            ((${BTYPE}*) out)[${N}] = 0;
      }
      $if BTYPE == TYPE:
        out += ${NR};
      $else:
        out += ${NR} * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;
      $if BITS == 8:
        $for N in range(NR):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $for N in range(NR):
          $for K in range(KBLOCK):
            const ${TYPE} v${N}${K} = w${N}[${K}];
            $if BITS == 8:
              ksum${N} += (uint32_t) v${N}${K};
          w${N} += ${KBLOCK};
        $for K in range(KBLOCK):
          $for N in range(NR):
            out[${N+K*NR}] = v${N}${K};
        out += ${NR*KBLOCK};
      }

      // KC remainder
      for (; k != 0; --k) {
        $for N in range(NR):
          const ${TYPE} v${N} = *w${N}++;
          $if BITS == 8:
            ksum${N} += (uint32_t) v${N};
          out[${N}] = v${N};
        out += ${NR};
      }
      $if BITS == 8:
        $for N in range(NR):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      $if BITS == 8:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = *b++;
          $else:
            *((${BTYPE}*) out) = *b++;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      } else {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = 0;
          $else:
            *((${BTYPE}*) out) = 0;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      }
      $if BTYPE == TYPE:
        out += (${NR} - n);
      $else:
        out += (${NR} - n) * sizeof(${BTYPE});

      $if NR > 2:
        // NR remainder has less than ${NR} rows so last row is not loaded
      $for N in range(1, NR-1):
        const ${TYPE}* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }
      $if BITS == 8:
        $for N in range(NR-1):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        $for N in range(NR-1):
          $for K in range(KBLOCK):
            const ${TYPE} v${N}${K} = w${N}[${K}];
            $if BITS == 8:
              ksum${N} += (uint32_t) v${N}${K};
          w${N} += ${KBLOCK};
        $for K in range(KBLOCK):
          $for N in range(NR-1):
            out[${N+K*NR}] = v${N}${K};
        out += ${NR*KBLOCK};
      }

      // KC remainder of 1..${KBLOCK-1}
      for (; k != 0; --k) {
        $for N in range(NR-1):
          const ${TYPE} v${N} = *w${N}++;
          $if BITS == 8:
            ksum${N} += (uint32_t) v${N};
          out[${N}] = v${N};
        out += ${NR};
      }
      $if BITS == 8:
        $for N in range(NR-1):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
