// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "xnnpack/assembly.h"

.syntax unified

// void xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}(
//     size_t mr,                            r0
//     size_t nc,                            r1
//     size_t kc,                            r2 -> r5 -> sp + 68
//     size_t ks,                            r3 -> sp + 72 -> r14
//     const float** restrict a,  sp + 112 -> (r5)
//     const void* restrict w,    sp + 116 -> r9
//     uint8_t* restrict c,       sp + 120 -> r11
//     size_t cm_stride,         sp + 124 -> (r6)
//     size_t cn_stride,         sp + 128 -> (r0)
//     size_t a_offset,          sp + 132 -> (r5)
//     const float* zero,        sp + 136 -> (r0)
//     minmax_params*params,     sp + 140 -> (r2)

// d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.

// Register usage
// A0   r3  d0 d4
// A1  r12  d1 d5
// A2  r10  d2 d6
// A3   r7  d3 d7
// B    r9  d8,  d9, d10, d11
// B       d12, d13, d14, d15
// C0  r11 d16-d17  q8  d18-d19  q9
// C1   r4 d20-d21 q10  d22-d23 q11
// C2   r8 d24-d25 q12  d26-d27 q13
// C3   r6 d28-d29 q14  d30-d31 q15
// clamp  (r2) d4 d5 d6 d7
// temp r0, r2 for Cortex-A53 loads

BEGIN_FUNCTION xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}
        .arm
#ifndef __APPLE__
        .arch       armv7-a
        .fpu        neon
#endif
        # Push 112 bytes
        # r2 will be reloaded in outer loop.  r3 is ks
        PUSH        {r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, lr}      // +44
        SUB         sp, sp, 4                                           // 4
        VPUSH       {d8-d15}                                            // +64 = 112

        LDR         r11, [sp, 120]          // c
        LDR         r6, [sp, 124]           // cm_stride
        LDR         r5, [sp, 112]           // a
        LDR         r9, [sp, 116]           // w
        MOV         r14, r3                 // p = ks

        # Clamp C pointers
        CMP         r0, 2                   // if mr >= 2
        ADD         r4, r11, r6             //   c1 = c0 + cm_stride
        MOVLO       r4, r11                 // c1
                                     // if mr > 2
        ADD         r8, r4, r6              //   c2 = c1 + cm_stride
        MOVLS       r8, r4                  // c2
        CMP         r0, 4                   // if mr >=4
        ADD         r6, r8, r6              //   c3 = c2 + cm_stride
        MOVLO       r6, r8                  // c3


        .p2align    3
0:
        # Load initial bias from w into accumulators
        VLDM        r9!, {d16-d19}          // Bias

        VMOV        q10, q8
        $if PREFETCH:
          PLD         [r9,   0]               // Prefetch B
        VMOV        q11, q9
        $if PREFETCH:
          PLD         [r9,  64]
        VMOV        q12, q8
        $if PREFETCH:
          PLD         [r9, 128]
        VMOV        q13, q9
        $if PREFETCH:
          PLD         [r9, 192]
        VMOV        q14, q8
        $if PREFETCH:
          PLD         [r9, 256]
        VMOV        q15, q9
        $if PREFETCH:
          PLD         [r9, 320]

1:
        # Load next 4 A pointers
        LDR         r3, [r5,  0]
        LDR         r12, [r5,  4]
        LDR         r10, [r5,  8]
        LDR         r7, [r5, 12]
        ADD         r5, r5, 16              // a += MR * sizeof(void*)
        $if PREFETCH:
          PLD         [r3,  0]                // Prefetch A
        STR         r5, [sp, 112]           // a
        $if PREFETCH:
          PLD         [r3, 64]
        LDR         r0, [sp, 136]           // zero
        $if PREFETCH:
          PLD         [r12,  0]
        LDR         r5, [sp, 132]           // a_offset
        $if PREFETCH:
          PLD         [r12, 64]
        LDR         r2, [sp, 68]            // kc
        $if PREFETCH:
          PLD         [r10,  0]
        $if PREFETCH:
          PLD         [r10, 64]
        $if PREFETCH:
          PLD         [r7,  0]
        $if PREFETCH:
          PLD         [r7, 64]

        # Add a_offset
        CMP         r3,  r0                 // if a0 == zero
        ADD         r3,  r3, r5             // a0 += a_offset
        MOVEQ       r3,  r0                 //   a0 = zero, else += a0 + a_offset
        CMP         r12,  r0                // if a1 == zero
        ADD         r12, r12, r5            // a1 += a_offset
        MOVEQ       r12,  r0                //   a1 = zero, else += a1 + a_offset
        CMP         r10,  r0                // if a2 == zero
        ADD         r10, r10, r5            // a2 += a_offset
        MOVEQ       r10,  r0                //   a2 = zero, else += a2 + a_offset
        CMP         r7,  r0                 // if a3 == zero
        ADD         r7,  r7, r5             // a3 += a_offset
        MOVEQ       r7,  r0                 //   a3 = zero, else += a3 + a_offset

        SUBS        r5, r2, 16              // kc - 16
        BLO         5f                      // less than 4 channels?

        # Prologue
        VLD1.32     {d0},  [r3]!            // A0
        VLD1.32     {d1}, [r12]!            // A1
        VLD1.32     {d2}, [r10]!            // A2
        VLD1.32     {d3},  [r7]!            // A3
        SUBS        r5, r5, 16
        VLDM        r9, {d8-d11}            // B0
        LDR         r0, [r9, 56]            // B1 low   VMOV is in BLOCK 0
        LDR         r2, [r9, 60]            // B1 high
        VLDR        d13, [r9, 40]           // B1

        BLO         3f                      // less than 4 channels?  skip main loop

        # Main loop - 4 floats of A (16 bytes)
        # 32 FMA + 8 LD64 A + 8 LDR B
        .p2align    3
2:
        # First group of 16 FMA, Second group loads
        # BLOCK 0
        VLD1.32     {d4}, [r3]!             // A0
        VMOV        d15, r0, r2             // b1 VMOV b from second group
        VMLA.F32    q8, q4, d0[0]
        LDR         r0, [r12]               // A1 low
        VMLA.F32    q10, q4, d1[0]
        LDR         r2, [r12, 4]            // A1 high
        VMLA.F32    q12, q4, d2[0]
        $if PREFETCH:
          PLD         [r3, 128]               // Prefetch A0

        # BLOCK 1
        VLDR        d12, [r9, 32]           // B1
        VMOV        d5, r0, r2              // a1 VMOV
        VMLA.F32    q14, q4, d3[0]
        LDR         r0, [r9, 72]            // B0 low
        VMLA.F32    q9, q5, d0[0]
        LDR         r2, [r9, 76]            // B0 high
        VMLA.F32    q11, q5, d1[0]
        $if PREFETCH:
          PLD         [r12, 128]              // Prefetch A1

        # BLOCK 2
        VLD1.32     {d6}, [r10]!            // A2
        VMOV        d9, r0, r2              // b0 VMOV
        VMLA.F32    q13, q5, d2[0]
        LDR         r0, [r7]                // A3 low
        VMLA.F32    q15, q5, d3[0]
        LDR         r2, [r7, 4]             // A3 high
        VMLA.F32    q8, q6, d0[1]
        $if PREFETCH:
          PLD         [r10, 128]              // Prefetch A2

        # BLOCK 3
        VLDR        d14, [r9, 48]           // B1
        VMOV        d7, r0, r2              // a3 VMOV
        VMLA.F32    q10, q6, d1[1]
        LDR         r0, [r9, 88]            // B0 low
        VMLA.F32    q12, q6, d2[1]
        LDR         r2, [r9, 92]            // B0 high
        VMLA.F32    q14, q6, d3[1]
        $if PREFETCH:
          PLD         [r7, 128]               // Prefetch A3

        # BLOCK 4
        VLDR        d8, [r9, 64]            // B0
        VMOV        d11, r0, r2             // B0 VMOV
        VMLA.F32    q9, q7, d0[1]
        LDR         r0, [r9, 104]           // B1 low   VMOV is in BLOCK 0
        VMLA.F32    q11, q7, d1[1]
        LDR         r2, [r9, 108]           // B1 high
        VMLA.F32    q13, q7, d2[1]
        $if PREFETCH:
          PLD         [r9, 384]               // Prefetch B

        # BLOCK 5
        VLDR        d10, [r9, 80]           // B0
        VMOV        d13, r0, r2             // b1 VMOV b from second group
        VMLA.F32    q15, q7, d3[1]
        LDR         r0, [r9, 120]           // B1 low   VMOV is in BLOCK 0
        NOP
        LDR         r2, [r9, 124]           // B1 high
        NOP
        $if PREFETCH:
          PLD         [r9, 448]               // Prefetch B

        # Second group of 16 FMA, First group of loads
        # BLOCK 0
        VLD1.32     {d0}, [r3]!             // A0
        VMOV        d15, r0, r2             // b1 VMOV b from second group
        VMLA.F32    q8, q4, d4[0]
        LDR         r0, [r12, 8]            // A1 low
        VMLA.F32    q10, q4, d5[0]
        LDR         r2, [r12, 12]           // A1 high
        VMLA.F32    q12, q4, d6[0]
        # NOP

        # BLOCK 1
        VLDR        d12, [r9, 96]           // B1
        VMOV        d1, r0, r2              // a1 VMOV
        VMLA.F32    q14, q4, d7[0]
        LDR         r0, [r9, 136]           // B0 low
        VMLA.F32    q9, q5, d4[0]
        LDR         r2, [r9, 140]           // B0 high
        VMLA.F32    q11, q5, d5[0]
        # NOP

        # BLOCK 2
        VLD1.32     {d2}, [r10]!            // A2
        VMOV        d9, r0, r2              // b0 VMOV
        VMLA.F32    q13, q5, d6[0]
        LDR         r0, [r7, 8]             // A3 low
        VMLA.F32    q15, q5, d7[0]
        LDR         r2, [r7, 12]            // A3 high
        VMLA.F32    q8, q6, d4[1]
        # NOP

        # BLOCK 3
        VLDR        d14, [r9, 112]          // B1
        VMOV        d3, r0, r2              // a3 VMOV
        VMLA.F32    q10, q6, d5[1]
        LDR         r0, [r9, 152]           // B0 low
        VMLA.F32    q12, q6, d6[1]
        LDR         r2, [r9, 156]           // B0 high
        VMLA.F32    q14, q6, d7[1]
        ADD         r12, r12, 16            // A1++

        # BLOCK 4
        VLDR        d8, [r9, 128]           // B0
        VMOV        d11, r0, r2             // B0 VMOV
        VMLA.F32    q9, q7, d4[1]
        LDR         r0, [r9, 168]           // B1 low
        VMLA.F32    q11, q7, d5[1]
        LDR         r2, [r9, 172]           // B1 high
        VMLA.F32    q13, q7, d6[1]
        ADD         r7, r7, 16              // A3++

        # BLOCK 5
        VLDR        d10, [r9, 144]          // B0
        VMOV        d13, r0, r2             // b1 VMOV b
        VMLA.F32    q15, q7, d7[1]
        LDR         r0, [r9, 184]           // B1 low   VMOV is in BLOCK 0
        SUBS        r5, r5, 16
        LDR         r2, [r9, 188]           // B1 high
        ADD         r9, r9, 128             // B++
        BHS         2b

        # Epilogue - 4 floats of A (16 bytes)
3:
        # First group of 16 FMA, Second group loads
        # BLOCK 0
        VLD1.32     {d4}, [r3]!             // A0
        VMOV        d15, r0, r2             // b1 VMOV b from second group
        VMLA.F32    q8, q4, d0[0]
        LDR         r0, [r12]               // A1 low
        VMLA.F32    q10, q4, d1[0]
        LDR         r2, [r12, 4]            // A1 high
        VMLA.F32    q12, q4, d2[0]
        # NOP

        # BLOCK 1
        VLDR        d12, [r9, 32]           // B1
        VMOV        d5, r0, r2              // a1 VMOV
        VMLA.F32    q14, q4, d3[0]
        LDR         r0, [r9, 72]            // B0 low
        VMLA.F32    q9, q5, d0[0]
        LDR         r2, [r9, 76]            // B0 high
        VMLA.F32    q11, q5, d1[0]
        # NOP

        # BLOCK 2
        VLD1.32     {d6}, [r10]!            // A2
        VMOV        d9, r0, r2              // b0 VMOV
        VMLA.F32    q13, q5, d2[0]
        LDR         r0, [r7]                // A3 low
        VMLA.F32    q15, q5, d3[0]
        LDR         r2, [r7, 4]             // A3 high
        VMLA.F32    q8, q6, d0[1]
        # NOP

        # BLOCK 3
        VLDR        d14, [r9, 48]           // B1
        VMOV        d7, r0, r2              // a3 VMOV
        VMLA.F32    q10, q6, d1[1]
        LDR         r0, [r9, 88]            // B0 low
        VMLA.F32    q12, q6, d2[1]
        LDR         r2, [r9, 92]            // B0 high
        VMLA.F32    q14, q6, d3[1]
        # NOP

        # BLOCK 4
        VLDR        d8, [r9, 64]            // B0
        VMOV        d11, r0, r2             // B0 VMOV
        VMLA.F32    q9, q7, d0[1]
        LDR         r0, [r9, 104]           // B1 low
        VMLA.F32    q11, q7, d1[1]
        LDR         r2, [r9, 108]           // B1 high
        VMLA.F32    q13, q7, d2[1]
        # NOP

        # BLOCK 5
        VLDR        d10, [r9, 80]           // B0
        VMOV        d13, r0, r2             // b1 VMOV b
        VMLA.F32    q15, q7, d3[1]
        LDR         r0, [r9, 120]           // B1 low   VMOV is in BLOCK 0
        NOP
        LDR         r2, [r9, 124]           // B1 high
        NOP
        NOP

        # Second group of 16 FMA, First group of loads
        # BLOCK 0
        VLDR        d12, [r9, 96]           // B1
        VMOV        d15, r0, r2             // b1 VMOV b from second group
        VMLA.F32    q8, q4, d4[0]
        VMLA.F32    q10, q4, d5[0]
        VMLA.F32    q12, q4, d6[0]

        # BLOCK 1
        VLDR        d14, [r9, 112]          // B1
        VMLA.F32    q14, q4, d7[0]
        VMLA.F32    q9, q5, d4[0]
        VMLA.F32    q11, q5, d5[0]
        ADD         r12, r12, 8             // A1++

        # BLOCK 2
        ADD         r7, r7, 8               // A3++ VLDR B1 lands here
        ADD         r9, r9, 128             // B++
        VMLA.F32    q13, q5, d6[0]
        VMLA.F32    q15, q5, d7[0]
        VMLA.F32    q8, q6, d4[1]

        # BLOCK 3
        VMLA.F32    q10, q6, d5[1]
        VMLA.F32    q12, q6, d6[1]
        VMLA.F32    q14, q6, d7[1]
        TST         r5, 15

        # BLOCK 4
        VMLA.F32    q9, q7, d4[1]
        VMLA.F32    q11, q7, d5[1]
        VMLA.F32    q13, q7, d6[1]

        # BLOCK 5
        VMLA.F32    q15, q7, d7[1]

        # Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
        BNE         5f

        .p2align    3
4:
        LDR         r5, [sp, 112]           // a
        SUBS        r14, r14, 16            // ks -= MR * sizeof(void*)

        # ks loop
        BHI         1b

        # Load params pointer
        LDR         r0, [sp, 128]           // cn_stride
        LDR         r2, [sp, 140]           // params
        LDR         r14, [sp, 72]           // p = ks
        SUBS        r1, r1, 8

        # Load min/max values
        VLD1.32     {d4[],d5[]}, [r2]!
        VLD1.32     {d6[],d7[]}, [r2]

        # Clamp
        VMAX.F32    q8,  q8, q2
        VMAX.F32    q9,  q9, q2
        VMAX.F32    q10, q10, q2
        VMAX.F32    q11, q11, q2
        VMAX.F32    q12, q12, q2
        VMAX.F32    q13, q13, q2
        VMAX.F32    q14, q14, q2
        VMAX.F32    q15, q15, q2
        VMIN.F32    q8,  q8, q3
        VMIN.F32    q9,  q9, q3
        VMIN.F32    q10, q10, q3
        VMIN.F32    q11, q11, q3
        VMIN.F32    q12, q12, q3
        VMIN.F32    q13, q13, q3
        VMIN.F32    q14, q14, q3
        VMIN.F32    q15, q15, q3

        # Store full 4 x 8
        BLO         7f
        VST1.32     {d28-d31},  [r6], r0
        VST1.32     {d24-d27},  [r8], r0
        VST1.32     {d20-d23},  [r4], r0
        VST1.32     {d16-d19}, [r11], r0

        SUB         r5, r5, r14             // a -= ks

        BHI         0b

        VPOP        {d8-d15}
        ADD         sp, sp, 12              // skip pad, r2, r3
        POP         {r4, r5, r6, r7, r8, r9, r10, r11, pc}

        .p2align    3
5:
        # Is there a remainder?- 2 floats of A (8 bytes)
        TST         r5, 8
        BEQ         6f

        # Remainder - 2 floats of A (8 bytes)
        VLD1.32     {d0}, [r3]!             // A0
        VLDM        r9!, {d8-d11}           // B0
        VLD1.32     {d1}, [r12]!            // A1
        VLD1.32     {d2}, [r10]!            // A2
        VLD1.32     {d3}, [ r7]!            // A3

        VMLA.F32    q8, q4, d0[0]
        VMLA.F32    q9, q5, d0[0]
        VMLA.F32    q10, q4, d1[0]
        VMLA.F32    q11, q5, d1[0]
        VLDM        r9!, {d12-d15}          // B1
        VMLA.F32    q12, q4, d2[0]
        VMLA.F32    q13, q5, d2[0]
        VMLA.F32    q14, q4, d3[0]
        VMLA.F32    q15, q5, d3[0]
        VMLA.F32    q8, q6, d0[1]
        VMLA.F32    q9, q7, d0[1]
        VMLA.F32    q10, q6, d1[1]
        VMLA.F32    q11, q7, d1[1]
        VMLA.F32    q12, q6, d2[1]
        VMLA.F32    q13, q7, d2[1]
        VMLA.F32    q14, q6, d3[1]
        VMLA.F32    q15, q7, d3[1]

        # Is there a remainder?- 1 float of A (4 bytes)
        TST         r5, 4
        BEQ         4b

6:
        # Remainder- 1 float of A (4 bytes)
        VLDM        r3!, {s0}               // A0
        VLDM        r9!, {d8-d11}           // B0
        VLDM        r12!, {s2}              // A1
        VLDM        r10!, {s4}              // A2
        VLDM        r7!, {s6}               // A3
        VMLA.F32    q8, q4, d0[0]
        VMLA.F32    q9, q5, d0[0]
        VMLA.F32    q10, q4, d1[0]
        VMLA.F32    q11, q5, d1[0]
        VMLA.F32    q12, q4, d2[0]
        VMLA.F32    q13, q5, d2[0]
        VMLA.F32    q14, q4, d3[0]
        VMLA.F32    q15, q5, d3[0]
        B           4b

        # Store odd width
7:
        TST         r1, 4
        BEQ         8f
        VST1.32     {d28-d29},  [r6]!
        VST1.32     {d24-d25},  [r8]!
        VMOV        q14, q15
        VMOV        q12, q13
        VST1.32     {d20-d21},  [r4]!
        VST1.32     {d16-d17}, [r11]!
        VMOV        q10, q11
        VMOV        q8,  q9

8:
        TST         r1, 2
        BEQ         9f
        VST1.32     {d28},  [r6]!
        VST1.32     {d24},  [r8]!
        VMOV        d28, d29
        VMOV        d24, d25
        VST1.32     {d20},  [r4]!
        VST1.32     {d16}, [r11]!
        VMOV        d20, d21
        VMOV        d16, d17

9:
        TST         r1, 1
        BEQ         10f
        VST1.32     {d28[0]},  [r6]!
        VST1.32     {d24[0]},  [r8]!
        VST1.32     {d20[0]},  [r4]!
        VST1.32     {d16[0]}, [r11]!

10:
        VPOP        {d8-d15}
        ADD         sp, sp, 12              // skip pad, r2, r3
        POP         {r4, r5, r6, r7, r8, r9, r10, r11, pc}

END_FUNCTION xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif
