// Auto-generated file. Do not edit! // Template: src/qs8-igemm/c4-avx512amx.c.in // Generator: tools/xngen // // Copyright 2024 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #if defined(__has_feature) #if __has_feature(memory_sanitizer) #include #endif #endif #include #include "xnnpack/gemm.h" #include "xnnpack/intrinsics-polyfill.h" #include "xnnpack/math.h" #include "xnnpack/unaligned.h" void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx( size_t mr, size_t nc, size_t kc, size_t ks, const int8_t** restrict a, const void* restrict w, float* restrict c, size_t cm_stride, size_t cn_stride, size_t a_offset, const int8_t* zero, const int8_t* zero_data, const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) { assert(mr != 0); assert(mr <= 7); assert(nc != 0); assert(kc != 0); assert(kc % sizeof(int8_t) == 0); assert(a != NULL); assert(w != NULL); assert(c != NULL); // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; const __mmask16 kremainder_mask = _cvtu32_mask16((UINT32_C(1) << (kremainder >> 2)) - 1); // Define tile config data structure struct __tile_config { uint8_t palette_id; uint8_t start_row; uint8_t reserved_0[14]; uint16_t colsb[8]; uint16_t reserved_1[8]; uint8_t rows[8]; uint8_t reserved_2[8]; }; // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; tile_data.rows[0] = mr; // tmm0 = res[0] tile_data.rows[1] = mr; // tmm1 = res[1] tile_data.rows[2] = mr; // tmm2 = res[2] tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder tile_data.colsb[0] = 64; // tmm0 = res[0] tile_data.colsb[1] = 64; // tmm1 = res[1] tile_data.colsb[2] = 64; // tmm2 = res[2] tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder tile_data.colsb[7] = 64; // tmm7 = weights remainder //_tile_loadconfig(&tile_data); __asm__ volatile ("ldtilecfg %0" :: "m" (tile_data)); float* c0 = c; float* c1 = (float*) ((uintptr_t) c0 + cm_stride); if XNN_UNPREDICTABLE(mr < 2) { c1 = c0; } float* c2 = (float*) ((uintptr_t) c1 + cm_stride); if XNN_UNPREDICTABLE(mr <= 2) { c2 = c1; } float* c3 = (float*) ((uintptr_t) c2 + cm_stride); if XNN_UNPREDICTABLE(mr < 4) { c3 = c2; } float* c4 = (float*) ((uintptr_t) c3 + cm_stride); if XNN_UNPREDICTABLE(mr <= 4) { c4 = c3; } float* c5 = (float*) ((uintptr_t) c4 + cm_stride); if XNN_UNPREDICTABLE(mr < 6) { c5 = c4; } float* c6 = (float*) ((uintptr_t) c5 + cm_stride); if XNN_UNPREDICTABLE(mr <= 6) { c6 = c5; } const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); // XNN_FORCE_REALIZATION(voutput_max); do { const __m512i vksum0123456789ABCDEF = _mm512_loadu_epi32((const int32_t*) w + 0); w = (const int32_t*) w + 16; // Zero tile accumulator __asm__ volatile ( "tilezero %%tmm0\n" ::); size_t p = ks; do { const int8_t* restrict a0 = a[0]; if XNN_UNPREDICTABLE(a0 != zero) { a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); } else { a0 = zero_data; } const int8_t* restrict a1 = a[1]; if XNN_UNPREDICTABLE(a1 != zero) { a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); } else { a1 = zero_data; } const int8_t* restrict a2 = a[2]; if XNN_UNPREDICTABLE(a2 != zero) { a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); } else { a2 = zero_data; } const int8_t* restrict a3 = a[3]; if XNN_UNPREDICTABLE(a3 != zero) { a3 = (const int8_t*) ((uintptr_t) a3 + a_offset); } else { a3 = zero_data; } const int8_t* restrict a4 = a[4]; if XNN_UNPREDICTABLE(a4 != zero) { a4 = (const int8_t*) ((uintptr_t) a4 + a_offset); } else { a4 = zero_data; } const int8_t* restrict a5 = a[5]; if XNN_UNPREDICTABLE(a5 != zero) { a5 = (const int8_t*) ((uintptr_t) a5 + a_offset); } else { a5 = zero_data; } const int8_t* restrict a6 = a[6]; if XNN_UNPREDICTABLE(a6 != zero) { a6 = (const int8_t*) ((uintptr_t) a6 + a_offset); } else { a6 = zero_data; } a += 7; size_t k = kc; if (mr == 1) { while (k >= 64 * sizeof(int8_t)) { _tile_loadd(4, a0, 64); // Directly load input for mr=1 a6 += 64; _tile_loadd(5, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 4, 5); w = (const int8_t*) w + 1024; k -= 64 * sizeof(int8_t); } } else { while (k >= 64 * sizeof(int8_t)) { const __m512i vin0 = _mm512_loadu_epi32(a0); a0 += 64; _mm512_store_epi32(vintile + 0, vin0); const __m512i vin1 = _mm512_loadu_epi32(a1); a1 += 64; _mm512_store_epi32(vintile + 16, vin1); const __m512i vin2 = _mm512_loadu_epi32(a2); a2 += 64; _mm512_store_epi32(vintile + 32, vin2); const __m512i vin3 = _mm512_loadu_epi32(a3); a3 += 64; _mm512_store_epi32(vintile + 48, vin3); const __m512i vin4 = _mm512_loadu_epi32(a4); a4 += 64; _mm512_store_epi32(vintile + 64, vin4); const __m512i vin5 = _mm512_loadu_epi32(a5); a5 += 64; _mm512_store_epi32(vintile + 80, vin5); const __m512i vin6 = _mm512_loadu_epi32(a6); a6 += 64; _mm512_store_epi32(vintile + 96, vin6); _tile_loadd(4, vintile, 64); _tile_loadd(5, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 4, 5); w = (const int8_t*) w + 1024; k -= 64 * sizeof(int8_t); } } if XNN_UNLIKELY(k != 0) { const __m512i vin0 = _mm512_maskz_loadu_epi32(kremainder_mask, a0); a0 += kremainder; _mm512_store_epi32(vintile + 0, vin0); const __m512i vin1 = _mm512_maskz_loadu_epi32(kremainder_mask, a1); a1 += kremainder; _mm512_store_epi32(vintile + 16, vin1); const __m512i vin2 = _mm512_maskz_loadu_epi32(kremainder_mask, a2); a2 += kremainder; _mm512_store_epi32(vintile + 32, vin2); const __m512i vin3 = _mm512_maskz_loadu_epi32(kremainder_mask, a3); a3 += kremainder; _mm512_store_epi32(vintile + 48, vin3); const __m512i vin4 = _mm512_maskz_loadu_epi32(kremainder_mask, a4); a4 += kremainder; _mm512_store_epi32(vintile + 64, vin4); const __m512i vin5 = _mm512_maskz_loadu_epi32(kremainder_mask, a5); a5 += kremainder; _mm512_store_epi32(vintile + 80, vin5); const __m512i vin6 = _mm512_maskz_loadu_epi32(kremainder_mask, a6); a6 += kremainder; _mm512_store_epi32(vintile + 96, vin6); _tile_loadd(6, vintile, 64); _tile_loadd(7, (const int8_t*) w + 0, 64); _tile_dpbssd(0, 6, 7); w = (const int8_t*) w + kremainder * 16; k -= kremainder * sizeof(int8_t); } p -= 7 * sizeof(void*); } while (p != 0); // TODO: Instead of processing up to 4 tiles (16x64) consider // quantizing 1 tile at a time (16 registers) _tile_stored(0, &res[0][0], 64); // TODO: Fix msan for AMX #if defined(__has_feature) #if __has_feature(memory_sanitizer) __msan_unpoison(res, sizeof(res)); #endif #endif // TODO: Instead of processing up to 4 tiles (16x64) consider // quantizing 1 row at a time. // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc3x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc4x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc5x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); __m512 vscaled4x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc4x0123456789ABCDEF); __m512 vscaled5x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc5x0123456789ABCDEF); __m512 vscaled6x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc6x0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled2x0123456789ABCDEF = _mm512_mul_ps(vscaled2x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled3x0123456789ABCDEF = _mm512_mul_ps(vscaled3x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled4x0123456789ABCDEF = _mm512_mul_ps(vscaled4x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled5x0123456789ABCDEF = _mm512_mul_ps(vscaled5x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); vscaled6x0123456789ABCDEF = _mm512_mul_ps(vscaled6x0123456789ABCDEF, _mm512_set1_ps(quantization_params->inv_scale)); const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_loadu_ps((const float*) w + 0); w = (const int32_t*) w + 16; const __m512 vbias0123456789ABCDEF = _mm512_loadu_ps((const float*) w + 0); w = (const int32_t*) w + 16; vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled1x0123456789ABCDEF = _mm512_fmadd_ps(vscaled1x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled2x0123456789ABCDEF = _mm512_fmadd_ps(vscaled2x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled3x0123456789ABCDEF = _mm512_fmadd_ps(vscaled3x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled4x0123456789ABCDEF = _mm512_fmadd_ps(vscaled4x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled5x0123456789ABCDEF = _mm512_fmadd_ps(vscaled5x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled6x0123456789ABCDEF = _mm512_fmadd_ps(vscaled6x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); vscaled1x0123456789ABCDEF = _mm512_max_ps(vscaled1x0123456789ABCDEF, voutput_min); vscaled2x0123456789ABCDEF = _mm512_max_ps(vscaled2x0123456789ABCDEF, voutput_min); vscaled3x0123456789ABCDEF = _mm512_max_ps(vscaled3x0123456789ABCDEF, voutput_min); vscaled4x0123456789ABCDEF = _mm512_max_ps(vscaled4x0123456789ABCDEF, voutput_min); vscaled5x0123456789ABCDEF = _mm512_max_ps(vscaled5x0123456789ABCDEF, voutput_min); vscaled6x0123456789ABCDEF = _mm512_max_ps(vscaled6x0123456789ABCDEF, voutput_min); vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max); vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max); vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max); vscaled4x0123456789ABCDEF = _mm512_min_ps(vscaled4x0123456789ABCDEF, voutput_max); vscaled5x0123456789ABCDEF = _mm512_min_ps(vscaled5x0123456789ABCDEF, voutput_max); vscaled6x0123456789ABCDEF = _mm512_min_ps(vscaled6x0123456789ABCDEF, voutput_max); if XNN_LIKELY(nc >= 16) { _mm512_storeu_ps(c6 + 0, vscaled6x0123456789ABCDEF); c6 = (float*) ((uintptr_t) c6 + cn_stride); _mm512_storeu_ps(c5 + 0, vscaled5x0123456789ABCDEF); c5 = (float*) ((uintptr_t) c5 + cn_stride); _mm512_storeu_ps(c4 + 0, vscaled4x0123456789ABCDEF); c4 = (float*) ((uintptr_t) c4 + cn_stride); _mm512_storeu_ps(c3 + 0, vscaled3x0123456789ABCDEF); c3 = (float*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ps(c2 + 0, vscaled2x0123456789ABCDEF); c2 = (float*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ps(c1 + 0, vscaled1x0123456789ABCDEF); c1 = (float*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ps(c0 + 0, vscaled0x0123456789ABCDEF); c0 = (float*) ((uintptr_t) c0 + cn_stride); a = (const int8_t**restrict) ((uintptr_t) a - ks); nc -= 16; } else { // Prepare mask for valid 32-bit elements (depends on nc). const __mmask16 vmask0 = _cvtu32_mask16((uint32_t) ((((UINT64_C(1) << nc) - 1) >> 0) & 0xFFFF)); _mm512_mask_storeu_ps(c6 + 0, vmask0, vscaled6x0123456789ABCDEF); _mm512_mask_storeu_ps(c5 + 0, vmask0, vscaled5x0123456789ABCDEF); _mm512_mask_storeu_ps(c4 + 0, vmask0, vscaled4x0123456789ABCDEF); _mm512_mask_storeu_ps(c3 + 0, vmask0, vscaled3x0123456789ABCDEF); _mm512_mask_storeu_ps(c2 + 0, vmask0, vscaled2x0123456789ABCDEF); _mm512_mask_storeu_ps(c1 + 0, vmask0, vscaled1x0123456789ABCDEF); _mm512_mask_storeu_ps(c0 + 0, vmask0, vscaled0x0123456789ABCDEF); nc = 0; } } while (nc != 0); // Release tile config // _tile_release(); __asm__ volatile ("tilerelease" ::); #endif // defined(__x86_64__) }