// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <stddef.h>

#include "xnnpack/math.h"
#include "xnnpack/packb.h"
#include "xnnpack/unaligned.h"

$assert CHANNEL_TILE >= 1
$assert CHANNEL_SUBTILE >= 1
$assert CHANNEL_TILE % CHANNEL_SUBTILE == 0
$CHANNEL_ROUND = 1
$assert TYPE in ["int8_t", "uint16_t", "uint32_t", "float"]
$assert BIAS in [0, 1]
$TYPE_SUFFIX = {"int8_t": "int", "uint16_t": "int", "uint32_t": "int", "float": "float"}[TYPE]
$BITS = {"int8_t": 8, "uint16_t": 16, "uint32_t": 32, "float": 32}[TYPE]
$BTYPE = {"int8_t": "int32_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "float"}[TYPE]
$WTYPE = {"int8_t": "int8_t", "uint16_t": "uint16_t", "uint32_t": "uint32_t", "float": "uint32_t"}[TYPE]
$NAME = {0: "zerob", 1: "packb"}[BIAS]
$UNALIGNED_STORE = {"uint32_t": "unaligned_indexed_store_u32", "float": "unaligned_indexed_store_f32"}[TYPE]
void xnn_x${BITS}_${NAME}_gemm_ukernel_${CHANNEL_TILE}c${CHANNEL_SUBTILE}s${CHANNEL_ROUND}r__scalar_${TYPE_SUFFIX}(
  size_t groups,
  size_t channels,
  $if BIAS:
    $if BITS == 8:
      const int32_t* bias,
    $else:
      const ${WTYPE}* bias,
  ${WTYPE}* packed_weights,
  size_t channel_tile_stride,
  size_t channel_subtile_stride,
  const struct xnn_x32_packb_params params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(groups != 0);
  assert(channels != 0);
  assert(packed_weights != NULL);

  ${TYPE}* w = (${TYPE}*) packed_weights;
  $if BIAS:
    const ${BTYPE}* b = (const ${BTYPE}*) bias;
  $else:
    const ${TYPE} vzero = 0;
  do {
    // channel tile loop multiple of ${CHANNEL_TILE}
    size_t c = channels;
    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
      $if BIAS:
        $for N in range(CHANNEL_TILE):
         const ${BTYPE} b${N} = b[${N}];
        $for N in range(CHANNEL_TILE):
          $if BTYPE == TYPE:
            ${UNALIGNED_STORE}(w, ${N}, b${N});
          $else:
            ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, b${N});
        b += ${CHANNEL_TILE};
      $else:
        $for N in range(CHANNEL_TILE):
          $if BTYPE == TYPE:
            ${UNALIGNED_STORE}(w, ${N}, vzero);
          $else:
            ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, vzero);

      w = (${TYPE}*) ((uintptr_t) w + channel_tile_stride);
    }

    $if CHANNEL_TILE != CHANNEL_SUBTILE:
      // channel subtile loop multiple of ${CHANNEL_SUBTILE}
      $if CHANNEL_TILE == 2 * CHANNEL_SUBTILE:
        if (c ${"!= 0" if CHANNEL_SUBTILE == 1 else ">= " + CHANNEL_SUBTILE}) {
          $if BIAS:
            $for N in range(CHANNEL_SUBTILE):
              const ${BTYPE} b${N} = b[${N}];
            $for N in range(CHANNEL_SUBTILE):
              $if BTYPE == TYPE:
                ${UNALIGNED_STORE}(w, ${N}, b${N});
              $else:
                ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, b${N});
            b += ${CHANNEL_SUBTILE};
          $else:
            $for N in range(CHANNEL_SUBTILE):
              $if BTYPE == TYPE:
                ${UNALIGNED_STORE}(w, ${N}, vzero);
              $else:
                ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, vzero);

          w = (${TYPE}*) ((uintptr_t) w + channel_subtile_stride);
        }
      $else:
        for (; c >= ${CHANNEL_SUBTILE}; c -= ${CHANNEL_SUBTILE}) {
          $if BIAS:
            $for N in range(CHANNEL_SUBTILE):
              const ${BTYPE} b${N} = b[${N}];
            $for N in range(CHANNEL_SUBTILE):
              $if BTYPE == TYPE:
                ${UNALIGNED_STORE}(w, ${N}, b${N});
              $else:
                ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, b${N});
            b += ${CHANNEL_SUBTILE};
          $else:
            $for N in range(CHANNEL_SUBTILE):
              $if BTYPE == TYPE:
                ${UNALIGNED_STORE}(w, ${N}, vzero);
              $else:
                ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, vzero);

          w = (${TYPE}*) ((uintptr_t) w + channel_subtile_stride);
        }

    $if CHANNEL_SUBTILE != 1:
      if XNN_UNLIKELY(c != 0) {
        $if CHANNEL_SUBTILE == 2:
          $if BTYPE == TYPE:
            $if BIAS:
              const ${BTYPE} b0 = b[0];
              ${UNALIGNED_STORE}(w, 0, b0);
            $else:
              ${UNALIGNED_STORE}(w, 0, vzero);
          $else:
            $if BIAS:
              const ${BTYPE} b0 = b[0];
              ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, b0);
            $else:
              ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, vzero);
          $if BIAS:
            b += 1;
          w = (${TYPE}*) ((uintptr_t) w + channel_subtile_stride);
        $else:
          // channels remainder (1..${CHANNEL_SUBTILE-1})
          ${TYPE}* prev_w = w;
          $for LOG2N in reversed(range(CHANNEL_SUBTILE.bit_length() - 1)):
            if (c & ${1 << LOG2N}) {
              $for N in range(1 << LOG2N):
                $if BIAS:
                  ${BTYPE} b${N} = b[${N}];
              $for N in range(1 << LOG2N):
                $if BTYPE == TYPE:
                  $if BIAS:
                    ${UNALIGNED_STORE}(w, ${N}, b${N});
                  $else:
                    ${UNALIGNED_STORE}(w, ${N}, vzero);
                $else:
                  $if BIAS:
                    ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, b${N});
                  $else:
                    ${UNALIGNED_STORE}((${BTYPE}*) w, ${N}, vzero);
              $if BIAS:
                b += ${1 << LOG2N};
              w += ${1 << LOG2N};
            }

          w = (${TYPE}*) ((uintptr_t) prev_w + channel_subtile_stride);
      }
  } while (--groups != 0);
}
