// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert REQUANTIZATION in ["FP32", "RNDNU"]
$assert REQUANTIZATION != "FP32" or VARIANT in ["FMAGIC", "IMAGIC", "LRINTF"]
$assert DATATYPE in ["QC8", "QS8", "QU8"]
$assert CHANNEL_TILE >= 1
$CHANNEL_SUBTILE = 1
$assert CHANNEL_TILE % CHANNEL_SUBTILE == 0
$CHANNEL_ROUND = 1
$assert MIDDLE_PASS_TILE <= LAST_PASS_TILE
$assert FIRST_PASS_TILE >= 1
$assert MIDDLE_PASS_TILE >= 1
$assert LAST_PASS_TILE >= 1
#include <assert.h>
$if VARIANT == "LRINTF":
  #include <math.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/common.h"
#include "xnnpack/dwconv.h"
#include "xnnpack/math.h"
#include "xnnpack/microparams.h"
$if (CHANNEL_TILE % 4 != 0) or ((LAST_PASS_TILE * CHANNEL_SUBTILE % 4) != 0):
  #include "xnnpack/unaligned.h"


$DATATYPE_SPEC = {"QS8": "qs8", "QC8": "qs8_qc8w", "QU8": "qu8"}[DATATYPE]
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar"
$PARAMS_TYPE = "union xnn_qs8_qc8w_conv_minmax_params" if DATATYPE == "QC8" else "union xnn_%s_conv_minmax_params" % DATATYPE.lower()
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
void xnn_${DATATYPE_SPEC}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_${FIRST_PASS_TILE}f${MIDDLE_PASS_TILE}m${LAST_PASS_TILE}l${CHANNEL_TILE}c${CHANNEL_SUBTILE}s${CHANNEL_ROUND}r__${"wasm" if WASM else "scalar"}${"_" + VARIANT.lower() if VARIANT else ""}(
    size_t channels,
    size_t output_width,
    const ${XINT8_T}** input,
    const void* weights,
    ${XINT8_T}* output,
    intptr_t input_stride,
    size_t output_increment,
    size_t input_offset,
    const ${XINT8_T}* zero,
    size_t kernel_size,
    int32_t* buffer,
    const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(channels != 0);
  assert(output_width != 0);
  assert(kernel_size > ${FIRST_PASS_TILE});

  $if REQUANTIZATION == "FP32":
    $if DATATYPE != "QC8":
      const float vscale = params->${PARAMS_STRUCT}.scale;
    $if VARIANT == "FMAGIC":
      const float voutput_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const float voutput_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const float vmagic_bias = 12582912.0f;
      const int32_t vmagic_bias_less_output_zero_point = INT32_C(0x4B400000) - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
    $elif VARIANT == "IMAGIC":
      const float vmagic_bias = 12582912.0f;
      const int32_t output_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const int32_t output_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const int32_t vmagic_min = (int32_t) float_as_uint32(12582912.0f + output_min_less_zero_point);
      const int32_t vmagic_max = (int32_t) float_as_uint32(12582912.0f + output_max_less_zero_point);
      const int32_t vmagic_bias_less_zero_point = INT32_C(0x4B400000) - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
    $elif VARIANT == "LRINTF":
      const float voutput_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const float voutput_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
      const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point;
  $elif REQUANTIZATION == "RNDNU":
    const int32_t vmultiplier = params->${PARAMS_STRUCT}.multiplier;
    const int64_t vrounding = params->${PARAMS_STRUCT}.rounding;
    const uint32_t vshift = params->${PARAMS_STRUCT}.shift;
    const int32_t voutput_min_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_min - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
    const int32_t voutput_max_less_zero_point = (int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point;
    const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point;
  $if DATATYPE == "QU8":
    const int32_t vkernel_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point;
  do {
    const void* w = weights;

    // First pass to process ${FIRST_PASS_TILE} inputs.
    {
      int32_t* b = buffer;
      $for K in range(FIRST_PASS_TILE):
        const ${XINT8_T}* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
        }
      input += ${FIRST_PASS_TILE};

      size_t c = channels;
      $if CHANNEL_TILE == 1:
        do {
          int32_t vacc = unaligned_load_s32(w);
          $for K in range(FIRST_PASS_TILE):
            $if DATATYPE == "QU8":
              const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
            $else:
              const int32_t vi${K} = (int32_t) *i${K}++;
            $if DATATYPE == "QU8":
              const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}] - vkernel_zero_point;
            $else:
              const int32_t vk${K} = ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}];
            vacc += vi${K} * vk${K};

          w = (const void*) ((uintptr_t) w + sizeof(int32_t) + ${FIRST_PASS_TILE} * sizeof(${XINT8_T}));
          *b++ = vacc;
        } while (--c != 0);
      $else:
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $if CHANNEL_TILE % 4 != 0:
            $for C in range(CHANNEL_TILE):
              int32_t vacc${C} = unaligned_indexed_load_s32(w, ${C});
          $else:
            $for C in range(CHANNEL_TILE):
              int32_t vacc${C} = ((const int32_t*) w)[${C}];
          $for K in range(FIRST_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}];
              $else:
                const int32_t vi${K}x${C} = (int32_t) i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}] - vkernel_zero_point;
              $else:
                const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}];

            $for C in range(CHANNEL_TILE):
              vacc${C} += vi${K}x${C} * vk${K}x${C};

          w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${FIRST_PASS_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
          $for C in range(CHANNEL_TILE):
            b[${C}] = vacc${C};
          b += ${CHANNEL_TILE};
        }
        if XNN_UNLIKELY(c != 0) {
          $if CHANNEL_TILE == 2:
            int32_t vacc = unaligned_load_s32(w);

            $for K in range(FIRST_PASS_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K} = (int32_t) (uint32_t) *i${K};
              $else:
                const int32_t vi${K} = (int32_t) *i${K};
              $if DATATYPE == "QU8":
                const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t)))[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
              $else:
                const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t)))[${K * CHANNEL_SUBTILE}];
              vacc += vi${K} * vk${K};

            w = (const void*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t) + ${FIRST_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T}));
            *b++ = vacc;
          $else:
            do {
              int32_t vacc = unaligned_load_s32(w);
              $for K in range(FIRST_PASS_TILE):
                $if DATATYPE == "QU8":
                  const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
                $else:
                  const int32_t vi${K} = (int32_t) *i${K}++;
                $if DATATYPE == "QU8":
                  const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t)))[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
                $else:
                  const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t)))[${K * CHANNEL_SUBTILE}];
                vacc += vi${K} * vk${K};

              w = (const void*) ((uintptr_t) w + ${CHANNEL_SUBTILE} * sizeof(int32_t) + ${FIRST_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T}));
              *b++ = vacc;
            } while (--c != 0);
        }
    }

    // Middle pass to process ${MIDDLE_PASS_TILE} inputs in each iteration.
    for (size_t ks = kernel_size - ${FIRST_PASS_TILE}; ks > ${LAST_PASS_TILE}; ks -= ${MIDDLE_PASS_TILE}) {
      int32_t* b = buffer;
      $for K in range(MIDDLE_PASS_TILE):
        const ${XINT8_T}* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
        }
      input += ${MIDDLE_PASS_TILE};

      size_t c = channels;
      $if CHANNEL_TILE == 1:
        do {
          int32_t vacc = *b;
          $for K in range(MIDDLE_PASS_TILE):
            $if DATATYPE == "QU8":
              const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
            $else:
              const int32_t vi${K} = (int32_t) *i${K}++;
            $if DATATYPE == "QU8":
              const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K}] - vkernel_zero_point;
            $else:
              const int32_t vk${K} = ((const ${XINT8_T}*) w)[${K}];
            vacc += vi${K} * vk${K};

          w = (const void*) ((uintptr_t) w + ${MIDDLE_PASS_TILE} * sizeof(${XINT8_T}));
          *b++ = vacc;
        } while (--c != 0);
      $else:
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $for C in range(CHANNEL_TILE):
            int32_t vacc${C} = b[${C}];
          $for K in range(MIDDLE_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}];
              $else:
                const int32_t vi${K}x${C} = (int32_t) i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_TILE + C}] - vkernel_zero_point;
              $else:
                const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_TILE + C}];

            $for C in range(CHANNEL_TILE):
              vacc${C} += vi${K}x${C} * vk${K}x${C};

          w = (const void*) ((uintptr_t) w + ${MIDDLE_PASS_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
          $for C in range(CHANNEL_TILE):
            b[${C}] = vacc${C};
          b += ${CHANNEL_TILE};
        }
        if XNN_UNLIKELY(c != 0) {
          $if CHANNEL_TILE == 2:
            int32_t vacc = b[0];

            $for K in range(MIDDLE_PASS_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K} = (int32_t) (uint32_t) *i${K};
              $else:
                const int32_t vi${K} = (int32_t) *i${K};
              $if DATATYPE == "QU8":
                const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
              $else:
                const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}];
              vacc += vi${K} * vk${K};
            w = (const void*) ((uintptr_t) w + ${MIDDLE_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T}));
            *b++ = vacc;
          $else:
            do {
              int32_t vacc = *b;
              $for K in range(MIDDLE_PASS_TILE):
                $if DATATYPE == "QU8":
                  const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
                $else:
                  const int32_t vi${K} = (int32_t) *i${K}++;
                $if DATATYPE == "QU8":
                  const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
                $else:
                  const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}];
                vacc += vi${K} * vk${K};

              w = (const void*) ((uintptr_t) w + ${MIDDLE_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T}));
              *b++ = vacc;
            } while (--c != 0);
        }
    }

    // Last pass to process up to ${LAST_PASS_TILE} inputs.
    {
      const int32_t* b = buffer;
      $for K in range(LAST_PASS_TILE):
        const ${XINT8_T}* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
        }

      size_t c = channels;
      $if CHANNEL_TILE == 1:
        do {
          int32_t vacc = unaligned_load_s32(b++);
          $for K in range(LAST_PASS_TILE):
            $if DATATYPE == "QU8":
              const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
            $else:
              const int32_t vi${K} = (int32_t) *i${K}++;
            $if DATATYPE == "QU8":
              const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K}] - vkernel_zero_point;
            $else:
              const int32_t vk${K} = ((const ${XINT8_T}*) w)[${K}];
            vacc += vi${K} * vk${K};

          w = (const void*) ((uintptr_t) w + ${LAST_PASS_TILE} * sizeof(${XINT8_T}));

          $if REQUANTIZATION == "FP32":
            $if DATATYPE == "QC8":
              const float vscale = unaligned_load_f32(w);
              w = (const void*) ((const float*) w + 1);
            float vfpacc = (float) vacc * vscale;

            $if VARIANT == "FMAGIC":
              vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
              vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
              vfpacc += vmagic_bias;
              int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
            $elif VARIANT == "IMAGIC":
              vfpacc += vmagic_bias;
              int32_t vout = (int32_t) float_as_uint32(vfpacc);
              vout = math_max_s32(vout, vmagic_min);
              vout = math_min_s32(vout, vmagic_max);
              vout -= vmagic_bias_less_zero_point;
            $elif VARIANT == "LRINTF":
              vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
              vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
              const int32_t vrndacc = (int32_t) lrintf(vfpacc);
              int32_t vout = vrndacc + voutput_zero_point;
          $elif REQUANTIZATION == "RNDNU":
            const int64_t vextacc = math_mulext_s32(vacc, vmultiplier) + vrounding;
            int32_t vout = (int32_t) math_asr_s64(vextacc, vshift);
            vout = math_max_s32(vout, voutput_min_less_zero_point);
            vout = math_min_s32(vout, voutput_max_less_zero_point);
            vout += voutput_zero_point;

          *output++ = (${XINT8_T}) vout;
        } while (--c != 0);
      $else:
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $for C in range(CHANNEL_TILE):
            int32_t vacc${C} = b[${C}];
          b += ${CHANNEL_TILE};
          $for K in range(LAST_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}];
              $else:
                const int32_t vi${K}x${C} = (int32_t) i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              $if DATATYPE == "QU8":
                const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_TILE + C}] - vkernel_zero_point;
              $else:
                const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_TILE + C}];

            $for C in range(CHANNEL_TILE):
              vacc${C} += vi${K}x${C} * vk${K}x${C};

          w = (const void*) ((uintptr_t) w + ${LAST_PASS_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));

          $if REQUANTIZATION == "FP32":
            $for C in range(CHANNEL_TILE):
              float vfpacc${C} = (float) vacc${C};

            $if DATATYPE == "QC8":
              $for C in range(CHANNEL_TILE):
                const float vscale${C} = unaligned_indexed_load_f32(w, ${C});
              w = (const void*) ((const float*) w + ${CHANNEL_TILE});

              $for C in range(CHANNEL_TILE):
                vfpacc${C} *= vscale${C};
            $else:
              $for C in range(CHANNEL_TILE):
                vfpacc${C} *= vscale;

            $if VARIANT == "FMAGIC":
              $for C in range(CHANNEL_TILE):
                vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);

              $for C in range(CHANNEL_TILE):
                vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);

              $for C in range(CHANNEL_TILE):
                vfpacc${C} += vmagic_bias;

              $for C in range(CHANNEL_TILE):
                int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point;
            $elif VARIANT == "IMAGIC":
              $for C in range(CHANNEL_TILE):
                vfpacc${C} += vmagic_bias;

              $for C in range(CHANNEL_TILE):
                int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C});

              $for C in range(CHANNEL_TILE):
                vout${C} = math_max_s32(vout${C}, vmagic_min);

              $for C in range(CHANNEL_TILE):
                vout${C} = math_min_s32(vout${C}, vmagic_max);

              $for C in range(CHANNEL_TILE):
                vout${C} -= vmagic_bias_less_zero_point;
            $elif VARIANT == "LRINTF":
              $for C in range(CHANNEL_TILE):
                vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);

              $for C in range(CHANNEL_TILE):
                vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);

              $for C in range(CHANNEL_TILE):
                const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C});

              $for C in range(CHANNEL_TILE):
                int32_t vout${C} = (int32_t) vrndacc${C} + voutput_zero_point;
          $elif REQUANTIZATION == "RNDNU":
            $for C in range(CHANNEL_TILE):
              const int64_t vextacc${C} = math_mulext_s32(vacc${C}, vmultiplier) + vrounding;

            $for C in range(CHANNEL_TILE):
              int32_t vout${C} = (int32_t) math_asr_s64(vextacc${C}, vshift);

            $for C in range(CHANNEL_TILE):
              vout${C} = math_max_s32(vout${C}, voutput_min_less_zero_point);

            $for C in range(CHANNEL_TILE):
              vout${C} = math_min_s32(vout${C}, voutput_max_less_zero_point);

            $for C in range(CHANNEL_TILE):
              vout${C} += voutput_zero_point;

          $for C in range(CHANNEL_TILE):
            output[${C}] = (${XINT8_T}) vout${C};
          output += ${CHANNEL_TILE};
        }
        if XNN_UNLIKELY(c != 0) {
          $if CHANNEL_TILE == 2:
            int32_t vacc = b[0];

            $for K in range(LAST_PASS_TILE):
              $if DATATYPE == "QU8":
                const int32_t vi${K} = (int32_t) (uint32_t) *i${K};
              $else:
                const int32_t vi${K} = (int32_t) *i${K};
              $if DATATYPE == "QU8":
                const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
              $else:
                const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}];
              vacc += vi${K} * vk${K};

            $if REQUANTIZATION == "FP32":
              $if DATATYPE == "QC8":
                const float vscale = unaligned_load_f32((const void*) ((uintptr_t) w + ${LAST_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T})));
              float vfpacc = (float) vacc * vscale;

              $if VARIANT == "FMAGIC":
                vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
                vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
                vfpacc += vmagic_bias;
                int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
              $elif VARIANT == "IMAGIC":
                vfpacc += vmagic_bias;
                int32_t vout = (int32_t) float_as_uint32(vfpacc);
                vout = math_max_s32(vout, vmagic_min);
                vout = math_min_s32(vout, vmagic_max);
                vout -= vmagic_bias_less_zero_point;
              $elif VARIANT == "LRINTF":
                vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
                vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
                const int32_t vrndacc = (int32_t) lrintf(vfpacc);
                int32_t vout = vrndacc + voutput_zero_point;
            $elif REQUANTIZATION == "RNDNU":
              const int64_t vextacc = math_mulext_s32(vacc, vmultiplier) + vrounding;
              int32_t vout = (int32_t) math_asr_s64(vextacc, vshift);
              vout = math_max_s32(vout, voutput_min_less_zero_point);
              vout = math_min_s32(vout, voutput_max_less_zero_point);
              vout += voutput_zero_point;

            *output++ = (${XINT8_T}) vout;
          $else:
            do {
              int32_t vacc = *b++;
              $for K in range(LAST_PASS_TILE):
                $if DATATYPE == "QU8":
                  const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++;
                $else:
                  const int32_t vi${K} = (int32_t) *i${K}++;
                $if DATATYPE == "QU8":
                  const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}] - vkernel_zero_point;
                $else:
                  const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) w)[${K * CHANNEL_SUBTILE}];
                vacc += vi${K} * vk${K};

              w = (const void*) ((uintptr_t) w + ${LAST_PASS_TILE * CHANNEL_SUBTILE} * sizeof(${XINT8_T}));

              $if REQUANTIZATION == "FP32":
                $if DATATYPE == "QC8":
                  const float vscale = unaligned_load_f32(w);
                  w = (const void*) ((const float*) w + ${CHANNEL_SUBTILE});
                float vfpacc = (float) vacc * vscale;

                $if VARIANT == "FMAGIC":
                  vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
                  vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
                  vfpacc += vmagic_bias;
                  int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
                $elif VARIANT == "IMAGIC":
                  vfpacc += vmagic_bias;
                  int32_t vout = (int32_t) float_as_uint32(vfpacc);
                  vout = math_max_s32(vout, vmagic_min);
                  vout = math_min_s32(vout, vmagic_max);
                  vout -= vmagic_bias_less_zero_point;
                $elif VARIANT == "LRINTF":
                  vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
                  vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
                  const int32_t vrndacc = (int32_t) lrintf(vfpacc);
                  int32_t vout = vrndacc + voutput_zero_point;
              $elif REQUANTIZATION == "RNDNU":
                const int64_t vextacc = math_mulext_s32(vacc, vmultiplier) + vrounding;
                int32_t vout = (int32_t) math_asr_s64(vextacc, vshift);
                vout = math_max_s32(vout, voutput_min_less_zero_point);
                vout = math_min_s32(vout, voutput_max_less_zero_point);
                vout += voutput_zero_point;

              *output++ = (${XINT8_T}) vout;
            } while (--c != 0);
        }
    }

    input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride);
    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
  } while (--output_width != 0);
}
