// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR > 1
$assert KR > 1
$assert TYPE in ["int8_t"]
$assert IZP in [0, 128]

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/packw.h"

$BITS = {"int8_t": 8, "uint16_t": 16, "uint32_t": 32, "float": 32}[TYPE]
$BTYPE = {"int8_t": "int32_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "float"}[TYPE]
$WTYPE = {"int8_t": "int8_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "uint32_t"}[TYPE]
void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__scalar(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const ${WTYPE}* weights,
  $if BITS == 8:
    const int32_t* bias,
  $else:
    const ${WTYPE}* bias,
  const void* scale,
  ${WTYPE}* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == ${KR});
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  ${TYPE}* out = (${TYPE}*) packed_weights;
  const ${BTYPE}* b = (const ${BTYPE}*) bias;
  $if BITS == 8:
    const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP});

  do {
    // NC main loop multiple of ${NR}
    const ${TYPE}* w0 = (const ${TYPE}*) weights;
    size_t n = nc;
    for (;n >= ${NR}; n -= ${NR}) {
      $if BITS == 8:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = b[${N}];
          $else:
            ((${BTYPE}*) out)[${N}] = b[${N}];
        b += ${NR};
      } else {
        $for N in range(NR):
          $if BTYPE == TYPE:
            out[${N}] = 0;
          $else:
            ((${BTYPE}*) out)[${N}] = 0;
      }
      $if BTYPE == TYPE:
        out += ${NR};
      $else:
        out += ${NR} * sizeof(${BTYPE});

      $for N in range(1, NR):
        const ${TYPE}* w${N} = w${N-1} + kc;
      $if BITS == 8:
        $for N in range(NR):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(NR):
          $for K in range(KR):
            const ${TYPE} v${N}x${K} = w${N}[${K}];
          $for K in range(KR):
            $if BITS == 8:
              ksum${N} += (uint32_t) v${N}x${K};
          $for K in range(KR):
            out[${N*KR+K}] = v${N}x${K};
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});
        $for N in range(NR):
          const ${TYPE} v${N}x0 = w${N}[0];
          $if BITS == 8:
            ksum${N} += (uint32_t) v${N}x0;
          out[${N*KR}] = v${N}x0;
          $for K in range(1, KR):
            if (${K} < k) {
              const ${TYPE} v${N}x${K} = w${N}[${K}];
              $if BITS == 8:
                ksum${N} += (uint32_t) v${N}x${K};
              out[${N*KR+K}] = v${N}x${K};
            }
          w${N} += k;
        out += ${NR*KR};
      }

      $if BITS == 8:
        $for N in range(NR):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      $if BITS == 8:
        int32_t* packed_b = (int32_t*) out;
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = *b++;
          $else:
            *((${BTYPE}*) out) = *b++;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      } else {
        size_t nb = n;
        do {
          $if BTYPE == TYPE:
            *out++ = 0;
          $else:
            *((${BTYPE}*) out) = 0;
            out += sizeof(${BTYPE});
        } while (--nb != 0);
      }
      $if BTYPE == TYPE:
        out += (${NR} - n);
      $else:
        out += (${NR} - n) * sizeof(${BTYPE});

     $if NR > 2:
        // NR remainder has less than ${NR} rows so last row is not loaded
      $for N in range(1, NR-1):
        const ${TYPE}* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      $if BITS == 8:
        $for N in range(NR-1):
          uint32_t ksum${N} = 0;

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        $for N in range(NR-1):
          $for K in range(KR):
            const ${TYPE} v${N}x${K} = w${N}[${K}];
          $for K in range(KR):
            $if BITS == 8:
              ksum${N} += (uint32_t) v${N}x${K};
          $for K in range(KR):
            out[${N*KR+K}] = v${N}x${K};
          w${N} += ${KR};
        out += ${NR*KR};
      }

      // KC remainder of 1..${KR-1}
      if (k != 0) {
        assert(k >= 1 && k <= ${KR-1});
        $for N in range(NR-1):
          const ${TYPE} v${N}x0 = w${N}[0];
          $if BITS == 8:
            ksum${N} += (uint32_t) v${N}x0;
          out[${N*KR}] = v${N}x0;
          $for K in range(1, KR):
            if (${K} < k) {
              const ${TYPE} v${N}x${K} = w${N}[${K}];
              $if BITS == 8:
                ksum${N} += (uint32_t) v${N}x${K};
              out[${N*KR+K}] = v${N}x${K};
            }
          w${N} += k;
        out += ${NR*KR};
      }

      $if BITS == 8:
        $for N in range(NR-1):
          packed_b[${N}] -= ksum${N} * izp;
      out = (${TYPE}*) ((uintptr_t) out + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
