// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert MR == 4 or MR == 8
$assert KBLOCK >= 4
$assert KBLOCK % 4 == 0
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

#include <assert.h>

#include <arm_neon.h>

#include "xnnpack/packx.h"
$if PREFETCH:
  #include "xnnpack/prefetch.h"


$IFWHILE = "if" if KBLOCK==8 else "while"
void xnn_x32_packx_ukernel_${MR}x__neon_st4_x${KBLOCK}${"_prfm" if PREFETCH else ""}(
    size_t m,
    size_t k,
    const uint32_t* x,
    size_t x_stride,
    uint32_t* restrict y)
{
  assert(m != 0);
  assert(m <= ${MR});
  assert(k != 0);
  assert(x != NULL);
  assert(y != NULL);

  const uint32_t* x0 = x;
  $for M in range(1, MR):
    const uint32_t* x${M} = (const uint32_t*) ((uintptr_t) x${M-1} + x_stride);
    $if M % 2 == 0:
      if XNN_UNPREDICTABLE(m <= ${M}) {
        x${M} = x${M-1};
      }
    $elif M + 1 == MR:
      if XNN_UNPREDICTABLE(m != ${M+1}) {
        x${M} = x${M-1};
      }
    $else:
      if XNN_UNPREDICTABLE(m < ${M+1}) {
        x${M} = x${M-1};
      }

  $if MR == 4:
    for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
      $for K in range(0, KBLOCK, 4):
        $for M in range(0, MR, 4):
          uint32x4x4_t vx${ABC[K:K+4]}x${ABC[M:M+4]};
      $for K in range(0, KBLOCK, 4):
        $for M in range(0, MR, 4):
          $for L in range(4):
            vx${ABC[K:K+4]}x${ABC[M:M+4]}.val[${L}] = vld1q_u32(x${M+L}); x${M+L} += 4;
      $if PREFETCH:
        $for M in range(MR):
          xnn_prefetch_to_l1((const int8_t*) x${M} + 128);
      $for K in range(0, KBLOCK, 4):
        $for M in range(0, MR, 4):
          vst4q_u32(y, vx${ABC[K:K+4]}x${ABC[M:M+4]}); y += 16;
    }
  $if MR == 8:
    for (; k >= ${KBLOCK}; k -= ${KBLOCK}) {
      $for K in range(0, KBLOCK, 4):
        $for M in range(MR):
          const uint32x4_t vx${ABC[K:K+4]}x${M} = vld1q_u32(x${M}); x${M} += 4;
      $if PREFETCH:
        $for M in range(MR):
          xnn_prefetch_to_l1((const int8_t*) x${M} + 128);
      $for K in range(0, KBLOCK, 4):
        $for M in range(0, MR // 2):
          const uint32x4x2_t vz${ABC[K:K+4]}x${M} = vzipq_u32(vx${ABC[K:K+4]}x${M}, vx${ABC[K:K+4]}x${M+4});

      $for K in range(0, KBLOCK, 4):
        $for L in range(0, 2):
          const uint32x4x4_t vy${ABC[K:K+4]}x${L} = { vz${ABC[K:K+4]}x0.val[${L}], vz${ABC[K:K+4]}x1.val[${L}], vz${ABC[K:K+4]}x2.val[${L}], vz${ABC[K:K+4]}x3.val[${L}] };
      $for K in range(0, KBLOCK, 4):
        $for L in range(0, 2):
          vst4q_u32(y, vy${ABC[K:K+4]}x${L}); y += 16;
    }

  if XNN_UNLIKELY(k != 0) {
    $for M in range(0, MR, 4):
      uint32x4_t vt${ABC[M:M+4]} = vdupq_n_u32(0);
    do {
      $for M in range(0, MR, 4):
        vt${ABC[M:M+4]} = vld1q_lane_u32(x${M+0}, vt${ABC[M:M+4]}, 0); x${M+0} += 1;
        vt${ABC[M:M+4]} = vld1q_lane_u32(x${M+1}, vt${ABC[M:M+4]}, 1); x${M+1} += 1;
        vt${ABC[M:M+4]} = vld1q_lane_u32(x${M+2}, vt${ABC[M:M+4]}, 2); x${M+2} += 1;
        vt${ABC[M:M+4]} = vld1q_lane_u32(x${M+3}, vt${ABC[M:M+4]}, 3); x${M+3} += 1;
      $if PREFETCH:
        $for M in range(MR):
          xnn_prefetch_to_l1((const int8_t*) x${M} + 128);
      $for M in range(0, MR, 4):
        vst1q_u32(y, vt${ABC[M:M+4]}); y += 4;
    } while (--k != 0);
  }
}
