// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR > 1
$assert KR > 1
$assert TYPE in ["int8_t"]
$assert IZP in [0, 128]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/packw.h"

$BTYPE = {"QS8": "int32_t", "X8": "uint32_t"}[DATATYPE]
$WTYPE = {"int8_t": "int8_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "uint32_t"}[TYPE]
void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_gio_ukernel_x${NR}c${KR}__scalar(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  size_t k_stride,
  const ${WTYPE}* weights,
  const ${BTYPE}* bias,
  const void* scale,
  ${WTYPE}* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == ${KR});
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  ${TYPE}* out = (${TYPE}*) packed_weights;
  const ${BTYPE}* b = (const ${BTYPE}*) bias;
  $if DATATYPE in ["QS8"]:
    const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP});

  do {
    // NC main loop multiple of ${NR}
    const ${TYPE}* w0 = (const ${TYPE}*) weights;
    size_t n = nc;
    for (;n >= ${NR}; n -= ${NR}) {
      $if DATATYPE in ["QS8"]:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = b[${N}];
          $else:
            ((${BTYPE}*) out)[${N}] = b[${N}];
        b += ${NR};
      } else {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = 0;
          $else:
            ((${BTYPE}*) out)[${N}] = 0;
      }
      $if BTYPE == TYPE:
        out += ${NR};
      $else:
        out += ${NR} * sizeof(${BTYPE});

      $for K in range(1, KR):
        const ${TYPE}* w${K} = w${K-1} + k_stride;
      $if DATATYPE in ["QS8"]:
        $for N in range(NR):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(NR):
          $for K in range(KR):
            const ${TYPE} v${K}x${N} = w${K}[${N}];
          $for K in range(KR):
            $if DATATYPE in ["QS8"]:
              ksum${N} += (uint32_t) v${K}x${N};
          $for K in range(KR):
            out[${N*KR+K}] = v${K}x${N};
        $for K in range(KR):
          w${K} += ${KR} * k_stride;
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});
        $for N in range(NR):
          const ${TYPE} v0x${N} = w0[${N}];
          $if DATATYPE in ["QS8"]:
            ksum${N} += (uint32_t) v0x${N};
          out[${N*KR}] = v0x${N};
          $for K in range(1, KR):
            if (${K} < k) {
              const ${TYPE} v${K}x${N} = w${K}[${N}];
              $if DATATYPE in ["QS8"]:
                ksum${N} += (uint32_t) v${K}x${N};
              out[${N*KR+K}] = v${K}x${N};
            }
        $for K in range(KR):
          w${K} += k * k_stride;
        out += ${NR*KR};
      }

      $if DATATYPE in ["QS8"]:
        $for N in range(NR):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
      w0 = w0 - kc * k_stride + ${NR};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      $if DATATYPE in ["QS8"]:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = *b++;
          $else:
            *((${BTYPE}*) out) = *b++;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      } else {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = 0;
          $else:
            *((${BTYPE}*) out) = 0;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      }
      $if BTYPE == TYPE:
        out += (${NR} - n);
      $else:
        out += (${NR} - n) * sizeof(${BTYPE});

     $if NR > 2:
        // NR remainder has less than ${NR} rows so last row is not loaded
      $for K in range(1, KR):
        const ${TYPE}* w${K} = w${K-1} + k_stride;

      $if DATATYPE in ["QS8"]:
        $for N in range(NR-1):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for K in range(KR):
          const ${TYPE} v${K}x0 = w${K}[0];
        $for K in range(KR):
          $if DATATYPE in ["QS8"]:
            ksum0 += (uint32_t) v${K}x0;
        $for K in range(KR):
          out[${K}] = v${K}x0;
        $for N in range(1, NR-1):
          if (${N} < n) {
            $for K in range(KR):
              const ${TYPE} v${K}x${N} = w${K}[${N}];
            $for K in range(KR):
              $if DATATYPE in ["QS8"]:
                ksum${N} += (uint32_t) v${K}x${N};
            $for K in range(KR):
              out[${N*KR+K}] = v${K}x${N};
          }
        $for K in range(KR):
          w${K} += ${KR} * k_stride;
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});
        $for N in range(NR-1):
          if (${N} < n) {
            const ${TYPE} v0x${N} = w0[${N}];
            $if DATATYPE in ["QS8"]:
              ksum${N} += (uint32_t) v0x${N};
            out[${N*KR}] = v0x${N};
            $for K in range(1, KR):
              if (${K} < k) {
                const ${TYPE} v${K}x${N} = w${K}[${N}];
                $if DATATYPE in ["QS8"]:
                  ksum${N} += (uint32_t) v${K}x${N};
                out[${N*KR+K}] = v${K}x${N};
              }
          }
        $for K in range(KR):
          w${K} += k * k_stride;
        out += ${NR*KR};
      }

      $if DATATYPE in ["QS8"]:
        $for N in range(NR-1):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
