// Auto-generated file. Do not edit! // Template: src/qs8-igemm/c4-avx512amx.c.in // Generator: tools/xngen // // Copyright 2024 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #if defined(__has_feature) #if __has_feature(memory_sanitizer) #include #endif #endif #include #include "xnnpack/gemm.h" #include "xnnpack/intrinsics-polyfill.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx( size_t mr, size_t nc, size_t kc, size_t ks, const int8_t** restrict a, const void* restrict w, float* restrict c, size_t cm_stride, size_t cn_stride, size_t a_offset, const int8_t* zero, const int8_t* zero_data, const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) { assert(mr != 0); assert(mr <= 16); assert(nc != 0); assert(kc != 0); assert(kc % sizeof(int8_t) == 0); assert(a != NULL); assert(w != NULL); assert(c != NULL); // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; const __mmask16 kremainder_mask = _cvtu32_mask16((UINT32_C(1) << (kremainder >> 2)) - 1); // Define tile config data structure struct __tile_config { uint8_t palette_id; uint8_t start_row; uint8_t reserved_0[14]; uint16_t colsb[8]; uint16_t reserved_1[8]; uint8_t rows[8]; uint8_t reserved_2[8]; }; // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; tile_data.rows[0] = mr; // tmm0 = res[0] tile_data.rows[1] = mr; // tmm1 = res[1] tile_data.rows[2] = mr; // tmm2 = res[2] tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder tile_data.colsb[0] = 64; // tmm0 = res[0] tile_data.colsb[1] = 64; // tmm1 = res[1] tile_data.colsb[2] = 64; // tmm2 = res[2] tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder tile_data.colsb[7] = 64; // tmm7 = weights remainder //_tile_loadconfig(&tile_data); __asm__ volatile ("ldtilecfg %0" :: "m" (tile_data)); float* c0 = c; float* c1 = (float*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { c1 = c0; } float* c2 = (float*) ((uintptr_t) c1 + cm_stride); if XNN_UNPREDICTABLE(mr <= 2) { c2 = c1; } float* c3 = (float*) ((uintptr_t) c2 + cm_stride); if XNN_UNPREDICTABLE(mr < 4) { c3 = c2; } float* c4 = (float*) ((uintptr_t) c3 + cm_stride); if XNN_UNPREDICTABLE(mr <= 4) { c4 = c3; } float* c5 = (float*) ((uintptr_t) c4 + cm_stride); if XNN_UNPREDICTABLE(mr < 6) { c5 = c4; } float* c6 = (float*) ((uintptr_t) c5 + cm_stride); if XNN_UNPREDICTABLE(mr <= 6) { c6 = c5; } float* c7 = (float*) ((uintptr_t) c6 + cm_stride); if XNN_UNPREDICTABLE(mr < 8) { c7 = c6; } float* c8 = (float*) ((uintptr_t) c7 + cm_stride); if XNN_UNPREDICTABLE(mr <= 8) { c8 = c7; } float* c9 = (float*) ((uintptr_t) c8 + cm_stride); if XNN_UNPREDICTABLE(mr < 10) { c9 = c8; } float* c10 = (float*) ((uintptr_t) c9 + cm_stride); if XNN_UNPREDICTABLE(mr <= 10) { c10 = c9; } float* c11 = (float*) ((uintptr_t) c10 + cm_stride); if XNN_UNPREDICTABLE(mr < 12) { c11 = c10; } float* c12 = (float*) ((uintptr_t) c11 + cm_stride); if XNN_UNPREDICTABLE(mr <= 12) { c12 = c11; } float* c13 = (float*) ((uintptr_t) c12 + cm_stride); if XNN_UNPREDICTABLE(mr < 14) { c13 = c12; } float* c14 = (float*) ((uintptr_t) c13 + cm_stride); if XNN_UNPREDICTABLE(mr <= 14) { c14 = c13; } float* c15 = (float*) ((uintptr_t) c14 + cm_stride); if XNN_UNPREDICTABLE(mr != 16) { c15 = c14; } const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); // XNN_FORCE_REALIZATION(voutput_max); do { const __m512i vksum0123456789ABCDEF = _mm512_loadu_epi32((const int32_t*) w + 0); w = (const int32_t*) w + 16; // Zero tile accumulator __asm__ volatile ( "tilezero %%tmm0\n" ::); size_t p = ks; do { const int8_t* restrict a0 = a[0]; if XNN_UNPREDICTABLE(a0 != zero) { a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); } else { a0 = zero_data; } const int8_t* restrict a1 = a[1]; if XNN_UNPREDICTABLE(a1 != zero) { a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); } else { a1 = zero_data; } const int8_t* restrict a2 = a[2]; if XNN_UNPREDICTABLE(a2 != zero) { a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); } else { a2 = zero_data; } const int8_t* restrict a3 = a[3]; if XNN_UNPREDICTABLE(a3 != zero) { a3 = (const int8_t*) ((uintptr_t) a3 + a_offset); } else { a3 = zero_data; } const int8_t* restrict a4 = a[4]; if XNN_UNPREDICTABLE(a4 != zero) { a4 = (const int8_t*) ((uintptr_t) a4 + a_offset); } else { a4 = zero_data; } const int8_t* restrict a5 = a[5]; if XNN_UNPREDICTABLE(a5 != zero) { a5 = (const int8_t*) ((uintptr_t) a5 + a_offset); } else { a5 = zero_data; } const int8_t* restrict a6 = a[6]; if XNN_UNPREDICTABLE(a6 != zero) { a6 = (const int8_t*) ((uintptr_t) a6 + a_offset); } else { a6 = zero_data; } const int8_t* restrict a7 = a[7]; if XNN_UNPREDICTABLE(a7 != zero) { a7 = (const int8_t*) ((uintptr_t) a7 + a_offset); } else { a7 = zero_data; } const int8_t* restrict a8 = a[8]; if XNN_UNPREDICTABLE(a8 != zero) { a8 = (const int8_t*) ((uintptr_t) a8 + a_offset); } else { a8 = zero_data; } const int8_t* restrict a9 = a[9]; if XNN_UNPREDICTABLE(a9 != zero) { a9 = (const int8_t*) ((uintptr_t) a9 + a_offset); } else { a9 = zero_data; } const int8_t* restrict a10 = a[10]; if XNN_UNPREDICTABLE(a10 != zero) { a10 = (const int8_t*) ((uintptr_t) a10 + a_offset); } else { a10 = zero_data; } const int8_t* restrict a11 = a[11]; if XNN_UNPREDICTABLE(a11 != zero) { a11 = (const int8_t*) ((uintptr_t) a11 + a_offset); } else { a11 = zero_data; } const int8_t* restrict a12 = a[12]; if XNN_UNPREDICTABLE(a12 != zero) { a12 = (const int8_t*) ((uintptr_t) a12 + a_offset); } else { a12 = zero_data; } const int8_t* restrict a13 = a[13]; if XNN_UNPREDICTABLE(a13 != zero) { a13 = (const int8_t*) ((uintptr_t) a13 + a_offset); } else { a13 = zero_data; } const int8_t* restrict a14 = a[14]; if XNN_UNPREDICTABLE(a14 != zero) { a14 = (const int8_t*) ((uintptr_t) a14 + a_offset); } else { a14 = zero_data; } const int8_t* restrict a15 = a[15]; if XNN_UNPREDICTABLE(a15 != zero) { a15 = (const int8_t*) ((uintptr_t) a15 + a_offset); } else { a15 = zero_data; } a += 16; size_t k = kc; if (mr == 1) { while (k >= 64 * sizeof(int8_t)) { _tile_loadd(4, a0, 64); // Directly load input for mr=1 a15 += 64; _tile_loadd(5, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 4, 5); w = (const int8_t*) w + 1024; k -= 64 * sizeof(int8_t); } } else { while (k >= 64 * sizeof(int8_t)) { const __m512i vin0 = _mm512_loadu_epi32(a0); a0 += 64; _mm512_store_epi32(vintile + 0, vin0); const __m512i vin1 = _mm512_loadu_epi32(a1); a1 += 64; _mm512_store_epi32(vintile + 16, vin1); const __m512i vin2 = _mm512_loadu_epi32(a2); a2 += 64; _mm512_store_epi32(vintile + 32, vin2); const __m512i vin3 = _mm512_loadu_epi32(a3); a3 += 64; _mm512_store_epi32(vintile + 48, vin3); const __m512i vin4 = _mm512_loadu_epi32(a4); a4 += 64; _mm512_store_epi32(vintile + 64, vin4); const __m512i vin5 = _mm512_loadu_epi32(a5); a5 += 64; _mm512_store_epi32(vintile + 80, vin5); const __m512i vin6 = _mm512_loadu_epi32(a6); a6 += 64; _mm512_store_epi32(vintile + 96, vin6); const __m512i vin7 = _mm512_loadu_epi32(a7); a7 += 64; _mm512_store_epi32(vintile + 112, vin7); const __m512i vin8 = _mm512_loadu_epi32(a8); a8 += 64; _mm512_store_epi32(vintile + 128, vin8); const __m512i vin9 = _mm512_loadu_epi32(a9); a9 += 64; _mm512_store_epi32(vintile + 144, vin9); const __m512i vin10 = _mm512_loadu_epi32(a10); a10 += 64; _mm512_store_epi32(vintile + 160, vin10); const __m512i vin11 = _mm512_loadu_epi32(a11); a11 += 64; _mm512_store_epi32(vintile + 176, vin11); const __m512i vin12 = _mm512_loadu_epi32(a12); a12 += 64; _mm512_store_epi32(vintile + 192, vin12); const __m512i vin13 = _mm512_loadu_epi32(a13); a13 += 64; _mm512_store_epi32(vintile + 208, vin13); const __m512i vin14 = _mm512_loadu_epi32(a14); a14 += 64; _mm512_store_epi32(vintile + 224, vin14); const __m512i vin15 = _mm512_loadu_epi32(a15); a15 += 64; _mm512_store_epi32(vintile + 240, vin15); _tile_loadd(4, vintile, 64); _tile_loadd(5, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 4, 5); w = (const int8_t*) w + 1024; k -= 64 * sizeof(int8_t); } } if XNN_UNLIKELY(k != 0) { const __m512i vin0 = _mm512_maskz_loadu_epi32(kremainder_mask, a0); a0 += kremainder; _mm512_store_epi32(vintile + 0, vin0); const __m512i vin1 = _mm512_maskz_loadu_epi32(kremainder_mask, a1); a1 += kremainder; _mm512_store_epi32(vintile + 16, vin1); const __m512i vin2 = _mm512_maskz_loadu_epi32(kremainder_mask, a2); a2 += kremainder; _mm512_store_epi32(vintile + 32, vin2); const __m512i vin3 = _mm512_maskz_loadu_epi32(kremainder_mask, a3); a3 += kremainder; _mm512_store_epi32(vintile + 48, vin3); const __m512i vin4 = _mm512_maskz_loadu_epi32(kremainder_mask, a4); a4 += kremainder; _mm512_store_epi32(vintile + 64, vin4); const __m512i vin5 = _mm512_maskz_loadu_epi32(kremainder_mask, a5); a5 += kremainder; _mm512_store_epi32(vintile + 80, vin5); const __m512i vin6 = _mm512_maskz_loadu_epi32(kremainder_mask, a6); a6 += kremainder; _mm512_store_epi32(vintile + 96, vin6); const __m512i vin7 = _mm512_maskz_loadu_epi32(kremainder_mask, a7); a7 += kremainder; _mm512_store_epi32(vintile + 112, vin7); const __m512i vin8 = _mm512_maskz_loadu_epi32(kremainder_mask, a8); a8 += kremainder; _mm512_store_epi32(vintile + 128, vin8); const __m512i vin9 = _mm512_maskz_loadu_epi32(kremainder_mask, a9); a9 += kremainder; _mm512_store_epi32(vintile + 144, vin9); const __m512i vin10 = _mm512_maskz_loadu_epi32(kremainder_mask, a10); a10 += kremainder; _mm512_store_epi32(vintile + 160, vin10); const __m512i vin11 = _mm512_maskz_loadu_epi32(kremainder_mask, a11); a11 += kremainder; _mm512_store_epi32(vintile + 176, vin11); const __m512i vin12 = _mm512_maskz_loadu_epi32(kremainder_mask, a12); a12 += kremainder; _mm512_store_epi32(vintile + 192, vin12); const __m512i vin13 = _mm512_maskz_loadu_epi32(kremainder_mask, a13); a13 += kremainder; _mm512_store_epi32(vintile + 208, vin13); const __m512i vin14 = _mm512_maskz_loadu_epi32(kremainder_mask, a14); a14 += kremainder; _mm512_store_epi32(vintile + 224, vin14); const __m512i vin15 = _mm512_maskz_loadu_epi32(kremainder_mask, a15); a15 += kremainder; _mm512_store_epi32(vintile + 240, vin15); _tile_loadd(6, vintile, 64); _tile_loadd(7, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 6, 7); w = (const int8_t*) w + kremainder * 16; k -= kremainder * sizeof(int8_t); } p -= 16 * sizeof(void*); } while (p != 0); // TODO: Instead of processing up to 4 tiles (16x64) consider // quantizing 1 tile at a time (16 registers) _tile_stored(0, &res[0][0], 64); // TODO: Fix msan for AMX #if defined(__has_feature) #if __has_feature(memory_sanitizer) __msan_unpoison(res, sizeof(res)); #endif #endif // TODO: Instead of processing up to 4 tiles (16x64) consider // quantizing 1 row at a time. // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc3x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc4x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc5x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc7x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc8x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc9x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc10x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc11x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc12x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); __m512 vscaled7x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc7x0123456789ABCDEF); __m512 vscaled8x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc8x0123456789ABCDEF); __m512 vscaled9x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc9x0123456789ABCDEF); __m512 vscaled10x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc10x0123456789ABCDEF); __m512 vscaled11x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc11x0123456789ABCDEF); __m512 vscaled12x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc12x0123456789ABCDEF); __m512 vscaled13x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc13x0123456789ABCDEF); __m512 vscaled14x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc14x0123456789ABCDEF); __m512 vscaled15x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc15x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled2x0123456789ABCDEF = _mm512_mul_ps(vscaled2x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled3x0123456789ABCDEF = _mm512_mul_ps(vscaled3x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled4x0123456789ABCDEF = _mm512_mul_ps(vscaled4x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled5x0123456789ABCDEF = _mm512_mul_ps(vscaled5x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled6x0123456789ABCDEF = _mm512_mul_ps(vscaled6x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled7x0123456789ABCDEF = _mm512_mul_ps(vscaled7x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled8x0123456789ABCDEF = _mm512_mul_ps(vscaled8x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled9x0123456789ABCDEF = _mm512_mul_ps(vscaled9x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled10x0123456789ABCDEF = _mm512_mul_ps(vscaled10x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled11x0123456789ABCDEF = _mm512_mul_ps(vscaled11x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled12x0123456789ABCDEF = _mm512_mul_ps(vscaled12x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled13x0123456789ABCDEF = _mm512_mul_ps(vscaled13x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled14x0123456789ABCDEF = _mm512_mul_ps(vscaled14x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled15x0123456789ABCDEF = _mm512_mul_ps(vscaled15x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_loadu_ps((const float*) w + 0); w = (const int32_t*) w + 16; const __m512 vbias0123456789ABCDEF = _mm512_loadu_ps((const float*) w + 0); w = (const int32_t*) w + 16; vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled1x0123456789ABCDEF = _mm512_fmadd_ps(vscaled1x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled2x0123456789ABCDEF = _mm512_fmadd_ps(vscaled2x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled3x0123456789ABCDEF = _mm512_fmadd_ps(vscaled3x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled4x0123456789ABCDEF = _mm512_fmadd_ps(vscaled4x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled5x0123456789ABCDEF = _mm512_fmadd_ps(vscaled5x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled6x0123456789ABCDEF = _mm512_fmadd_ps(vscaled6x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled7x0123456789ABCDEF = _mm512_fmadd_ps(vscaled7x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled8x0123456789ABCDEF = _mm512_fmadd_ps(vscaled8x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled9x0123456789ABCDEF = _mm512_fmadd_ps(vscaled9x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled10x0123456789ABCDEF = _mm512_fmadd_ps(vscaled10x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled11x0123456789ABCDEF = _mm512_fmadd_ps(vscaled11x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled12x0123456789ABCDEF = _mm512_fmadd_ps(vscaled12x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled13x0123456789ABCDEF = _mm512_fmadd_ps(vscaled13x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled14x0123456789ABCDEF = _mm512_fmadd_ps(vscaled14x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled15x0123456789ABCDEF = _mm512_fmadd_ps(vscaled15x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); vscaled1x0123456789ABCDEF = _mm512_max_ps(vscaled1x0123456789ABCDEF, voutput_min); vscaled2x0123456789ABCDEF = _mm512_max_ps(vscaled2x0123456789ABCDEF, voutput_min); vscaled3x0123456789ABCDEF = _mm512_max_ps(vscaled3x0123456789ABCDEF, voutput_min); vscaled4x0123456789ABCDEF = _mm512_max_ps(vscaled4x0123456789ABCDEF, voutput_min); vscaled5x0123456789ABCDEF = _mm512_max_ps(vscaled5x0123456789ABCDEF, voutput_min); vscaled6x0123456789ABCDEF = _mm512_max_ps(vscaled6x0123456789ABCDEF, voutput_min); vscaled7x0123456789ABCDEF = _mm512_max_ps(vscaled7x0123456789ABCDEF, voutput_min); vscaled8x0123456789ABCDEF = _mm512_max_ps(vscaled8x0123456789ABCDEF, voutput_min); vscaled9x0123456789ABCDEF = _mm512_max_ps(vscaled9x0123456789ABCDEF, voutput_min); vscaled10x0123456789ABCDEF = _mm512_max_ps(vscaled10x0123456789ABCDEF, voutput_min); vscaled11x0123456789ABCDEF = _mm512_max_ps(vscaled11x0123456789ABCDEF, voutput_min); vscaled12x0123456789ABCDEF = _mm512_max_ps(vscaled12x0123456789ABCDEF, voutput_min); vscaled13x0123456789ABCDEF = _mm512_max_ps(vscaled13x0123456789ABCDEF, voutput_min); vscaled14x0123456789ABCDEF = _mm512_max_ps(vscaled14x0123456789ABCDEF, voutput_min); vscaled15x0123456789ABCDEF = _mm512_max_ps(vscaled15x0123456789ABCDEF, voutput_min); vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max); vscaled7x0123456789ABCDEF = _mm512_min_ps(vscaled7x0123456789ABCDEF, voutput_max); vscaled8x0123456789ABCDEF = _mm512_min_ps(vscaled8x0123456789ABCDEF, voutput_max); vscaled9x0123456789ABCDEF = _mm512_min_ps(vscaled9x0123456789ABCDEF, voutput_max); vscaled10x0123456789ABCDEF = _mm512_min_ps(vscaled10x0123456789ABCDEF, voutput_max); vscaled11x0123456789ABCDEF = _mm512_min_ps(vscaled11x0123456789ABCDEF, voutput_max); vscaled12x0123456789ABCDEF = _mm512_min_ps(vscaled12x0123456789ABCDEF, voutput_max); vscaled13x0123456789ABCDEF = _mm512_min_ps(vscaled13x0123456789ABCDEF, voutput_max); vscaled14x0123456789ABCDEF = _mm512_min_ps(vscaled14x0123456789ABCDEF, voutput_max); vscaled15x0123456789ABCDEF = _mm512_min_ps(vscaled15x0123456789ABCDEF, voutput_max); if XNN_LIKELY(nc >= 16) { _mm512_storeu_ps(c15 + 0, vscaled15x0123456789ABCDEF); c15 = (float*) ((uintptr_t) c15 + cn_stride); _mm512_storeu_ps(c14 + 0, vscaled14x0123456789ABCDEF); c14 = (float*) ((uintptr_t) c14 + cn_stride); _mm512_storeu_ps(c13 + 0, vscaled13x0123456789ABCDEF); c13 = (float*) ((uintptr_t) c13 + cn_stride); _mm512_storeu_ps(c12 + 0, vscaled12x0123456789ABCDEF); c12 = (float*) ((uintptr_t) c12 + cn_stride); _mm512_storeu_ps(c11 + 0, vscaled11x0123456789ABCDEF); c11 = (float*) ((uintptr_t) c11 + cn_stride); _mm512_storeu_ps(c10 + 0, vscaled10x0123456789ABCDEF); c10 = (float*) ((uintptr_t) c10 + cn_stride); _mm512_storeu_ps(c9 + 0, vscaled9x0123456789ABCDEF); c9 = (float*) ((uintptr_t) c9 + cn_stride); _mm512_storeu_ps(c8 + 0, vscaled8x0123456789ABCDEF); c8 = (float*) ((uintptr_t) c8 + cn_stride); _mm512_storeu_ps(c7 + 0, vscaled7x0123456789ABCDEF); c7 = (float*) ((uintptr_t) c7 + cn_stride); _mm512_storeu_ps(c6 + 0, vscaled6x0123456789ABCDEF); c6 = (float*) ((uintptr_t) c6 + cn_stride); _mm512_storeu_ps(c5 + 0, vscaled5x0123456789ABCDEF); c5 = (float*) ((uintptr_t) c5 + cn_stride); _mm512_storeu_ps(c4 + 0, vscaled4x0123456789ABCDEF); c4 = (float*) ((uintptr_t) c4 + cn_stride); _mm512_storeu_ps(c3 + 0, vscaled3x0123456789ABCDEF); c3 = (float*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ps(c2 + 0, vscaled2x0123456789ABCDEF); c2 = (float*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ps(c1 + 0, vscaled1x0123456789ABCDEF); c1 = (float*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ps(c0 + 0, vscaled0x0123456789ABCDEF); c0 = (float*) ((uintptr_t) c0 + cn_stride); a = (const int8_t**restrict) ((uintptr_t) a - ks); nc -= 16; } else { // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask0 = _cvtu32_mask16((uint32_t) ((((UINT64_C(1) << nc) - 1) >> 0) & 0xFFFF)); _mm512_mask_storeu_ps(c15 + 0, vmask0, vscaled15x0123456789ABCDEF); _mm512_mask_storeu_ps(c14 + 0, vmask0, vscaled14x0123456789ABCDEF); _mm512_mask_storeu_ps(c13 + 0, vmask0, vscaled13x0123456789ABCDEF); _mm512_mask_storeu_ps(c12 + 0, vmask0, vscaled12x0123456789ABCDEF); _mm512_mask_storeu_ps(c11 + 0, vmask0, vscaled11x0123456789ABCDEF); _mm512_mask_storeu_ps(c10 + 0, vmask0, vscaled10x0123456789ABCDEF); _mm512_mask_storeu_ps(c9 + 0, vmask0, vscaled9x0123456789ABCDEF); _mm512_mask_storeu_ps(c8 + 0, vmask0, vscaled8x0123456789ABCDEF); _mm512_mask_storeu_ps(c7 + 0, vmask0, vscaled7x0123456789ABCDEF); _mm512_mask_storeu_ps(c6 + 0, vmask0, vscaled6x0123456789ABCDEF); _mm512_mask_storeu_ps(c5 + 0, vmask0, vscaled5x0123456789ABCDEF); _mm512_mask_storeu_ps(c4 + 0, vmask0, vscaled4x0123456789ABCDEF); _mm512_mask_storeu_ps(c3 + 0, vmask0, vscaled3x0123456789ABCDEF); _mm512_mask_storeu_ps(c2 + 0, vmask0, vscaled2x0123456789ABCDEF); _mm512_mask_storeu_ps(c1 + 0, vmask0, vscaled1x0123456789ABCDEF); _mm512_mask_storeu_ps(c0 + 0, vmask0, vscaled0x0123456789ABCDEF); nc = 0; } } while (nc != 0); // Release tile config // _tile_release(); __asm__ volatile ("tilerelease" ::); #endif // defined(__x86_64__) }