#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #if defined(CPU_CAPABILITY_AVX512) #define SLEEF_STATIC_LIBS #include #endif namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized { private: static constexpr __m512i zero_vec{0, 0, 0, 0, 0, 0, 0, 0}; public: __m512 values; using value_type = float; using size_type = int; static constexpr size_type size() { return 16; } Vectorized() {} Vectorized(__m512 v) : values(v) {} Vectorized(float val) { values = _mm512_set1_ps(val); } Vectorized( float val1, float val2, float val3, float val4, float val5, float val6, float val7, float val8, float val9, float val10, float val11, float val12, float val13, float val14, float val15, float val16) { values = _mm512_setr_ps( val1, val2, val3, val4, val5, val6, val7, val8, val9, val10, val11, val12, val13, val14, val15, val16); } Vectorized(const float (&arr)[16]) : Vectorized( arr[0], arr[1], arr[2], arr[3], arr[4], arr[5], arr[6], arr[7], arr[8], arr[9], arr[10], arr[11], arr[12], arr[13], arr[14], arr[15]) {} operator __m512() const { return values; } template static Vectorized blend( const Vectorized& a, const Vectorized& b) { return _mm512_mask_blend_ps(mask, a.values, b.values); } static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); auto mmask = _mm512_cmp_epi32_mask( _mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); return _mm512_mask_blend_ps(mmask, a.values, b.values); } template static Vectorized arange( float base = 0.f, step_t step = static_cast(1)) { return Vectorized( base, base + step, base + 2 * step, base + 3 * step, base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); case 2: return blend<3>(a, b); case 3: return blend<7>(a, b); case 4: return blend<15>(a, b); case 5: return blend<31>(a, b); case 6: return blend<63>(a, b); case 7: return blend<127>(a, b); case 8: return blend<255>(a, b); case 9: return blend<511>(a, b); case 10: return blend<1023>(a, b); case 11: return blend<2047>(a, b); case 12: return blend<4095>(a, b); case 13: return blend<8191>(a, b); case 14: return blend<16383>(a, b); case 15: return blend<32767>(a, b); } return b; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm512_loadu_ps(reinterpret_cast(ptr)); __mmask16 mask = (1ULL << count) - 1; return _mm512_maskz_loadu_ps(mask, ptr); } void store(void* ptr, int64_t count = size()) const { if (count == size()) { _mm512_storeu_ps(reinterpret_cast(ptr), values); } else if (count > 0) { __mmask16 mask = (1ULL << count) - 1; _mm512_mask_storeu_ps(reinterpret_cast(ptr), mask, values); } } const float& operator[](int idx) const = delete; float& operator[](int idx) = delete; int zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit // and others are translated to 0-bit __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); return static_cast(cmp); } Vectorized isnan() const { auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } bool has_inf_nan() const { __m512 self_sub = _mm512_sub_ps(values, values); return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0; } Vectorized map(float (*const f)(float)) const { __at_align__ float tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { tmp[i] = f(tmp[i]); } return loadu(tmp); } Vectorized abs() const { auto mask = _mm512_set1_ps(-0.f); return _mm512_andnot_ps(mask, values); } Vectorized angle() const { __m512 zero_vec = _mm512_set1_ps(0.f); const auto nan_vec = _mm512_set1_ps(NAN); const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); const auto not_nan_vec = _mm512_mask_set1_epi32( _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); const auto nan_mask = _mm512_cmp_ps_mask( _mm512_castsi512_ps(not_nan_vec), zero_vec, _CMP_EQ_OQ); const auto pi = _mm512_set1_ps(c10::pi); const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); return angle; } Vectorized real() const { return *this; } Vectorized imag() const { return _mm512_set1_ps(0); } Vectorized conj() const { return *this; } Vectorized acos() const { return Vectorized(Sleef_acosf16_u10(values)); } Vectorized acosh() const { return Vectorized(Sleef_acoshf16_u10(values)); } Vectorized asin() const { return Vectorized(Sleef_asinf16_u10(values)); } Vectorized asinh() const { return Vectorized(Sleef_asinhf16_u10(values)); } Vectorized atan() const { return Vectorized(Sleef_atanf16_u10(values)); } Vectorized atanh() const { return Vectorized(Sleef_atanhf16_u10(values)); } Vectorized atan2(const Vectorized& b) const { return Vectorized(Sleef_atan2f16_u10(values, b)); } Vectorized copysign(const Vectorized& sign) const { return Vectorized(Sleef_copysignf16(values, sign)); } Vectorized erf() const { // constants const auto neg_zero_vec = _mm512_set1_ps(-0.f); const auto one_vec = _mm512_set1_ps(1.0f); const auto p = _mm512_set1_ps(0.3275911f); const auto p1 = _mm512_set1_ps(0.254829592f); const auto p2 = _mm512_set1_ps(-0.284496736f); const auto p3 = _mm512_set1_ps(1.421413741f); const auto p4 = _mm512_set1_ps(-1.453152027f); const auto p5 = _mm512_set1_ps(1.061405429f); // sign(x) auto sign_mask = _mm512_and_ps(neg_zero_vec, values); auto abs_vec = _mm512_abs_ps(values); // t = 1 / (p * abs(x) + 1) auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec); auto t = _mm512_div_ps(one_vec, tmp0); // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 auto tmp1 = _mm512_fmadd_ps(p5, t, p4); auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3); auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2); auto r = _mm512_fmadd_ps(tmp3, t, p1); // - exp(- x * x) auto pow_2 = _mm512_mul_ps(values, values); auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2); // auto tmp4 = exp(neg_pow_2); auto tmp4 = Vectorized(Sleef_expf16_u10(neg_pow_2)); auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4); // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) auto tmp6 = _mm512_mul_ps(tmp5, t); auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec); return _mm512_xor_ps(sign_mask, tmp7); } Vectorized erfc() const { return Vectorized(Sleef_erfcf16_u15(values)); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return Vectorized(Sleef_expf16_u10(values)); } Vectorized exp2() const { return Vectorized(Sleef_exp2f16_u10(values)); } Vectorized expm1() const { return Vectorized(Sleef_expm1f16_u10(values)); } Vectorized exp_u20() const { // A faster version of exp with ULP=20 const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); // 1/factorial(1) const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); // 1/factorial(2) const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); // 1/factorial(3) const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); // 1/factorial(4) const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); // 1/factorial(5) const __m512 vec_exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) const __m512 vec_half = _mm512_set1_ps(0.5f); const __m512 vec_one = _mm512_set1_ps(1.f); const __m512 vec_zero = _mm512_set1_ps(0.f); const __m512 vec_two = _mm512_set1_ps(2.f); const __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) const __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); const __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); const int n_mantissa_bits = 23; // exp(x) = // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression auto less_ln_flt_min_mask = _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); // fx = floorf(x * log2ef + 0.5) auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); auto vec_fx_i = _mm512_cvt_roundps_epi32( vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); vec_fx = _mm512_cvtepi32_ps(vec_fx_i); // x = x - fx * ln2 auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); // compute polynomial auto vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); // compute 2^(n-1) auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero); // y = y * 2^n vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); vec_res = _mm512_mul_ps(vec_res, vec_two); return vec_res; } Vectorized fmod(const Vectorized& q) const { return Vectorized(Sleef_fmodf16(values, q)); } Vectorized log() const { return Vectorized(Sleef_logf16_u10(values)); } Vectorized log2() const { return Vectorized(Sleef_log2f16_u10(values)); } Vectorized log10() const { return Vectorized(Sleef_log10f16_u10(values)); } Vectorized log1p() const { return Vectorized(Sleef_log1pf16_u10(values)); } Vectorized frac() const; Vectorized sin() const { return Vectorized(Sleef_sinf16_u35(values)); } Vectorized sinh() const { return Vectorized(Sleef_sinhf16_u10(values)); } Vectorized cos() const { return Vectorized(Sleef_cosf16_u35(values)); } Vectorized cosh() const { return Vectorized(Sleef_coshf16_u10(values)); } Vectorized ceil() const { return _mm512_ceil_ps(values); } Vectorized floor() const { return _mm512_floor_ps(values); } Vectorized hypot(const Vectorized& b) const { return Vectorized(Sleef_hypotf16_u05(values, b)); } Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized digamma() const { return map(calc_digamma); } Vectorized igamma(const Vectorized& x) const { __at_align__ float tmp[size()]; __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (const auto i : c10::irange(size())) { tmp[i] = calc_igamma(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized igammac(const Vectorized& x) const { __at_align__ float tmp[size()]; __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (const auto i : c10::irange(size())) { tmp[i] = calc_igammac(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized neg() const { return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); } Vectorized nextafter(const Vectorized& b) const { return Vectorized(Sleef_nextafterf16(values, b)); } Vectorized round() const { return _mm512_roundscale_ps( values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } Vectorized tan() const { return Vectorized(Sleef_tanf16_u10(values)); } Vectorized tanh() const { return Vectorized(Sleef_tanhf16_u10(values)); } Vectorized trunc() const { return _mm512_roundscale_ps( values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vectorized lgamma() const { return Vectorized(Sleef_lgammaf16_u10(values)); } Vectorized sqrt() const { return _mm512_sqrt_ps(values); } Vectorized reciprocal() const { return _mm512_div_ps(_mm512_set1_ps(1), values); } Vectorized rsqrt() const { return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); } Vectorized pow(const Vectorized& b) const { return Vectorized(Sleef_powf16_u10(values, b)); } float reduce_add() const { return _mm512_reduce_add_ps(values); } float reduce_max() const { return _mm512_reduce_max_ps(values); } // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized operator==(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized operator!=(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized operator<(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized operator<=(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized operator>(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized operator>=(const Vectorized& other) const { auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); } Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return _mm512_add_ps(a, b); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return _mm512_sub_ps(a, b); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return _mm512_mul_ps(a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return _mm512_div_ps(a, b); } // frac. Implement this here so we can use subtraction inline Vectorized Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { auto zero_vec = _mm512_set1_epi32(0); auto max = _mm512_max_ps(a, b); auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); auto isnan = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); // Exploit the fact that all-ones is a NaN. return _mm512_or_ps(max, isnan); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { auto zero_vec = _mm512_set1_epi32(0); auto min = _mm512_min_ps(a, b); auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); auto isnan = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); // Exploit the fact that all-ones is a NaN. return _mm512_or_ps(min, isnan); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { return _mm512_min_ps(max, _mm512_max_ps(min, a)); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { return _mm512_min_ps(max, a); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { return _mm512_max_ps(min, a); } template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return _mm512_and_ps(a, b); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return _mm512_or_ps(a, b); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return _mm512_xor_ps(a, b); } inline Vectorized Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0f); } inline Vectorized Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0f); } inline Vectorized Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0f); } inline Vectorized Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0f); } inline Vectorized Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0f); } inline Vectorized Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0f); } template <> inline void convert(const float* src, float* dst, int64_t n) { int64_t i; #ifndef __msvc_cl__ #pragma unroll #endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); } #ifndef __msvc_cl__ #pragma unroll #endif for (; i < n; i++) { dst[i] = src[i]; } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm512_fmadd_ps(a, b, c); } template <> Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm512_fmsub_ps(a, b, c); } // TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) // Used by Inductor CPP codegen for micro gemm // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 // kernel for transposing mxn where m, n <= 16 // (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + N instructions inline void transpose_block( at::vec::VectorizedN& input, int M = 16, int N = 16) { TORCH_CHECK(M <= 16 && N <= 16, "transpose_block expects M, N <= 16."); // unpacking and interleaving 32-bit elements __m512 temp[16]; int i; for (i = 0; i < (M + 1) / 2; ++i) { temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); } for (i = i * 2; i < 16; ++i) { temp[i] = _mm512_setzero_ps(); } // unpacking and interleaving 64-bit elements for (i = 0; i < (M + 3) / 4; ++i) { input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); } // shuffle 128-bits (composed of 4 32-bit elements) for (i = 0; i < (M + 7) / 8; ++i) { temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); temp[8 * i + 1] = _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); temp[8 * i + 2] = _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); temp[8 * i + 3] = _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); temp[8 * i + 4] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); temp[8 * i + 5] = _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); temp[8 * i + 6] = _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); temp[8 * i + 7] = _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); } for (i = 0; i < N; ++i) { if (i < 8) { input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); } else { input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); } } } // TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle) // Used by Inductor CPP codegen // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 // kernel for transposing mxn where m, n <= 16 // M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions inline void transpose_mxn_16x16( const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); // load from src to registers at::vec::VectorizedN input; int i; if (N == 16) { for (i = 0; i < M; ++i) { input[i] = _mm512_loadu_ps(&src[i * ld_src]); } } else { __mmask16 src_mask = (1 << N) - 1; for (i = 0; i < M; ++i) { input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); } } for (; i < 16; ++i) { // Not really needed but to avoid uninitialized variable warning. // Shouldn't be much overhead because xor can be executed in parallel with // other instructions. input[i] = _mm512_setzero_ps(); } transpose_block(input, M, N); // store from registers to dst if (M == 16) { for (i = 0; i < N; ++i) { _mm512_storeu_ps(&dst[i * ld_dst], input[i]); } } else { __mmask16 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); } } } template <> inline void transpose_mxn( const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { int64_t i = 0; for (; i < M / 16 * 16; i += 16) { int64_t j = 0; for (; j < N / 16 * 16; j += 16) { transpose_mxn_16x16( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); } // handle remainder j int nrem = N - j; if (nrem > 0) { transpose_mxn_16x16( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); } } // handle remainder i int mrem = M - i; if (mrem > 0) { int j = 0; for (; j < N / 16 * 16; j += 16) { transpose_mxn_16x16( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); } // handle remainder j int nrem = N - j; transpose_mxn_16x16( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); } } template < typename T, int M, int N, typename std::enable_if_t, int> = 0> inline void transpose_mxn( const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } #endif } // namespace CPU_CAPABILITY } // namespace at::vec