// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR >= 8
$assert NR % 16 == 0
$assert KBLOCK == 4
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <immintrin.h>

#include "xnnpack/intrinsics-polyfill.h"
#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


void xnn_x32_packw_gemm_goi_ukernel_x${NR}__avx512f_u${KBLOCK}${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});   // This kernel is for NR=${NR}
  assert(kr == 1);
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);

  const float* b = (const float*) bias;
  float* packed_w = (float*) packed_weights;
  do {
    // NC main loop multiple of ${NR}
    const float* w0 = (const float*) weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(b != NULL) {
        const __m512 vb0 = _mm512_loadu_ps(b);
        $for N in range(16,NR,16):
          const __m512 vb${N} = _mm512_loadu_ps(b + ${N});
        _mm512_store_ps(packed_w, vb0);
        $for N in range(16,NR,16):
          _mm512_store_ps(packed_w + ${N}, vb${N});
        b += ${NR};
      } else {
        const __m512 vzero = _mm512_setzero_ps();
        _mm512_store_ps(packed_w, vzero);
        $for N in range(16,NR,16):
          _mm512_store_ps(packed_w + ${N}, vzero);
      }
      packed_w += ${NR};

      $for N in range(1, NR):
        const float* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      // KC main loop multiple of ${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        // Read blocks of 4x4
        // a b c d
        // e f g h
        // i j k l
        // m n o p
        // Load first 4 rows of N into low part of each register
        $for N in range(0,NR,16):
          $for L in range(4):
            __m512 v${N+L}x0123 = _mm512_castps128_ps512(_mm_loadu_ps(w${N+L}));
            w${N+L} += ${KBLOCK};
        // Load next 4 rows of N into the high part of each register
        $for N in range(0,NR,16):
          $for L in range(4):
            v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+4}), 1);
            w${N+L+4} += ${KBLOCK};
          $for L in range(4):
            v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+8}), 2);
            w${N+L+8} += ${KBLOCK};
          $for L in range(4):
            v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+12}), 3);
            w${N+L+12} += ${KBLOCK};

        // Transpose 2x2
        $for N in range(0,NR,16):
          $for K in range(0,KBLOCK,4):
            const __m512 vtmp${N+0}x${ABC[K:K+4]} = _mm512_unpacklo_ps(v${N+0}x${ABC[K:K+4]}, v${N+1}x${ABC[K:K+4]});  // a e b f   from row 0, 1
            const __m512 vtmp${N+1}x${ABC[K:K+4]} = _mm512_unpacklo_ps(v${N+2}x${ABC[K:K+4]}, v${N+3}x${ABC[K:K+4]});  // i m j n   from row 2, 3
            const __m512 vtmp${N+2}x${ABC[K:K+4]} = _mm512_unpackhi_ps(v${N+0}x${ABC[K:K+4]}, v${N+1}x${ABC[K:K+4]});  // c g d h   from row 0, 1
            const __m512 vtmp${N+3}x${ABC[K:K+4]} = _mm512_unpackhi_ps(v${N+2}x${ABC[K:K+4]}, v${N+3}x${ABC[K:K+4]});  // k o l p   from row 2, 3
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
         // Transpose 4x4
        $for N in range(0,NR,16):
          $for K in range(0,KBLOCK,4):
            v${N+0}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(vtmp${N+0}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+1}x${ABC[K:K+4]})));  // a e i m   from row 0, 1
            v${N+1}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(vtmp${N+0}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+1}x${ABC[K:K+4]})));  // b f j n   from row 0, 1
            v${N+2}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(vtmp${N+2}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+3}x${ABC[K:K+4]})));  // c g k o   from row 2, 3
            v${N+3}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(vtmp${N+2}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+3}x${ABC[K:K+4]})));  // d h l p   from row 2, 3

        _mm512_store_ps(packed_w, v0x0123);
        $for K in range(KBLOCK):
          $for N in range(0,NR,16):
            $if N != 0 or K != 0:
              _mm512_store_ps(packed_w + ${N+K*NR}, v${N+K}x${ABC[K//4*4:K//4*4+4]});
        packed_w += ${NR*KBLOCK};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        if (k & 2) {
          // Read blocks of 4x2
          // a b
          // c d
          // e f
          // g h
          $for N in range(NR):
            __m128 v${N} = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*) w${N}));
            w${N} += 2;

          $for N in range(0,NR,4):
            // Transpose 2x2
            const __m128 vtmp${N+0} = _mm_unpacklo_ps(v${N+0}, v${N+1});  // a c b d   from row 0, 1
            const __m128 vtmp${N+1} = _mm_unpacklo_ps(v${N+2}, v${N+3});  // e g f h   from row 2, 3
            // Transpose 4x4
            v${N+0} = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // a c e g   from row 0, 1
            v${N+1} = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // b d f h   from row 0, 1

          _mm_store_ps(packed_w, v0);
          $for K in range(2):
            $for N in range(0,NR,4):
              $if N != 0 or K != 0:
                _mm_store_ps(packed_w + ${N//4*4+K*NR}, v${N+K});
          packed_w += ${NR*2};
        }
        if (k & 1) {
          // Read blocks of 4x1
          // a
          // b
          // c
          // d
          $for N in range(NR):
            __m128 v${N} = _mm_load_ss(w${N});
            w${N} += 1;

          $for N in range(0,NR,4):
            // Transpose 2x2
            const __m128 vtmp${N+0} = _mm_unpacklo_ps(v${N+0}, v${N+1});  // a b  from row 0, 1
            const __m128 vtmp${N+1} = _mm_unpacklo_ps(v${N+2}, v${N+3});  // c d  from row 2, 3
            // Transpose 4x4
            v${N+0} = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // a b c d   from row 0, 1

          _mm_store_ps(packed_w, v0);
          $for N in range(4,NR,4):
            _mm_store_ps(packed_w + ${N//4*4}, v${N});
          packed_w += ${NR};
        }
      }
      packed_w = (float*) ((uintptr_t) packed_w + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          *packed_w++  = *b++;
        } while (--nb != 0);
        packed_w += (${NR} - n);
      } else {
        const __m512 vzero = _mm512_setzero_ps();
        _mm512_store_ps(packed_w, vzero);
        $for N in range(16,NR,16):
          _mm512_store_ps(packed_w + ${N}, vzero);
        packed_w += ${NR};
      }

      // NR remainder has less than ${NR} rows so last row is not loaded
      // For SR=4 the
      $for N in range(1, NR-1):
        const float* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      // KC main loop multiple of ${KBLOCK}
      size_t k = kc;
      for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
        // Read blocks of 4x4
        // a b c d
        // e f g h
        // i j k l
        // m n o p
        // Load first 4 rows of N into low part of each register
        $for N in range(0,NR,16):
          $for L in range(4):
            $if (N+L+12)==(NR-1):
              // castps leaves upper 128 bits undefined, so zero them.
              __m512 v${N+L}x0123 = _mm512_zextps128_ps512(_mm_loadu_ps(w${N+L}));
            $else:
              __m512 v${N+L}x0123 = _mm512_castps128_ps512(_mm_loadu_ps(w${N+L}));
            w${N+L} += ${KBLOCK};
        // Load next 4 rows of N into the high part of each register
        $for N in range(0,NR,16):
          $for L in range(4):
            v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+4}), 1);
            w${N+L+4} += ${KBLOCK};
          $for L in range(4):
            v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+8}), 2);
            w${N+L+8} += ${KBLOCK};
          $for L in range(4):
            $if (N+L+12)<(NR-1):
              v${N+L}x0123 = _mm512_insertf32x4(v${N+L}x0123, _mm_loadu_ps(w${N+L+12}), 3);
              w${N+L+12} += ${KBLOCK};

        // Transpose 2x2
        $for N in range(0,NR,16):
          $for K in range(0,KBLOCK,4):
            const __m512 vtmp${N+0}x${ABC[K:K+4]} = _mm512_unpacklo_ps(v${N+0}x${ABC[K:K+4]}, v${N+1}x${ABC[K:K+4]});  // a e b f   from row 0, 1
            const __m512 vtmp${N+1}x${ABC[K:K+4]} = _mm512_unpacklo_ps(v${N+2}x${ABC[K:K+4]}, v${N+3}x${ABC[K:K+4]});  // i m j n   from row 2, 3
            const __m512 vtmp${N+2}x${ABC[K:K+4]} = _mm512_unpackhi_ps(v${N+0}x${ABC[K:K+4]}, v${N+1}x${ABC[K:K+4]});  // c g d h   from row 0, 1
            const __m512 vtmp${N+3}x${ABC[K:K+4]} = _mm512_unpackhi_ps(v${N+2}x${ABC[K:K+4]}, v${N+3}x${ABC[K:K+4]});  // k o l p   from row 2, 3
        $if PREFETCH:
          $for N in range(0, NR-1):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        // Transpose 4x4
        $for N in range(0,NR,16):
          $for K in range(0,KBLOCK,4):
            v${N+0}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(vtmp${N+0}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+1}x${ABC[K:K+4]})));  // a e i m   from row 0, 1
            v${N+1}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(vtmp${N+0}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+1}x${ABC[K:K+4]})));  // b f j n   from row 0, 1
            v${N+2}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(vtmp${N+2}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+3}x${ABC[K:K+4]})));  // c g k o   from row 2, 3
            v${N+3}x${ABC[K:K+4]} = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(vtmp${N+2}x${ABC[K:K+4]}), _mm512_castps_pd(vtmp${N+3}x${ABC[K:K+4]})));  // d h l p   from row 2, 3

        _mm512_store_ps(packed_w, v0x0123);
        $for K in range(KBLOCK):
          $for N in range(0,NR,16):
            $if N != 0 or K != 0:
              _mm512_store_ps(packed_w + ${N+K*NR}, v${N+K}x${ABC[K//4*4:K//4*4+4]});
        packed_w += ${NR*KBLOCK};
      }

      // KC remainder (1..3)
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= 3);
        if (k & 2) {
          // Read blocks of 4x2
          // a b
          // c d
          // e f
          // g h
          $for N in range(NR-1):
            __m128 v${N} = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*) w${N}));
            w${N} += 2;

          $for N in range(0,NR,4):
            // Transpose 2x2
            const __m128 vtmp${N+0} = _mm_unpacklo_ps(v${N+0}, v${N+1});  // a c b d   from row 0, 1
            const __m128 vtmp${N+1} = _mm_unpacklo_ps(v${N+2}, v${min(N+3, NR-2)});  // e g f h   from row 2, 3
            // Transpose 4x4
            v${N+0} = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // a c e g   from row 0, 1
            v${N+1} = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // b d f h   from row 0, 1

          _mm_store_ps(packed_w, v0);
          $for K in range(2):
            $for N in range(0,NR,4):
              $if N != 0 or K != 0:
                _mm_store_ps(packed_w + ${N//4*4+K*NR}, v${N+K});
          packed_w += ${NR*2};
        }
        if (k & 1) {
          // Read blocks of 4x1
          // a
          // b
          // c
          // d
          $for N in range(NR-1):
            __m128 v${N} = _mm_load_ss(w${N});
            w${N} += 1;

          $for N in range(0,NR,4):
            // Transpose 2x2
            const __m128 vtmp${N+0} = _mm_unpacklo_ps(v${N+0}, v${N+1});  // a b  from row 0, 1
            const __m128 vtmp${N+1} = _mm_unpacklo_ps(v${N+2}, v${min(N+3, NR-2)});  // c d  from row 2, 3
            // Transpose 4x4
            v${N+0} = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(vtmp${N+0}), _mm_castps_pd(vtmp${N+1})));  // a b c d   from row 0, 1

          _mm_store_ps(packed_w, v0);
          $for N in range(4,NR,4):
            _mm_store_ps(packed_w + ${N//4*4}, v${N});
          packed_w += ${NR};
        }
      }
    }
    weights += nc * kc;
  } while (--g != 0);
}
