// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert NR % 8 == 0
$assert 8 <= NR <= 16
$assert REQUANTIZATION in ["FP32", "RNDNU"] or not REQUANTIZATION
$assert DATATYPE in ["QC8", "QS8", "QD8_F16", "QD8_F32"]
$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"
$assert DATATYPE != "QD8_F32" or not REQUANTIZATION
#include <assert.h>

#include <arm_neon.h>

#include "xnnpack/gemm.h"
$if REQUANTIZATION == "FP32" and ARMV8:
  #include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/math.h"


$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8": "qs8", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w"}[DATATYPE]
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else ""
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" and ARMV8 else "neon")
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8": "union xnn_qs8_conv_minmax_params", "QD8_F16": "union xnn_f16_minmax_params", "QD8_F32": "union xnn_f32_minmax_params"}[DATATYPE]
$OUT_T = {"QC8": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QS8": "int8_t"}[DATATYPE]
$ISA = "neonfp16arith" if DATATYPE in ["QD8_F16"] else "neonv8" if ARMV8 else "neon"
$ISA_SPEC = "" if DATATYPE in ["QD8_F16"] else "_mlal" if MLA else "_mull"
void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c2s4__${ISA}${ISA_SPEC}(
    size_t mr,
    size_t nc,
    size_t kc,
    const int8_t* restrict a,
    size_t a_stride,
    const void* restrict w,
    ${OUT_T}* restrict c,
    size_t cm_stride,
    size_t cn_stride,
    $if DATATYPE in ["QD8_F16", "QD8_F32"]:
      const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)],
      const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
    $else:
      const ${PARAMS_TYPE} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(mr != 0);
  assert(mr <= ${MR});
  assert(nc != 0);
  assert(kc != 0);
  assert(kc % sizeof(int8_t) == 0);
  assert(a != NULL);
  assert(w != NULL);
  assert(c != NULL);

  const int8_t* a0 = a;
  $if DATATYPE in ["QD8_F16"]:
    uint16_t* c0 = (uint16_t*) c;
  $else:
    ${OUT_T}* c0 = c;
  $for M in range(1, MR):
    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
    $if DATATYPE in ["QD8_F16"]:
      uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride);
    $else:
      ${OUT_T}* c${M} = (${OUT_T}*) ((uintptr_t) c${M-1} + cm_stride);
    $if M % 2 == 0:
      if XNN_UNPREDICTABLE(mr <= ${M}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }
    $elif M + 1 == MR:
      if XNN_UNPREDICTABLE(mr != ${M+1}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }
    $else:
      if XNN_UNPREDICTABLE(mr < ${M+1}) {
        a${M} = a${M-1};
        c${M} = c${M-1};
      }

  kc = round_up_po2(kc, 8 * sizeof(int8_t));
  do {
    $if DATATYPE in ["QD8_F16", "QD8_F32"]:
      $for M in range(0, MR, 2):
        $if M + 1 == MR:
          const int32x4_t vizp${M} = vld1q_dup_s32(&quantization_params[${M}].zero_point);
          $for N in range(0, NR, 4):
            $if M == 0:
              const int32x4_t vksum${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4;
            int32x4_t vacc${M}x${ABC[N:N+4]} = vmulq_s32(vksum${ABC[N:N+4]}, vizp${M});
        $else:
          const int32x4_t vizp${ABC[M:M+2]} = vld1q_s32(&quantization_params[${M}].zero_point);
          $for N in range(0, NR, 4):
            $if M == 0:
              const int32x4_t vksum${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4;
            int32x4_t vacc${M}x${ABC[N:N+4]} = vmulq_lane_s32(vksum${ABC[N:N+4]}, vget_low_s32(vizp${ABC[M:M+2]}), 0);
            int32x4_t vacc${M+1}x${ABC[N:N+4]} = vmulq_lane_s32(vksum${ABC[N:N+4]}, vget_high_s32(vizp${ABC[M:M+2]}), 0);
    $else:
      $for N in range(0, NR, 4):
        int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4;
      $for M in range(1, MR):
        $for N in range(0, NR, 4):
          int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};

    size_t k = kc;
    $if MLA:
      while (k >= 16 * sizeof(int8_t)) {
        $for M in range(MR):
          int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
          int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;

        $for K in range(4):
          $for N in range(0, NR, 4):
            const int8x8_t vb${ABC[N:N+4]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;

        $for K in range(4):
          $for N in range(0, NR, 4):
            $for M in range(MR):
              int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}x0, va${M}x0);
            const int8x8_t vb${ABC[N:N+4]}c${K}x1 = vld1_s8(w); w = (const int8_t*) w + 8;
            $for M in range(MR):
              vprod${M}x${ABC[N:N+4]}c${K} = vmlal_s8(vprod${M}x${ABC[N:N+4]}c${K}, vb${ABC[N:N+4]}c${K}x1, va${M}x1);
            $for M in range(MR):
              vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
          $if K + 1 != 4:
            $for M in range(MR):
              va${M}x0 = vext_s8(va${M}x0, va${M}x0, 2);
              va${M}x1 = vext_s8(va${M}x1, va${M}x1, 2);

        k -= 16 * sizeof(int8_t);
      }
    ${"if (k != 0)" if MLA else "do"} {
      $for M in range(MR):
        int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;

      $for K in range(4):
        $for N in range(0, NR, 4):
          const int8x8_t vb${ABC[N:N+4]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;

      $for K in range(4):
        $for N in range(0, NR, 4):
          $for M in range(MR):
            int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}x0, va${M}x0);
          $for M in range(MR):
            vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
        $if K + 1 != 4:
          $for M in range(MR):
            va${M}x0 = vext_s8(va${M}x0, va${M}x0, 2);

      $if not MLA:
        k -= 8 * sizeof(int8_t);
    }${"" if MLA else " while (k != 0);"}

    $if DATATYPE in ["QD8_F16", "QD8_F32"]:
      $for M in range(0, MR):
        $for N in range(0, NR, 4):
          float32x4_t vout${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]});

      $for M in range(0, MR, 2):
        $if M + 1 == MR:
          const float32x4_t vinput_scale${M} = vld1q_dup_f32(&quantization_params[${M}].inv_scale);
          $for N in range(0, NR, 4):
            vout${M}x${ABC[N:N+4]} = vmulq_f32(vout${M}x${ABC[N:N+4]}, vinput_scale${M});
        $else:
          const float32x4_t vinput_scale${ABC[M:M+2]} = vreinterpretq_f32_s32(vld1q_s32(&quantization_params[${M}].zero_point));
          $for N in range(0, NR, 4):
            vout${M}x${ABC[N:N+4]} = vmulq_lane_f32(vout${M}x${ABC[N:N+4]}, vget_low_f32(vinput_scale${ABC[M:M+2]}), 1);
            vout${M+1}x${ABC[N:N+4]} = vmulq_lane_f32(vout${M+1}x${ABC[N:N+4]}, vget_high_f32(vinput_scale${ABC[M:M+2]}), 1);

      $for N in range(0, NR, 4):
        const float32x4_t vfilter_output_scale${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;

      $for N in range(0, NR, 4):
        const float32x4_t vbias${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
        $if DATATYPE == "QD8_F16":
          $for M in range(MR):
            vout${M}x${ABC[N:N+4]} = vfmaq_f32(vbias${ABC[N:N+4]}, vout${M}x${ABC[N:N+4]}, vfilter_output_scale${ABC[N:N+4]});
        $else:
          #if XNN_ARCH_ARM64
            $for M in range(MR):
              vout${M}x${ABC[N:N+4]} = vfmaq_f32(vbias${ABC[N:N+4]}, vout${M}x${ABC[N:N+4]}, vfilter_output_scale${ABC[N:N+4]});
          #else
            $for M in range(MR):
              vout${M}x${ABC[N:N+4]} = vmlaq_f32(vbias${ABC[N:N+4]}, vout${M}x${ABC[N:N+4]}, vfilter_output_scale${ABC[N:N+4]});
          #endif

      $if DATATYPE in ["QD8_F16", "QC4_F16"]:
        $for M in range(0, MR):
          $for N in range(0, NR, 8):
            float16x8_t vfp16out${M}x${ABC[N:N+8]} = vcombine_f16(vcvt_f16_f32(vout${M}x${ABC[N:N+4]}), vcvt_f16_f32(vout${M}x${ABC[N+4:N+8]}));

        #if XNN_ARCH_ARM64
          const uint16x8x2_t voutput_minmax = vld2q_dup_u16((const uint16_t*) &params->scalar.min);
          const float16x8_t voutput_min = vreinterpretq_f16_u16(voutput_minmax.val[0]);
          const float16x8_t voutput_max = vreinterpretq_f16_u16(voutput_minmax.val[1]);
        #else
          const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) &params->scalar.min));
          const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16((const uint16_t*) &params->scalar.max));
        #endif
        $for M in range(0, MR):
          $for N in range(0, NR, 8):
            vfp16out${M}x${ABC[N:N+8]} = vmaxq_f16(vfp16out${M}x${ABC[N:N+8]}, voutput_min);
        $for M in range(0, MR):
          $for N in range(0, NR, 8):
            vfp16out${M}x${ABC[N:N+8]} = vminq_f16(vfp16out${M}x${ABC[N:N+8]}, voutput_max);
        if XNN_LIKELY(nc >= ${NR}) {
          $for M in range(MR):
            vst1q_u16(c${M}, vreinterpretq_u16_f16(vfp16out${M}x${ABC[0:8]}));
            $for N in range(8, NR, 8):
              vst1q_u16(c${M} + ${N}, vreinterpretq_u16_f16(vfp16out${M}x${ABC[N:N+8]}));

          $for M in range(MR):
            a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);

          $for M in range(MR):
            c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride);

          nc -= ${NR};
        } else {
         $for LOG2N in reversed(range(NR.bit_length())):
            $if NR != 1 << LOG2N:
              $if LOG2N == 2:
                $for M in range(MR):
                  float16x4_t vfp16out${M}x${ABC[N:N+4]} = vget_low_f16(vfp16out${M}x${ABC[N:N+8]});
              if (nc & ${1 << LOG2N}) {
                $if LOG2N > 2:
                  $for N in range(0, 1 << LOG2N, 8):
                    $for M in range(MR):
                      vst1q_u16(c${M}, vreinterpretq_u16_f16(vfp16out${M}x${ABC[N:N+8]})); c${M} += 8;
                      vfp16out${M}x${ABC[N:N+8]} = vfp16out${M}x${ABC[N+(1 << LOG2N):N+(1 << LOG2N)+8]};
                $elif LOG2N == 2:
                  $for M in range(MR):
                    vst1_u16(c${M}, vreinterpret_u16_f16(vfp16out${M}x${ABC[N:N+4]})); c${M} += 4;
                  $for M in range(MR):
                    vfp16out${M}x${ABC[N:N+4]} = vget_high_f16(vfp16out${M}x${ABC[N:N+8]});
                $elif LOG2N == 1:
                  $for M in range(MR):
                    vst1_lane_u32((void*) c${M}, vreinterpret_u32_f16(vfp16out${M}x${ABC[N:N+4]}), 0); c${M} += 2;
                  $for M in range(MR):
                    vfp16out${M}x${ABC[N:N+4]} = vext_f16(vfp16out${M}x${ABC[N:N+4]}, vfp16out${M}x${ABC[N:N+4]}, 2);
                $elif LOG2N == 0:
                  $for M in range(MR):
                    vst1_lane_u16(c${M}, vreinterpret_u16_f16(vfp16out${M}x${ABC[N:N+4]}), 0);
              }
          nc = 0;
        }
      $else:
        const float32x4_t voutput_min = vld1q_dup_f32(&params->scalar.min);
        $for M in range(0, MR):
          $for N in range(0, NR, 4):
            vout${M}x${ABC[N:N+4]} = vmaxq_f32(vout${M}x${ABC[N:N+4]}, voutput_min);

        const float32x4_t voutput_max = vld1q_dup_f32(&params->scalar.max);
        $for M in range(0, MR):
          $for N in range(0, NR, 4):
            vout${M}x${ABC[N:N+4]} = vminq_f32(vout${M}x${ABC[N:N+4]}, voutput_max);

        if XNN_LIKELY(nc >= ${NR}) {
          $for M in range(MR):
            vst1q_f32(c${M}, vout${M}x${ABC[0:4]});
            $for N in range(0, NR, 4):
              vst1q_f32(c${M} + ${N}, vout${M}x${ABC[N:N+4]});

          $for M in range(MR):
            a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);

          $for M in range(MR):
            c${M} = (float*) ((uintptr_t) c${M} + cn_stride);

          nc -= ${NR};
        } else {
          $for LOG2N in reversed(range(NR.bit_length())):
            $if NR != 1 << LOG2N:
              $if LOG2N == 1:
                $for M in range(MR):
                  float32x2_t vout${M}x${ABC[N:N+2]} = vget_low_f32(vout${M}x${ABC[N:N+4]});
              if (nc & ${1 << LOG2N}) {
                $if LOG2N > 1:
                  $for N in range(0, 1 << LOG2N, 4):
                    $for M in range(MR):
                      vst1q_f32(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4;
                      vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N+(1 << LOG2N):N+(1 << LOG2N)+4]};
                $elif LOG2N == 1:
                  $for M in range(MR):
                    vst1_f32(c${M}, vout${M}x${ABC[N:N+2]}); c${M} += 2;
                  $for M in range(MR):
                    vout${M}x${ABC[N:N+2]} = vget_high_f32(vout${M}x${ABC[N:N+4]});
                $elif LOG2N == 0:
                  $for M in range(MR):
                    vst1_lane_f32(c${M}, vout${M}x${ABC[N:N+2]}, 0);
              }
          nc = 0;
        }
    $else:
      $if REQUANTIZATION == "RNDNU":
        const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_pre_shift);
        const int32x4_t vmultiplier = vld1q_dup_s32(&params->${PARAMS_STRUCT}.multiplier);
        const int32x4_t vright_post_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_post_shift);

        $for M in range(MR):
          $for N in range(0, NR, 4):
            vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift);

        $for M in range(MR):
          $for N in range(0, NR, 4):
            vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);

        $for M in range(MR):
          $for N in range(0, NR, 4):
            vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift);
      $elif REQUANTIZATION == "FP32":
        $for M in range(MR):
          $for N in range(0, NR, 4):
            float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]});

        $if DATATYPE == "QC8":
          $for N in range(0, NR, 4):
            const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
            $for M in range(MR):
              vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]});
        $else:
          const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
          $for M in range(MR):
            $for N in range(0, NR, 4):
              vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale);

        $if ARMV8:
          $for M in range(MR):
            $for N in range(0, NR, 4):
              vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]});
        $else:
          const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
          $for M in range(MR):
            $for N in range(0, NR, 4):
              vacc${M}x${ABC[N:N+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${M}x${ABC[N:N+4]}, vmagic_bias));

          const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
          $for M in range(MR):
            $for N in range(0, NR, 4):
              vacc${M}x${ABC[N:N+4]} = vqsubq_s32(vacc${M}x${ABC[N:N+4]}, vmagic_bias_less_output_zero_point);

      $if REQUANTIZATION != "FP32" or ARMV8:
        const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
      #if XNN_ARCH_ARM64
        $for M in range(MR):
          $for N in range(0, NR, 8):
            int16x8_t vacc${M}x${ABC[N:N+8]} = vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]});

        $if REQUANTIZATION != "FP32" or ARMV8:
          $for M in range(MR):
            $for N in range(0, NR, 8):
              vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);

        $for M in range(MR):
          $for N in range(0, NR, 16):
            $if N + 8 < NR:
              int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
            $elif M % 2 == 1:
              int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
            $elif M + 1 == MR:
              int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
      #else
        $for M in range(MR):
          $for N in range(0, NR, 8):
            int16x8_t vacc${M}x${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]}));

        $if REQUANTIZATION != "FP32" or ARMV8:
          $for M in range(MR):
            $for N in range(0, NR, 8):
              vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);

        $for M in range(MR):
          $for N in range(0, NR, 16):
            $if N + 8 < NR:
              int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
            $elif M % 2 == 1:
              int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
            $elif M + 1 == MR:
              int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
      #endif

      $if NR == 8 and MR == 1:
        const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
      $else:
        const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
      $for M in range(MR):
        $for N in range(0, NR, 16):
          $if N + 8 < NR:
            vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
          $elif M % 2 == 1:
            vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
          $elif M + 1 == MR:
            $if NR == 8 and MR == 1:
              vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
            $else:
              vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));

      $if NR == 8 and MR == 1:
        const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
      $else:
        const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
      $for M in range(MR):
        $for N in range(0, NR, 16):
          $if N + 8 < NR:
            vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
          $elif M % 2 == 1:
            vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
          $elif M + 1 == MR:
            $if NR == 8 and MR == 1:
              vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
            $else:
              vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));

      if (nc >= ${NR}) {
        $for M in range(MR):
          $for N in range(0, NR, 16):
            $if N + 8 < NR:
              vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
            $elif M % 2 == 1:
              vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
              vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
            $elif M + 1 == MR:
              vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});

        $for M in range(MR):
          c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);

        $for M in range(MR):
          a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);

        nc -= ${NR};
      } else {
        // Final case where not all of the ${NR} columns fit in the destination.
        $if NR == 16:
          $for M in range(MR):
            $if M % 2 == 1:
              int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
            $elif M + 1 == MR:
              int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
          if (nc & 8) {
            $for M in range(MR):
              $if M % 2 == 1:
                vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8;
                vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8;
              $elif M + 1 == MR:
                vst1_s8(c${M}, vout${M}x01234567); c${M} += 8;
            $for M in range(MR):
              $if M % 2 == 1:
                vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
              $elif M + 1 == MR:
                vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
          }
        if (nc & 4) {
          $for M in range(MR):
            $if M % 2 == 1:
              vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
              vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
            $elif M + 1 == MR:
              vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
          $for M in range(MR):
            $if M % 2 == 1:
              vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
            $elif M + 1 == MR:
              vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
        }
        if (nc & 2) {
          $for M in range(MR):
            $if M % 2 == 1:
              vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
              vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
            $elif M + 1 == MR:
              vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
          $for M in range(MR):
            $if M % 2 == 1:
              vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
            $elif M + 1 == MR:
              vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
        }
        if (nc & 1) {
          $for M in range(MR):
            $if M % 2 == 1:
              vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
              vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
            $elif M + 1 == MR:
              vst1_lane_s8(c${M}, vout${M}x01234567, 0);
        }

        nc = 0;
      }
  } while (nc != 0);
}
