// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert NR != 0
$assert KR == 4

#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include <emmintrin.h>

#include "xnnpack/packw.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


void xnn_x32_packw_gemm_goi_ukernel_x${NR}c${KR}__sse2_u4${"_prfm" if PREFETCH else ""}(
  size_t g,
  size_t nc,
  size_t kc,
  size_t nr,
  size_t kr,
  size_t sr,
  const uint32_t* weights,
  const uint32_t* bias,
  const void* scale,
  uint32_t* packed_weights,
  size_t extra_bytes,
  const void* params)
{
  assert(g != 0);
  assert(nc != 0);
  assert(kc != 0);
  assert(nr == ${NR});
  assert(kr == ${KR});
  assert(sr == 1);
  assert(weights != NULL);
  assert(packed_weights != NULL);
  $for N in range(NR):
    __m128 v${N};

  const float* b = (const float*) bias;
  float* packed_w = (float*) packed_weights;
  do {
    // NC main loop multiple of ${NR}
    const float* w0 = (const float*) weights;
    size_t n = nc;

    for (; n >= ${NR}; n -= ${NR}) {
      if XNN_LIKELY(b != NULL) {
        $for N in range(NR):
          packed_w[${N}] = b[${N}];
        b += ${NR};
      } else {
        $for N in range(NR):
          packed_w[${N}] = 0.0f;
      }
      packed_w += ${NR};

      $for N in range(1, NR):
        const float* w${N} = w${N-1} + kc;
      $if PREFETCH:
        $for N in range(0, NR):
          xnn_prefetch_to_l1((const int8_t*) w${N});
          xnn_prefetch_to_l1((const int8_t*) w${N} + 64);

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        // Read blocks of 2x4
        // a b c d
        // e f g h
        $for N in range(NR):
          v${N} = _mm_loadu_ps(w${N});
          w${N} += ${KR};
        $if PREFETCH:
          $for N in range(0, NR):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        _mm_storeu_ps(packed_w, v0);
        $for N in range(1,NR):
          _mm_storeu_ps(packed_w + ${N*KR}, v${N});
        packed_w += ${NR*KR};
      }

      // KC remainder (1..${KR-1})
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= ${KR-1});
        switch (k) {
          case 1:
            // Read blocks of 2x1
            // a
            // e
            $for N in range(NR):
              v${N} = _mm_load_ss(w${N});
              w${N} += 1;
            break;
          case 2:
            // Read blocks of 2x2
            // a b
            // e f
            $for N in range(NR):
              v${N} = _mm_castpd_ps(_mm_load_sd((const double*) w${N}));
              w${N} += 2;
            break;
          case 3:
          {
            // Read blocks of 2x3
            // a b c
            // e f g
            $for N in range(NR):
              const __m128 v${N}lo = _mm_castpd_ps(_mm_load_sd((const double*) w${N}));
              const __m128 v${N}hi = _mm_load_ss(w${N} + 2);
              v${N} = _mm_movelh_ps(v${N}lo, v${N}hi);
              w${N} += 3;
            break;
          }
          default:
            XNN_UNREACHABLE;
        }
        _mm_storeu_ps(packed_w, v0);
        $for N in range(1,NR):
          _mm_storeu_ps(packed_w + ${N*KR}, v${N});
        packed_w += ${NR*KR};
      }
      packed_w = (float*) ((uintptr_t) packed_w + extra_bytes);
      w0 = w${NR-1};
    }

    // NC remainder (1..${NR-1})
    if XNN_UNLIKELY(n != 0) {
      assert(n >= 1);
      assert(n <= ${NR-1});
      if XNN_LIKELY(b != NULL) {
        size_t nb = n;
        do {
          *packed_w++  = *b++;
        } while (--nb != 0);
        packed_w += (${NR} - n);
      } else {
        $for N in range(NR):
          packed_w[${N}] = 0.0f;
        packed_w += ${NR};
      }

      // NR remainder has less than ${NR} rows so last row is not loaded
      $for N in range(1, NR-1):
        const float* w${N} = w${N-1} + kc;
        $if N % 2 == 0:
          if XNN_UNPREDICTABLE(n <= ${N}) {
            w${N} = w${N-1};
          }
        $else:
          if XNN_UNPREDICTABLE(n < ${N+1}) {
            w${N} = w${N-1};
          }

      // KC main loop multiple of ${NR}x${KR}
      size_t k = kc;
      for (; k >= ${KR}; k -= ${KR}) {
        // Read blocks of 2x4
        // a b c d
        // e f g h
        $for N in range(NR-1):
          v${N} = _mm_loadu_ps(w${N});
          w${N} += ${KR};
        $if PREFETCH:
          $for N in range(0, NR-1):
            xnn_prefetch_to_l1((const int8_t*) w${N} + 128);
        _mm_storeu_ps(packed_w, v0);
        $for N in range(1,NR):
          _mm_storeu_ps(packed_w + ${N*KR}, v${min(N,NR-2)});
        packed_w += ${NR*KR};
      }

      // KC remainder (1..${KR-1})
      if XNN_UNLIKELY(k != 0) {
        assert(k >= 1);
        assert(k <= ${KR-1});
        switch (k) {
          case 1:
            // Read blocks of 1x1
            // a
            $for N in range(NR-1):
              v${N} = _mm_load_ss(w${N});
              w${N} += 1;
            break;
          case 2:
            // Read blocks of 1x2
            // a b
            $for N in range(NR-1):
              v${N} = _mm_castpd_ps(_mm_load_sd((const double*) w${N}));
              w${N} += 2;
            break;
          case 3:
          {
            // Read blocks of 1x3
            // a b c
            $for N in range(NR-1):
              const __m128 v${N}lo = _mm_castpd_ps(_mm_load_sd((const double*) w${N}));
              const __m128 v${N}hi = _mm_load_ss(w${N} + 2);
              v${N} = _mm_movelh_ps(v${N}lo, v${N}hi);
              w${N} += 3;
            break;
          }
          default:
            XNN_UNREACHABLE;
        }
        _mm_storeu_ps(packed_w, v0);
        $for N in range(1,NR):
          _mm_storeu_ps(packed_w + ${N*KR}, v${min(N,NR-2)});
        packed_w += ${NR*KR};
      }
      packed_w = (float*) ((uintptr_t) packed_w + extra_bytes);
    }
    weights += nc * kc;
  } while (--g != 0);
}
