// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR >= 4
$assert NR % 4 == 0
$assert KBLOCK >= 4
$assert KBLOCK % 4 == 0
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <emmintrin.h>

#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"

$ISA = "avx" if AVX else "sse2"
void xnn_x32_packw_gemm_goi_ukernel_x${NR}__${ISA}_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  const float* b = (const float*) bias;
  float* packed_w = (float*) packed_weights;
  do {
    // NC main loop multiple of ${NR}
    const float* w0 = (const float*) weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(b != NULL) {
        const __m128 vb${ABC[0:4]} = _mm_loadu_ps(b);
        $for N in range(4, NR, 4):
          const __m128 vb${ABC[N:N+4]} = _mm_loadu_ps(b + ${N});
        b += ${NR};

        _mm_store_ps(packed_w, vb${ABC[0:4]});
        $for N in range(4, NR, 4):
          _mm_store_ps(packed_w + ${N}, vb${ABC[N:N+4]});
      } else {
        const __m128 vzero = _mm_setzero_ps();
        _mm_store_ps(packed_w, vzero);
        $for N in range(4, NR, 4):
          _mm_store_ps(packed_w + ${N}, vzero);
      }
      packed_w += ${NR};

      $for N in range(1, NR):
        const float* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      size_t k = kc;
      $if KBLOCK > 4:
        // KC main loop multiple of ${KBLOCK}
        for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
          $for N in range(NR):
            const __m128 v${ABC[N]}x${ABC[0:4]} = _mm_loadu_ps(w${N});
            $for K in range(4, KBLOCK, 4):
              const __m128 v${ABC[N]}x${ABC[K:K+4]} = _mm_loadu_ps(w${N} + ${K});
            w${N} += ${KBLOCK};

          $for N in range(0, NR, 4):
            $for K in range(0, KBLOCK, 4):
              const __m128 v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]} = _mm_unpacklo_ps(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]});
              const __m128 v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]} = _mm_unpacklo_ps(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[N+3]}x${ABC[K:K+4]});
              const __m128 v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]} = _mm_unpackhi_ps(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]});
              const __m128 v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]} = _mm_unpackhi_ps(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[N+3]}x${ABC[K:K+4]});
          $if PREFETCH:
            $for N in range(0, NR):
              xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
          $for N in range(0, NR, 4):
            $for K in range(0, KBLOCK, 4):
              const __m128 v${ABC[N:N+4]}x${ABC[K]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+1]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+2]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+3]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]});

          _mm_store_ps(packed_w, v${ABC[0:4]}x${ABC[0]});
          $for K in range(KBLOCK):
            $for N in range(0, NR, 4):
              $if N != 0 or K != 0:
                _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
          packed_w += ${NR*KBLOCK};
        }

      // KC multiple of 4
      for (; k >= 4; k -= 4) {
        $for N in range(NR):
          const __m128 v${ABC[N]}x${ABC[0:4]} = _mm_loadu_ps(w${N});
          w${N} += 4;

        $for N in range(0, NR, 4):
          const __m128 v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]} = _mm_unpacklo_ps(v${ABC[N]}x${ABC[0:4]}, v${ABC[N+1]}x${ABC[0:4]});
          const __m128 v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]} = _mm_unpacklo_ps(v${ABC[N+2]}x${ABC[0:4]}, v${ABC[N+3]}x${ABC[0:4]});
          const __m128 v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]} = _mm_unpackhi_ps(v${ABC[N]}x${ABC[0:4]}, v${ABC[N+1]}x${ABC[0:4]});
          const __m128 v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]} = _mm_unpackhi_ps(v${ABC[N+2]}x${ABC[0:4]}, v${ABC[N+3]}x${ABC[0:4]});
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        $for N in range(0, NR, 4):
          const __m128 v${ABC[N:N+4]}x${ABC[0]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]}, v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]});
          const __m128 v${ABC[N:N+4]}x${ABC[1]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]}, v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]});
          const __m128 v${ABC[N:N+4]}x${ABC[2]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]}, v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]});
          const __m128 v${ABC[N:N+4]}x${ABC[3]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]}, v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]});

        _mm_store_ps(packed_w, v${ABC[0:4]}x${ABC[0]});
        $for K in range(4):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
        packed_w += ${NR*4};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        switch (k) {
          case 1:
          {
            $for N in range(NR):
              const __m128 v${ABC[N]}x0 = _mm_load_ss(w${N});
              w${N} += 1;

            $for N in range(0, NR, 2):
              const __m128 v${ABC[N:N+2]}x0 = _mm_unpacklo_ps(v${ABC[N]}x0, v${ABC[N+1]}x0);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0, v${ABC[N+2:N+4]}x0);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for N in range(4, NR, 4):
              _mm_store_ps(packed_w + ${N}, v${ABC[N:N+4]}x0);
            packed_w += ${NR};
            break;
          }
          case 2:
          {
            $for N in range(NR):
              const __m128 v${ABC[N]}x01 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*) w${N}));
              w${N} += 2;

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = _mm_unpacklo_ps(v${ABC[N]}x01, v${ABC[N+1]}x01);
              const __m128 v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = _mm_unpacklo_ps(v${ABC[N+2]}x01, v${ABC[N+3]}x01);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1);
              const __m128 v${ABC[N:N+4]}x1 = _mm_movehl_ps(v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for K in range(2):
              $for N in range(0, NR, 4):
                $if N != 0 or K != 0:
                  _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
            packed_w += ${NR*2};
            break;
          }
          case 3:
          {
            $for N in range(NR):
              __m128 v${ABC[N]}x012 = _mm_load_ss(w${N} + 2);

            $for N in range(NR):
              v${ABC[N]}x012 = _mm_movelh_ps(v${ABC[N]}x012, v${ABC[N]}x012);

            $for N in range(NR):
              v${ABC[N]}x012 = _mm_loadl_pi(v${ABC[N]}x012, (const __m64*) w${N});
              w${N} += 3;

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = _mm_unpacklo_ps(v${ABC[N]}x012, v${ABC[N+1]}x012);
              const __m128 v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = _mm_unpacklo_ps(v${ABC[N+2]}x012, v${ABC[N+3]}x012);
              const __m128 v${ABC[N:N+2]}x2 = _mm_unpackhi_ps(v${ABC[N]}x012, v${ABC[N+1]}x012);
              const __m128 v${ABC[N+2:N+4]}x2 = _mm_unpackhi_ps(v${ABC[N+2]}x012, v${ABC[N+3]}x012);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1);
              const __m128 v${ABC[N:N+4]}x1 = _mm_movehl_ps(v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1);
              const __m128 v${ABC[N:N+4]}x2 = _mm_movelh_ps(v${ABC[N:N+2]}x2, v${ABC[N+2:N+4]}x2);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for K in range(3):
              $for N in range(0, NR, 4):
                $if N != 0 or K != 0:
                  _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
            packed_w += ${NR*3};
            break;
          }
          default:
            XNN_UNREACHABLE;
        }
      }
      packed_w = (float*) ((uintptr_t) packed_w + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          *packed_w++  = *b++;
        } while (--nb != 0);
        packed_w += (${NR} - n);
      } else {
        const __m128 vzero = _mm_setzero_ps();
        _mm_store_ps(packed_w, vzero);
        $for N in range(4, NR, 4):
          _mm_store_ps(packed_w + ${N}, vzero);
        packed_w += ${NR};
      }

      $for N in range(1, NR-1):
        const float* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      size_t k = kc;
      $if KBLOCK > 4:
        // KC main loop multiple of ${KBLOCK}
        for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
          $for N in range(NR-1):
            const __m128 v${ABC[N]}x${ABC[0:4]} = _mm_loadu_ps(w${N});
            $for K in range(4, KBLOCK, 4):
              const __m128 v${ABC[N]}x${ABC[K:K+4]} = _mm_loadu_ps(w${N} + ${K});
            w${N} += ${KBLOCK};

          $for N in range(0, NR, 4):
            $for K in range(0, KBLOCK, 4):
              const __m128 v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]} = _mm_unpacklo_ps(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]});
              const __m128 v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]} = _mm_unpacklo_ps(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[min(N+3, NR-2)]}x${ABC[K:K+4]});
              const __m128 v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]} = _mm_unpackhi_ps(v${ABC[N]}x${ABC[K:K+4]}, v${ABC[N+1]}x${ABC[K:K+4]});
              const __m128 v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]} = _mm_unpackhi_ps(v${ABC[N+2]}x${ABC[K:K+4]}, v${ABC[min(N+3, NR-2)]}x${ABC[K:K+4]});
          $if PREFETCH:
            $for N in range(0, NR-1):
              xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
          $for N in range(0, NR, 4):
            $for K in range(0, KBLOCK, 4):
              const __m128 v${ABC[N:N+4]}x${ABC[K]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]}, v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+1]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[K]}_${ABC[N+2:N+4]}x${ABC[K+1]}, v${ABC[N:N+2]}x${ABC[K]}_${ABC[N:N+2]}x${ABC[K+1]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+2]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]}, v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]});
              const __m128 v${ABC[N:N+4]}x${ABC[K+3]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[K+2]}_${ABC[N+2:N+4]}x${ABC[K+3]}, v${ABC[N:N+2]}x${ABC[K+2]}_${ABC[N:N+2]}x${ABC[K+3]});

          _mm_store_ps(packed_w, v${ABC[0:4]}x${ABC[0]});
          $for K in range(KBLOCK):
            $for N in range(0, NR, 4):
              $if N != 0 or K != 0:
                _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
          packed_w += ${NR*KBLOCK};
        }

      // KC multiple of 4
      for (; k >= 4; k -= 4) {
        $for N in range(NR-1):
          const __m128 v${ABC[N]}x${ABC[0:4]} = _mm_loadu_ps(w${N});
          w${N} += 4;

        $for N in range(0, NR, 4):
          const __m128 v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]} = _mm_unpacklo_ps(v${ABC[N]}x${ABC[0:4]}, v${ABC[N+1]}x${ABC[0:4]});
          const __m128 v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]} = _mm_unpacklo_ps(v${ABC[N+2]}x${ABC[0:4]}, v${ABC[min(N+3, NR-2)]}x${ABC[0:4]});
          const __m128 v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]} = _mm_unpackhi_ps(v${ABC[N]}x${ABC[0:4]}, v${ABC[N+1]}x${ABC[0:4]});
          const __m128 v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]} = _mm_unpackhi_ps(v${ABC[N+2]}x${ABC[0:4]}, v${ABC[min(N+3, NR-2)]}x${ABC[0:4]});
        $if PREFETCH:
          $for N in range(0, NR-1):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        $for N in range(0, NR, 4):
          const __m128 v${ABC[N:N+4]}x${ABC[0]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]}, v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]});
          const __m128 v${ABC[N:N+4]}x${ABC[1]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[0]}_${ABC[N+2:N+4]}x${ABC[1]}, v${ABC[N:N+2]}x${ABC[0]}_${ABC[N:N+2]}x${ABC[1]});
          const __m128 v${ABC[N:N+4]}x${ABC[2]} = _mm_movelh_ps(v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]}, v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]});
          const __m128 v${ABC[N:N+4]}x${ABC[3]} = _mm_movehl_ps(v${ABC[N+2:N+4]}x${ABC[2]}_${ABC[N+2:N+4]}x${ABC[3]}, v${ABC[N:N+2]}x${ABC[2]}_${ABC[N:N+2]}x${ABC[3]});

        _mm_store_ps(packed_w, v${ABC[0:4]}x${ABC[0]});
        $for K in range(4):
          $for N in range(0, NR, 4):
            $if N != 0 or K != 0:
              _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
        packed_w += ${NR*4};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        switch (k) {
          case 1:
          {
            $for N in range(NR-1):
              const __m128 v${ABC[N]}x0 = _mm_load_ss(w${N});

            $for N in range(0, NR, 2):
              const __m128 v${ABC[N:N+2]}x0 = _mm_unpacklo_ps(v${ABC[N]}x0, v${ABC[min(N+1, NR-2)]}x0);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0, v${ABC[N+2:N+4]}x0);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for N in range(4, NR, 4):
              _mm_store_ps(packed_w + ${N}, v${ABC[N:N+4]}x0);
            packed_w += ${NR};
            break;
          }
          case 2:
          {
            $for N in range(NR-1):
              const __m128 v${ABC[N]}x01 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*) w${N}));

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = _mm_unpacklo_ps(v${ABC[N]}x01, v${ABC[N+1]}x01);
              const __m128 v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = _mm_unpacklo_ps(v${ABC[N+2]}x01, v${ABC[min(N+3, NR-2)]}x01);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1);
              const __m128 v${ABC[N:N+4]}x1 = _mm_movehl_ps(v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for K in range(2):
              $for N in range(0, NR, 4):
                $if N != 0 or K != 0:
                  _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
            packed_w += ${NR*2};
            break;
          }
          case 3:
          {
            $for N in range(NR-1):
              __m128 v${ABC[N]}x012 = _mm_load_ss(w${N} + 2);

            $for N in range(NR-1):
              v${ABC[N]}x012 = _mm_movelh_ps(v${ABC[N]}x012, v${ABC[N]}x012);

            $for N in range(NR-1):
              v${ABC[N]}x012 = _mm_loadl_pi(v${ABC[N]}x012, (const __m64*) w${N});

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1 = _mm_unpacklo_ps(v${ABC[N]}x012, v${ABC[N+1]}x012);
              const __m128 v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1 = _mm_unpacklo_ps(v${ABC[N+2]}x012, v${ABC[min(N+3, NR-2)]}x012);
              const __m128 v${ABC[N:N+2]}x2 = _mm_unpackhi_ps(v${ABC[N]}x012, v${ABC[N+1]}x012);
              const __m128 v${ABC[N+2:N+4]}x2 = _mm_unpackhi_ps(v${ABC[N+2]}x012, v${ABC[min(N+3, NR-2)]}x012);

            $for N in range(0, NR, 4):
              const __m128 v${ABC[N:N+4]}x0 = _mm_movelh_ps(v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1, v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1);
              const __m128 v${ABC[N:N+4]}x1 = _mm_movehl_ps(v${ABC[N+2:N+4]}x0_${ABC[N+2:N+4]}x1, v${ABC[N:N+2]}x0_${ABC[N:N+2]}x1);
              const __m128 v${ABC[N:N+4]}x2 = _mm_movelh_ps(v${ABC[N:N+2]}x2, v${ABC[N+2:N+4]}x2);

            _mm_store_ps(packed_w, v${ABC[0:4]}x0);
            $for K in range(3):
              $for N in range(0, NR, 4):
                $if N != 0 or K != 0:
                  _mm_store_ps(packed_w + ${N+K*NR}, v${ABC[N:N+4]}x${ABC[K]});
            packed_w += ${NR*3};
            break;
          }
          default:
            XNN_UNREACHABLE;
        }
      }
      packed_w = (float*) ((uintptr_t) packed_w + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
