// Copyright 2023 SiFive, Inc.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$import math
$assert SIZE in [32]
$assert VLEN in [64, 128, 256, 512, 1024]
$LMUL = 1
$TILE_WIDTH = int(min(VLEN/SIZE, 8))
$TILE_HEIGHT = int(VLEN/SIZE)
$NUM_ITERS = int(math.log2(TILE_HEIGHT))

#include <riscv_vector.h>

#include <assert.h>

#include "xnnpack/common.h"
#include "xnnpack/math.h"
#include "xnnpack/transpose.h"

void xnn_x${SIZE}_transposec_ukernel__${TILE_HEIGHT}x${TILE_WIDTH}_rvv(
  const uint${SIZE}_t* input,
  uint${SIZE}_t* output,
  size_t input_stride,
  size_t output_stride,
  size_t block_width,
  size_t block_height) XNN_OOB_READS
{
  assert(block_width == 1 || output_stride >= block_height * sizeof(uint${SIZE}_t));
  assert(block_height == 1 || input_stride >= block_width * sizeof(uint${SIZE}_t));

  const size_t tile_height = ${TILE_HEIGHT};
  const size_t tile_width = ${TILE_WIDTH};
  const size_t tile_hbytes = tile_height * sizeof(uint${SIZE}_t);
  const size_t tile_wbytes = tile_width * sizeof(uint${SIZE}_t);
  const size_t input_reset = tile_wbytes - round_down_po2(block_height, tile_height) * input_stride;
  const size_t input_offset = tile_height * input_stride;
  const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t);

  const uint${SIZE}_t* i0 = input;

  uint${SIZE}_t* o0 = (uint${SIZE}_t*) output;
  $for N in range(1, TILE_WIDTH):
    uint${SIZE}_t* o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N-1} + output_stride);

  do {
    size_t bh = block_height;
    size_t vl = __riscv_vsetvl_e${SIZE}m${LMUL}(tile_height);
    for (; bh >= ${TILE_HEIGHT}; bh -= ${TILE_HEIGHT}) {
      if (block_width >= tile_width) {
        vuint${SIZE}m${LMUL}x${TILE_WIDTH}_t tuple = __riscv_vlsseg${TILE_WIDTH}e${SIZE}_v_u${SIZE}m${LMUL}x${TILE_WIDTH}(i0, input_stride, vl);

        $for N in range(TILE_WIDTH):
          vuint${SIZE}m${LMUL}_t v_d${N} = __riscv_vget_v_u${SIZE}m${LMUL}x${TILE_WIDTH}_u${SIZE}m${LMUL}(tuple, ${N});
          __riscv_vse${SIZE}_v_u${SIZE}m${LMUL}(o${N}, v_d${N}, vl);

      } else {
        switch (block_width) {
          $for M in reversed(range(2, TILE_WIDTH)):
            case ${M}: {
              vuint${SIZE}m${LMUL}x${M}_t tuple = __riscv_vlsseg${M}e${SIZE}_v_u${SIZE}m${LMUL}x${M}(i0, input_stride, vl);

              $for N in range(M):
                vuint${SIZE}m${LMUL}_t v_d${N} = __riscv_vget_v_u${SIZE}m${LMUL}x${M}_u${SIZE}m${LMUL}(tuple, ${N});
                __riscv_vse${SIZE}_v_u${SIZE}m${LMUL}(o${N}, v_d${N}, vl);
              break;
            }\n
          case 1: {
            vuint32m${LMUL}_t v_d0 = __riscv_vlse32_v_u32m${LMUL}(i0, input_stride, vl);
            __riscv_vse32_v_u32m${LMUL}(o0, v_d0, vl);
            break;
          }

          default:
            XNN_UNREACHABLE;
        }
      }

      i0 = (uint${SIZE}_t*) ((uintptr_t) i0 + input_offset);
      $for N in reversed(range(TILE_WIDTH)):
        o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + tile_hbytes);
    }

    if (bh != 0) {
      const uint${SIZE}_t* i = i0;
      vl = __riscv_vsetvl_e${SIZE}m${LMUL}(bh);
      if (block_width >= tile_width) {
        vuint${SIZE}m${LMUL}x${TILE_WIDTH}_t tuple = __riscv_vlsseg${TILE_WIDTH}e${SIZE}_v_u${SIZE}m${LMUL}x${TILE_WIDTH}(i, input_stride, vl);

        $for N in range(TILE_WIDTH):
          vuint${SIZE}m${LMUL}_t v_d${N} = __riscv_vget_v_u${SIZE}m${LMUL}x${TILE_WIDTH}_u${SIZE}m${LMUL}(tuple, ${N});
          __riscv_vse${SIZE}_v_u${SIZE}m${LMUL}(o${N}, v_d${N}, vl);
      } else {
        switch(block_width) {
          $for M in reversed(range(2, TILE_WIDTH)):
            case ${M}: {
              vuint${SIZE}m${LMUL}x${M}_t tuple = __riscv_vlsseg${M}e${SIZE}_v_u${SIZE}m${LMUL}x${M}(i, input_stride, vl);

              $for N in range(0, M):
                vuint${SIZE}m${LMUL}_t v_d${N} = __riscv_vget_v_u${SIZE}m${LMUL}x${M}_u${SIZE}m${LMUL}(tuple, ${N});
                __riscv_vse${SIZE}_v_u${SIZE}m${LMUL}(o${N}, v_d${N}, vl);
              break;
            }

          case 1: {
            vuint32m${LMUL}_t v_d0 = __riscv_vlse32_v_u32m${LMUL}(i, input_stride, vl);
            __riscv_vse32_v_u32m${LMUL}(o0, v_d0, vl);
            break;
          }

          default:
            XNN_UNREACHABLE;
        }
      }

      $for M in range(1, NUM_ITERS + 1):
        $if (TILE_HEIGHT>>M) > 1:
          if (bh & ${TILE_HEIGHT>>M}) {
            $for N in reversed(range(TILE_WIDTH)):
              o${N} += ${TILE_HEIGHT>>M};
            i = (uint${SIZE}_t*) ((uintptr_t) i + input_stride * ${TILE_HEIGHT>>M});
          }
    }

    i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset);

    o0 = (uint${SIZE}_t*) ((uintptr_t) o0 + output_reset);
    $for N in range(1, TILE_WIDTH):
      o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + output_reset);

    block_width = doz(block_width, tile_width);
  } while (block_width != 0);
}
