// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "xnnpack/assembly.h"

.syntax unified

// void xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}(
//     size_t mr,                            (unused)
//     size_t nc,                            r1
//     size_t kc,                            r2 -> r0
//     size_t ks,                            (r3) -> sp + 4 -> r14
//     const float** restrict a,   sp + 24 -> r4
//     const void* restrict w,     sp + 28 -> r9
//     uint8_t* restrict c,        sp + 32 -> r12
//     size_t cm_stride,          sp + 36 -> (unused)
//     size_t cn_stride,          sp + 40 -> (r7)
//     size_t a_offset,           sp + 44 -> (r0)
//     const float* zero,         sp + 48 -> (r7)
//     minmax_params*params,      sp + 52 -> (r0)

// d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.

// Register usage
// A0   r3  d0
// B    r9 d24, d25, d26, d27
// B       d28, d29, d30, d31
// C0  r12 d16-d17  q8  d18-d19  q9
// clamp  (r0) d4 d5 d6 d7

BEGIN_FUNCTION xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}
        .arm
#ifndef __APPLE__
        .arch       armv7-a
        .fpu        neon
#endif
        # Push 24 bytes
        # r3 is ks
        PUSH        {r3, r4, r7, r9, lr}    // 20
        SUB         sp, sp, 4               // +4 = 24

        LDR         r4, [sp, 24]            // a
        LDR         r9, [sp, 28]            // w
        LDR         r12, [sp, 32]           // c
        LDR         r0, [sp, 52]            // params
        MOV         r14, r3                 // p = ks

        # Load min/max values
        VLD1.32     {d4[], d5[]}, [r0]!
        VLD1.32     {d6[], d7[]}, [r0]

0:
        # Load initial bias from w into accumulators
        VLDM        r9!, {d16-d19}          // Bias
        VMOV.I32    q10, 0                  // second set of C for pipelining VMLA
        $if PREFETCH:
          PLD         [r9]                    // Prefetch B
        VMOV.I32    q11, 0

        $if PREFETCH:
          PLD         [r9,  64]
          PLD         [r9, 128]
          PLD         [r9, 192]
          PLD         [r9, 256]
          PLD         [r9, 320]
          PLD         [r9, 384]
          PLD         [r9, 448]
          PLD         [r9, 512]
          PLD         [r9, 576]
1:
        # Load next A pointer
        LDR         r3, [r4], 4

        # Add a_offset
        LDR         r0, [sp, 44]            // a_offset
        LDR         r7, [sp, 48]            // zero
        CMP         r3,  r7                 // if a0 == zero
        ADD         r3,  r3, r0             // a0 += a_offset
        MOVEQ       r3,  r7                 //   a0 = zero, else += a0 + a_offset

        SUBS        r0, r2, 8               // kc - 8

        $if PREFETCH:
          PLD         [r3,  0]                // Prefetch A
          PLD         [r3, 64]

        BLO         4f                      // less than 2 channels?

        # Main loop - 2 floats of A (8 bytes)
2:
        VLDM        r9!, {d24-d27}          // B0
        VLD1.32     {d0}, [r3]!             // A0
        VLDM        r9!, {d28-d31}          // B1

        VMLA.F32    q8, q12, d0[0]
        VMLA.F32    q9, q13, d0[0]
        $if PREFETCH:
          PLD         [r9, 576]               // Prefetch B
        VMLA.F32    q10, q14, d0[1]
        VMLA.F32    q11, q15, d0[1]
        SUBS        r0, r0, 8
        $if PREFETCH:
          PLD         [r3, 128]               // Prefetch A0
        BHS         2b

        # Is there a remainder?- 1 float of A (4 bytes)
        TST         r0, 4
        BNE         4f

3:
        # ks loop
        SUBS        r14, r14, 4             // ks -= MR * sizeof(void*)
        BHI         1b

        LDR         r7, [sp, 40]            // cn_stride
        VADD.F32    q8, q8, q10
        LDR         r14, [sp, 4]            // p = ks
        VADD.F32    q9, q9, q11

        # Clamp
        VMAX.F32    q8,  q8, q2
        SUBS        r1, r1, 8
        VMAX.F32    q9,  q9, q2
        VMIN.F32    q8,  q8, q3
        VMIN.F32    q9,  q9, q3

        # Store full 1 x 8
        BLO         5f
        VST1.32     {d16-d19}, [r12], r7
        SUB         r4, r4, r14             // a -= ks
        BHI         0b

        ADD         sp, sp, 8               // skip pad, r3
        POP         {r4, r7, r9, pc}

4:
        # Remainder- 1 float of A (4 bytes)
        VLDM        r3!,  {s0}              // A0
        VLDM        r9!, {d24-d27}          // B0
        VMLA.F32    q8, q12, d0[0]
        VMLA.F32    q9, q13, d0[0]
        B           3b

        # Store odd width
5:
        TST         r1, 4
        BEQ         6f
        VST1.32     {d16-d17}, [r12]!
        VMOV        q8,  q9

6:
        TST         r1, 2
        BEQ         7f
        VST1.32     {d16}, [r12]!
        VMOV        d16, d17

7:
        TST         r1, 1
        BEQ         8f
        VST1.32     {d16[0]}, [r12]!

8:
        ADD         sp, sp, 8               // skip pad, r3
        POP         {r4, r7, r9, pc}

END_FUNCTION xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53${"_prfm" if PREFETCH else ""}

#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif


