// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert REQUANTIZATION in ["FP32", "RNDNU"]
$assert DATATYPE in ["QC8", "QS8"]
$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"

#include "xnnpack/assembly.h"

$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8": "qs8"}[DATATYPE]
$PARAMS_UNION = {"QC8": "xnn_qs8_qc8w_conv_minmax_params", "QS8": "xnn_qs8_conv_minmax_params"}[DATATYPE]
# void xnn_${DATATYPE_SPEC}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal(
#     size_t mr,                 x0
#     size_t nc,                 x1
#     size_t kc,                 x2 / x0
#     const int8_t* restrict a,  x3
#     size_t a_stride,           x4
#     const void* restrict w,    x5
#     int8_t* restrict c,        x6
#     size_t cm_stride,          x7
#     size_t cn_stride,          [sp] -> x10
#     const union ${PARAMS_UNION} params)  [sp + 8] -> x11

# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.

// Register usage
// A0  x3  v0
// A1  x4  v1
// B   x5  v4  v5  v6  v7
// C0  x7 v16 v18 v20 v22 v24 v26 v28 v30
// C1  x8 v17 v19 v21 v23 v25 v27 v29 v31
// temp0   v2 v10 v12 v14
// temp1   v3 v11 v13 v15
// unused  v8 v9

BEGIN_FUNCTION xnn_${DATATYPE_SPEC}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal

        # Clamp A and C pointers
        CMP         x0, 2                   // if mr < 2
        STP         d10, d11, [sp, -48]!
        ADD         x4, x3, x4              // a1 = a0 + a_stride
        STP         d12, d13, [sp, 16]
        ADD         x7, x6, x7              // c1 = c0 + cm_stride
        STP         d14, d15, [sp, 32]
        CSEL        x4, x3, x4, LO          //   a1 = a0
        ADD         x2, x2, 15              // kc = (kc + 15) & ~15
        CSEL        x7, x6, x7, LO          //   c1 = c0
        BIC         x2, x2, 15

        .p2align    3
0:
        # Load initial bias from w into accumulators
        MOV         x0, x2                  // k = kc
        LDP         s16, s18, [x5], 8
        MOV         v17.16b, v16.16b
        MOV         v19.16b, v18.16b
        LDP         s20, s22, [x5], 8
        MOV         v21.16b, v20.16b
        MOV         v23.16b, v22.16b
        LDP         s24, s26, [x5], 8
        MOV         v25.16b, v24.16b
        MOV         v27.16b, v26.16b
        LDP         s28, s30, [x5], 8
        MOV         v29.16b, v28.16b
        LDP         x10, x11, [sp, 48]          // cn_stride, params
        MOV         v31.16b, v30.16b

        # Main loop - 16 bytes of A
        .p2align    3
1:
        LDR         q0, [x3], 16
        LDP         q4, q5, [x5]
        LDR         q1, [x4], 16
        LDP         q6, q7, [x5, 32]
        SMULL       v2.8h, v4.8b, v0.8b
        SMULL       v3.8h, v4.8b, v1.8b
        SMULL       v10.8h, v5.8b, v0.8b
        SMULL       v11.8h, v5.8b, v1.8b
        SMLAL2      v2.8h, v4.16b, v0.16b
        SMLAL2      v3.8h, v4.16b, v1.16b
        SMLAL2      v10.8h, v5.16b, v0.16b
        SMLAL2      v11.8h, v5.16b, v1.16b
        SMULL       v12.8h, v6.8b, v0.8b
        SADALP      v16.4s,  v2.8h
        SMULL       v13.8h, v6.8b, v1.8b
        SADALP      v17.4s,  v3.8h
        SMULL       v14.8h, v7.8b, v0.8b
        SADALP      v18.4s, v10.8h
        SMULL       v15.8h, v7.8b, v1.8b
        SADALP      v19.4s, v11.8h
        LDP         q4, q5, [x5, 64]
        SMLAL2      v12.8h, v6.16b, v0.16b
        SMLAL2      v13.8h, v6.16b, v1.16b
        SMLAL2      v14.8h, v7.16b, v0.16b
        SMLAL2      v15.8h, v7.16b, v1.16b
        SMULL       v2.8h, v4.8b, v0.8b
        SADALP      v20.4s, v12.8h
        SMULL       v3.8h, v4.8b, v1.8b
        SADALP      v21.4s, v13.8h
        SMULL       v10.8h, v5.8b, v0.8b
        SADALP      v22.4s, v14.8h
        SMULL       v11.8h, v5.8b, v1.8b
        SADALP      v23.4s, v15.8h
        LDP         q6, q7, [x5, 96]

        SMLAL2      v2.8h, v4.16b, v0.16b
        SMLAL2      v3.8h, v4.16b, v1.16b
        SMLAL2      v10.8h, v5.16b, v0.16b
        SMLAL2      v11.8h, v5.16b, v1.16b
        ADD         x5, x5, 128
        SMULL       v12.8h, v6.8b, v0.8b
        SADALP      v24.4s,  v2.8h
        SMULL       v13.8h, v6.8b, v1.8b
        SADALP      v25.4s,  v3.8h
        SMULL       v14.8h, v7.8b, v0.8b
        SADALP      v26.4s, v10.8h
        SMULL       v15.8h, v7.8b, v1.8b
        SADALP      v27.4s, v11.8h
        SUBS        x0, x0, 16
        SMLAL2      v12.8h, v6.16b, v0.16b
        SMLAL2      v13.8h, v6.16b, v1.16b
        SMLAL2      v14.8h, v7.16b, v0.16b
        SMLAL2      v15.8h, v7.16b, v1.16b
        SADALP      v28.4s, v12.8h
        SADALP      v29.4s, v13.8h
        SADALP      v30.4s, v14.8h
        SADALP      v31.4s, v15.8h
        B.HI        1b

        # Add columns
        ADDP        v16.4s, v16.4s, v18.4s
        ADDP        v20.4s, v20.4s, v22.4s
        $if REQUANTIZATION == "RNDNU":
          LD1R        {v4.4s}, [x11], 4
        ADDP        v24.4s, v24.4s, v26.4s
        ADDP        v28.4s, v28.4s, v30.4s
        $if REQUANTIZATION == "RNDNU":
          LD1R        {v7.4s}, [x11], 4
        ADDP        v17.4s, v17.4s, v19.4s
        ADDP        v21.4s, v21.4s, v23.4s
        ADDP        v25.4s, v25.4s, v27.4s
        ADDP        v29.4s, v29.4s, v31.4s
        ADDP        v0.4s, v16.4s, v20.4s
        ADDP        v1.4s, v24.4s, v28.4s
        ADDP        v2.4s, v17.4s, v21.4s
        ADDP        v3.4s, v25.4s, v29.4s

        $if REQUANTIZATION == "RNDNU":
          # Apply params - preshift, scale, postshift, bias and clamp
          LD1R        {v5.4s}, [x11], 4
          SQSHL       v0.4s, v0.4s, v4.4s     // shift to upper bits
          SQSHL       v1.4s, v1.4s, v4.4s
          SQSHL       v2.4s, v2.4s, v4.4s
          SQSHL       v3.4s, v3.4s, v4.4s
          SQDMULH     v0.4s, v0.4s, v7.4s     // scale without rounding
          SQDMULH     v1.4s, v1.4s, v7.4s
          SQDMULH     v2.4s, v2.4s, v7.4s
          SQDMULH     v3.4s, v3.4s, v7.4s
          SRSHL       v0.4s, v0.4s, v5.4s     // signed rounding shift left
          SRSHL       v1.4s, v1.4s, v5.4s
          SRSHL       v2.4s, v2.4s, v5.4s
          SRSHL       v3.4s, v3.4s, v5.4s
        $elif REQUANTIZATION == "FP32":
          $if DATATYPE != "QC8":
            # Apply params - scale, bias and clamp
            SCVTF       v0.4s, v0.4s
            LD1R        {v4.4s}, [x11], 4
            SCVTF       v1.4s, v1.4s
            SCVTF       v2.4s, v2.4s
            SCVTF       v3.4s, v3.4s
            FMUL        v0.4s, v0.4s, v4.4s
            FMUL        v1.4s, v1.4s, v4.4s
            FMUL        v2.4s, v2.4s, v4.4s
            FMUL        v3.4s, v3.4s, v4.4s
          $else:
            # Load per channel scale values from weights
            SCVTF       v0.4s, v0.4s
            LDR         q4, [x5], 16
            SCVTF       v1.4s, v1.4s
            LDR         q5, [x5], 16
            SCVTF       v2.4s, v2.4s
            SCVTF       v3.4s, v3.4s
            FMUL        v0.4s, v0.4s, v4.4s
            FMUL        v1.4s, v1.4s, v5.4s
            FMUL        v2.4s, v2.4s, v4.4s
            FMUL        v3.4s, v3.4s, v5.4s

          FCVTNS      v0.4s, v0.4s
          FCVTNS      v1.4s, v1.4s
          FCVTNS      v2.4s, v2.4s
          FCVTNS      v3.4s, v3.4s

        LD1R        {v5.8h}, [x11], 2
        SQXTN       v0.4h, v0.4s
        SQXTN       v2.4h, v2.4s
        SQXTN2      v0.8h, v1.4s
        SQXTN2      v2.8h, v3.4s
        SUBS        x1, x1, 8
        SQADD       v0.8h, v0.8h, v5.8h
        SQADD       v1.8h, v2.8h, v5.8h
        SQXTN       v0.8b, v0.8h
        SQXTN2      v0.16b, v1.8h
        LD1R        {v1.16b}, [x11], 1
        LD1R        {v2.16b}, [x11]
        SMAX        v0.16b, v0.16b, v1.16b
        SMIN        v0.16b, v0.16b, v2.16b
        B.LO        2f

        # Store full 2 x 8
        ST1         {v0.8b}, [x6], x10
        SUB         x3, x3, x2              // a0 -= kc
        ST1         {v0.d}[1], [x7], x10
        SUB         x4, x4, x2              // a1 -= kc
        B.HI        0b

        # Restore d10-d15 from stack
        LDP         d14, d15, [sp, 32]
        LDP         d12, d13, [sp, 16]
        LDP         d10, d11, [sp], 48
        RET

        # Store odd width
        .p2align    3
2:
        TBZ         x1, 2, 3f
        STR         s0, [x6], 4
        ST1         {v0.s}[2], [x7], 4
        EXT         v0.16b, v0.16b, v0.16b, 4

3:
        TBZ         x1, 1, 4f
        STR         h0, [x6], 2
        ST1         {v0.h}[4], [x7], 2
        EXT         v0.16b, v0.16b, v0.16b, 2
4:
        TBZ         x1, 0, 5f
        STR         b0, [x6]
        ST1         {v0.b}[8], [x7]
5:
        # Restore d10-d15 from stack
        LDP         d14, d15, [sp, 32]
        LDP         d12, d13, [sp, 16]
        LDP         d10, d11, [sp], 48
        RET

END_FUNCTION xnn_${DATATYPE_SPEC}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c16__asm_aarch64_neon_mlal

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif

