// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert REQUANTIZATION in ["FP32", "RNDNU"]
$assert DATATYPE in ["QC8", "QS8"]
$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"

#include "xnnpack/assembly.h"

$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8": "qs8"}[DATATYPE]
$PARAMS_UNION = {"QC8": "xnn_qs8_qc8w_conv_minmax_params", "QS8": "xnn_qs8_conv_minmax_params"}[DATATYPE]
$REWIND_DECREMENT = 3 if DATATYPE == "QC8" else {"RNDNU": 15, "FP32": 7}[REQUANTIZATION]
# void xnn_${DATATYPE_SPEC}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal(
#     size_t mr,                 x0
#     size_t nc,                 x1
#     size_t kc,                 x2 / x0
#     size_t ks,                 x3 / x9
#     const int8_t** restrict a,  x4
#     const int8_t* restrict w,  x5
#     int8_t* restrict c,        x6
#     size_t cm_stride,          x7
#     size_t cn_stride,                  [sp] -> x10
#     size_t a_offset,                   [sp + 8] -> x8
#     const int8_t* zero,                [sp + 16] -> x12
#     const union ${PARAMS_UNION} params [sp + 24] -> x11

# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.

// Register usage
// A0 x13  v0
// A1 x15  v1
// B   x5  v4  v5  v6  v7
// C0  x6 v16 v18 v20 v22 v24 v26 v28 v30
// C1  x7 v17 v19 v21 v23 v25 v27 v29 v31
// temp0   v2 v10 v12 v14
// temp1   v3 v11 v13 v15
// unused  v8 v9

BEGIN_FUNCTION xnn_${DATATYPE_SPEC}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal

        # Clamp C pointers
        LDP         x10, x8, [sp]           // Load cn_stride, a_offset
        CMP         x0, 2                   // if mr < 2
        LDP         x12, x11, [sp, 16]      // Load zero, params pointer
        ADD         x7, x6, x7              // c1 = c0 + cm_stride
        STP         d10, d11, [sp, -48]!
        ADD         x2, x2, 15              // kc = (kc + 15) & ~15
        STP         d12, d13, [sp, 16]
        CSEL        x7, x6, x7, LO          //   c1 = c0
        STP         d14, d15, [sp, 32]
        BIC         x2, x2, 15

        .p2align    3
0:
        # Load initial bias from w into accumulators
        LDP         s16, s18, [x5], 8
        MOV         v17.16b, v16.16b
        MOV         v19.16b, v18.16b
        LDP         s20, s22, [x5], 8
        MOV         v21.16b, v20.16b
        MOV         v23.16b, v22.16b
        LDP         s24, s26, [x5], 8
        MOV         v25.16b, v24.16b
        MOV         v27.16b, v26.16b
        LDP         s28, s30, [x5], 8
        MOV         v29.16b, v28.16b
        MOV         v31.16b, v30.16b
        MOV         x9, x3                  // p = ks

        .p2align    3
1:
        # Load next 2 A pointers
        LDP         x13, x15, [x4], 16

        CMP         x13, x12                // if a0 == zero
        ADD         x13, x13, x8            // a0 += a_offset
        CSEL        x13, x12, x13, EQ       //   a0 = zero, else += a0 + a_offset
        CMP         x15, x12                // if a1 == zero
        ADD         x15, x15, x8            // a1 += a_offset
        CSEL        x15, x12, x15, EQ       //   a1 = zero, else += a1 + a_offset

        MOV         x0, x2                  // k = kc

        # Main loop - 16 bytes of A
        .p2align    3
2:
        LDR         q0, [x13], 16
        LDP         q4, q5, [x5]
        LDR         q1, [x15], 16
        LDP         q6, q7, [x5, 32]
        SMULL       v2.8h, v4.8b, v0.8b
        SMULL       v3.8h, v4.8b, v1.8b
        SMULL       v10.8h, v5.8b, v0.8b
        SMULL       v11.8h, v5.8b, v1.8b
        SMLAL2      v2.8h, v4.16b, v0.16b
        SMLAL2      v3.8h, v4.16b, v1.16b
        SMLAL2      v10.8h, v5.16b, v0.16b
        SMLAL2      v11.8h, v5.16b, v1.16b
        SMULL       v12.8h, v6.8b, v0.8b
        SADALP      v16.4s,  v2.8h
        SMULL       v13.8h, v6.8b, v1.8b
        SADALP      v17.4s,  v3.8h
        SMULL       v14.8h, v7.8b, v0.8b
        SADALP      v18.4s, v10.8h
        SMULL       v15.8h, v7.8b, v1.8b
        SADALP      v19.4s, v11.8h
        LDP         q4, q5, [x5, 64]
        SMLAL2      v12.8h, v6.16b, v0.16b
        SMLAL2      v13.8h, v6.16b, v1.16b
        SMLAL2      v14.8h, v7.16b, v0.16b
        SMLAL2      v15.8h, v7.16b, v1.16b
        SMULL       v2.8h, v4.8b, v0.8b
        SADALP      v20.4s, v12.8h
        SMULL       v3.8h, v4.8b, v1.8b
        SADALP      v21.4s, v13.8h
        SMULL       v10.8h, v5.8b, v0.8b
        SADALP      v22.4s, v14.8h
        SMULL       v11.8h, v5.8b, v1.8b
        SADALP      v23.4s, v15.8h
        LDP         q6, q7, [x5, 96]

        SMLAL2      v2.8h, v4.16b, v0.16b
        SMLAL2      v3.8h, v4.16b, v1.16b
        SMLAL2      v10.8h, v5.16b, v0.16b
        SMLAL2      v11.8h, v5.16b, v1.16b
        ADD         x5, x5, 128
        SMULL       v12.8h, v6.8b, v0.8b
        SADALP      v24.4s,  v2.8h
        SMULL       v13.8h, v6.8b, v1.8b
        SADALP      v25.4s,  v3.8h
        SMULL       v14.8h, v7.8b, v0.8b
        SADALP      v26.4s, v10.8h
        SMULL       v15.8h, v7.8b, v1.8b
        SADALP      v27.4s, v11.8h
        SUBS        x0, x0, 16
        SMLAL2      v12.8h, v6.16b, v0.16b
        SMLAL2      v13.8h, v6.16b, v1.16b
        SMLAL2      v14.8h, v7.16b, v0.16b
        SMLAL2      v15.8h, v7.16b, v1.16b
        SADALP      v28.4s, v12.8h
        SADALP      v29.4s, v13.8h
        SADALP      v30.4s, v14.8h
        SADALP      v31.4s, v15.8h
        B.HI        2b

        # ks loop
        SUBS        x9, x9, 16              // ks -= MR * sizeof(int8_t*)
        B.HI        1b

        # Add columns
        ADDP        v16.4s, v16.4s, v18.4s
        ADDP        v20.4s, v20.4s, v22.4s
        $if REQUANTIZATION == "RNDNU":
          LD1R        {v4.4s}, [x11], 4
        ADDP        v24.4s, v24.4s, v26.4s
        ADDP        v28.4s, v28.4s, v30.4s
        $if REQUANTIZATION == "RNDNU":
          LD1R        {v7.4s}, [x11], 4
        ADDP        v17.4s, v17.4s, v19.4s
        ADDP        v21.4s, v21.4s, v23.4s
        ADDP        v25.4s, v25.4s, v27.4s
        ADDP        v29.4s, v29.4s, v31.4s
        ADDP        v0.4s, v16.4s, v20.4s
        ADDP        v1.4s, v24.4s, v28.4s
        ADDP        v2.4s, v17.4s, v21.4s
        ADDP        v3.4s, v25.4s, v29.4s

        $if REQUANTIZATION == "RNDNU":
          # Apply params - preshift, scale, postshift, bias and clamp
          LD1R        {v5.4s}, [x11], 4
          SQSHL       v0.4s, v0.4s, v4.4s     // shift to upper bits
          SQSHL       v1.4s, v1.4s, v4.4s
          SQSHL       v2.4s, v2.4s, v4.4s
          SQSHL       v3.4s, v3.4s, v4.4s
          SQDMULH     v0.4s, v0.4s, v7.4s     // scale without rounding
          SQDMULH     v1.4s, v1.4s, v7.4s
          SQDMULH     v2.4s, v2.4s, v7.4s
          SQDMULH     v3.4s, v3.4s, v7.4s
          SRSHL       v0.4s, v0.4s, v5.4s     // signed rounding shift left
          SRSHL       v1.4s, v1.4s, v5.4s
          SRSHL       v2.4s, v2.4s, v5.4s
          SRSHL       v3.4s, v3.4s, v5.4s
        $elif REQUANTIZATION == "FP32":
          $if DATATYPE != "QC8":
            # Apply params - scale, bias and clamp
            SCVTF       v0.4s, v0.4s
            LD1R        {v4.4s}, [x11], 4
            SCVTF       v1.4s, v1.4s
            SCVTF       v2.4s, v2.4s
            SCVTF       v3.4s, v3.4s
            FMUL        v0.4s, v0.4s, v4.4s
            FMUL        v1.4s, v1.4s, v4.4s
            FMUL        v2.4s, v2.4s, v4.4s
            FMUL        v3.4s, v3.4s, v4.4s
          $else:
            # Load per channel scale values from weights
            SCVTF       v0.4s, v0.4s
            LDR         q4, [x5], 16
            SCVTF       v1.4s, v1.4s
            LDR         q5, [x5], 16
            SCVTF       v2.4s, v2.4s
            SCVTF       v3.4s, v3.4s
            FMUL        v0.4s, v0.4s, v4.4s
            FMUL        v1.4s, v1.4s, v5.4s
            FMUL        v2.4s, v2.4s, v4.4s
            FMUL        v3.4s, v3.4s, v5.4s

          FCVTNS      v0.4s, v0.4s
          FCVTNS      v1.4s, v1.4s
          FCVTNS      v2.4s, v2.4s
          FCVTNS      v3.4s, v3.4s

        LD1R        {v5.8h}, [x11], 2
        SQXTN       v0.4h, v0.4s
        SQXTN       v2.4h, v2.4s
        SQXTN2      v0.8h, v1.4s
        SQXTN2      v2.8h, v3.4s
        SUBS        x1, x1, 8
        SQADD       v0.8h, v0.8h, v5.8h
        SQADD       v1.8h, v2.8h, v5.8h
        SQXTN       v0.8b, v0.8h
        SQXTN2      v0.16b, v1.8h
        LD1R        {v1.16b}, [x11], 1
        LD1R        {v2.16b}, [x11]
        SMAX        v0.16b, v0.16b, v1.16b
        SUB         x11, x11, ${REWIND_DECREMENT}           // rewind params pointer
        SMIN        v0.16b, v0.16b, v2.16b
        B.LO        3f

        # Store full 2 x 8
        ST1         {v0.d}[1], [x7], x10
        SUB         x4, x4, x3              // a -= ks
        ST1         {v0.8b}, [x6], x10

        # nc loop
        B.HI        0b

        # Restore d10-d15 from stack
        LDP         d14, d15, [sp, 32]
        LDP         d12, d13, [sp, 16]
        LDP         d10, d11, [sp], 48
        RET

        # Store odd width
        .p2align    3
3:
        TBZ         x1, 2, 4f
        ST1         {v0.s}[2], [x7], 4
        STR         s0, [x6], 4
        EXT         v0.16b, v0.16b, v0.16b, 4

4:
        TBZ         x1, 1, 5f
        ST1         {v0.h}[4], [x7], 2
        STR         h0, [x6], 2
        EXT         v0.16b, v0.16b, v0.16b, 2
5:
        TBZ         x1, 0, 6f
        ST1         {v0.b}[8], [x7]
        STR         b0, [x6]
6:
        # Restore d10-d15 from stack
        LDP         d14, d15, [sp, 32]
        LDP         d12, d13, [sp, 16]
        LDP         d10, d11, [sp], 48
        RET

END_FUNCTION xnn_${DATATYPE_SPEC}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif

