#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #if defined(CPU_CAPABILITY_AVX512) #define SLEEF_STATIC_LIBS #include #endif namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) #ifndef SLEEF_CONST #if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) #define SLEEF_CONST const #else #define SLEEF_CONST #endif #define SLEEF_CONST_OLD SLEEF_CONST #else #define SLEEF_CONST_OLD #endif // bfloat16 conversion static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); } static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { __m256i lo = _mm512_extracti32x8_epi32(a, 0); __m256i hi = _mm512_extracti32x8_epi32(a, 1); cvtbf16_fp32(lo, o1); cvtbf16_fp32(hi, o2); } static inline __m256i cvtfp32_bf16(const __m512& src) { __m512i value = _mm512_castps_si512(src); __m512i nan = _mm512_set1_epi32(0xffff); auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); __m512i ones = _mm512_set1_epi32(0x1); __m512i vec_bias = _mm512_set1_epi32(0x7fff); // uint32_t lsb = (input >> 16) & 1; auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); // uint32_t rounding_bias = 0x7fff + lsb; t_value = _mm512_add_epi32(t_value, vec_bias); // input += rounding_bias; t_value = _mm512_add_epi32(t_value, value); // input = input >> 16; t_value = _mm512_srli_epi32(t_value, 16); // Check NaN before converting back to bf16 t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); return _mm512_cvtusepi32_epi16(t_value); } static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { __m512i lo = _mm512_castps_si512(a); __m512i hi = _mm512_castps_si512(b); __m512i nan = _mm512_set1_epi32(0xffff); auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); __m512i ones = _mm512_set1_epi32(0x1); __m512i vec_bias = _mm512_set1_epi32(0x7fff); // uint32_t lsb = (input >> 16) & 1; auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); // uint32_t rounding_bias = 0x7fff + lsb; t_lo = _mm512_add_epi32(t_lo, vec_bias); t_hi = _mm512_add_epi32(t_hi, vec_bias); // input += rounding_bias; t_lo = _mm512_add_epi32(t_lo, lo); t_hi = _mm512_add_epi32(t_hi, hi); // input = input >> 16; t_lo = _mm512_srli_epi32(t_lo, 16); t_hi = _mm512_srli_epi32(t_hi, 16); // Check NaN before converting back to bf16 t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); t_lo = _mm512_packus_epi32( t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); return _mm512_permutexvar_epi64(idx, t_lo); } static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { __m512i lo = _mm512_castps_si512(a); __m512i hi = _mm512_castps_si512(b); lo = _mm512_srli_epi32(lo, 16); hi = _mm512_srli_epi32(hi, 16); auto out = _mm512_packus_epi32(lo, hi); __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); return _mm512_permutexvar_epi64(idx, out); } // float16 conversion static inline void cvtfp16_fp32(const __m256i& a, __m512& o) { o = _mm512_cvtph_ps(a); } static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) { __m256i lo = _mm512_extracti32x8_epi32(a, 0); __m256i hi = _mm512_extracti32x8_epi32(a, 1); cvtfp16_fp32(lo, o1); cvtfp16_fp32(hi, o2); } static inline __m256i cvtfp32_fp16(const __m512& src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) { __m256i lo = _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m256i hi = _mm512_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo)); __m256 t_hi = _mm256_castsi256_ps(hi); return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1)); } // dtype conversion between float16/bfloat16 and float32 template < typename T, typename std::enable_if_t, int> = 0> inline void cvt_to_fp32(const __m256i& a, __m512& o); template <> inline void cvt_to_fp32(const __m256i& a, __m512& o) { cvtbf16_fp32(a, o); } template <> inline void cvt_to_fp32(const __m256i& a, __m512& o) { cvtfp16_fp32(a, o); } template < typename T, typename std::enable_if_t, int> = 0> inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2); template <> inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { cvtbf16_fp32(a, o1, o2); } template <> inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { cvtfp16_fp32(a, o1, o2); } template < typename T, bool is_compare_op = false, typename std::enable_if_t, int> = 0> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b); template <> inline __m512i cvt_from_fp32( const __m512& a, const __m512& b) { return cvtfp32_bf16(a, b); } template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { return merge_compare_result(a, b); } template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { return cvtfp32_fp16(a, b); } template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { return cvtfp32_fp16(a, b); } template class Vectorized16 { static_assert( is_reduced_floating_point_v, "Support only float16 and bfloat16."); private: __m512i values; public: using value_type = uint16_t; using size_type = int; static constexpr size_type size() { return 32; } Vectorized16() {} Vectorized16(__m512i v) : values(v) {} Vectorized16(T val) { value_type uw = val.x; values = _mm512_set1_epi16(uw); } Vectorized16( T val1, T val2, T val3, T val4, T val5, T val6, T val7, T val8, T val9, T val10, T val11, T val12, T val13, T val14, T val15, T val16, T val17, T val18, T val19, T val20, T val21, T val22, T val23, T val24, T val25, T val26, T val27, T val28, T val29, T val30, T val31, T val32) { values = _mm512_set_epi16( val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x, val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x, val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x, val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x); } operator __m512i() const { return values; } T& operator[](int idx) = delete; const T& operator[](int idx) const = delete; int zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit // and others are translated to 0-bit return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); } static Vectorized loadu(const void* ptr, int16_t count = size()) { if (count == size()) return _mm512_loadu_si512(reinterpret_cast(ptr)); __mmask32 mask = (1ULL << count) - 1; return _mm512_maskz_loadu_epi16(mask, ptr); } void store(void* ptr, int count = size()) const { if (count == size()) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); } else if (count > 0) { __mmask32 mask = (1ULL << count) - 1; _mm512_mask_storeu_epi16(ptr, mask, values); } } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { return _mm512_mask_blend_epi16(mask, a.values, b.values); } static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask) { auto all_ones = _mm512_set1_epi16(0xFFFF); auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); return _mm512_mask_blend_epi16(mask_, a.values, b.values); } template static Vectorized arange( T base = 0.f, step_t step = static_cast(1)) { return Vectorized( base, base + step, base + 2 * step, base + 3 * step, base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); case 2: return blend<3>(a, b); case 3: return blend<7>(a, b); case 4: return blend<15>(a, b); case 5: return blend<31>(a, b); case 6: return blend<63>(a, b); case 7: return blend<127>(a, b); case 8: return blend<255>(a, b); case 9: return blend<511>(a, b); case 10: return blend<1023>(a, b); case 11: return blend<2047>(a, b); case 12: return blend<4095>(a, b); case 13: return blend<8191>(a, b); case 14: return blend<16383>(a, b); case 15: return blend<32767>(a, b); case 16: return blend<65535>(a, b); case 17: return blend<131071>(a, b); case 18: return blend<262143>(a, b); case 19: return blend<524287>(a, b); case 20: return blend<1048575>(a, b); case 21: return blend<2097151>(a, b); case 22: return blend<4194303>(a, b); case 23: return blend<8388607>(a, b); case 24: return blend<16777215>(a, b); case 25: return blend<33554431>(a, b); case 26: return blend<67108863>(a, b); case 27: return blend<134217727>(a, b); case 28: return blend<268435455>(a, b); case 29: return blend<536870911>(a, b); case 30: return blend<1073741823>(a, b); case 31: return blend<2147483647>(a, b); } return b; } #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wignored-qualifiers" Vectorized map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); const auto o1 = vop(lo); const auto o2 = vop(hi); return cvt_from_fp32(o1, o2); } Vectorized isnan() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); __mmask16 lo_mask, hi_mask; __m512 zero = _mm512_set1_ps(0.0); __m512i zeroi = _mm512_castps_si512(zero); lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q); lo = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF)); hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q); hi = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF)); return merge_compare_result(lo, hi); } #pragma clang diagnostic pop Vectorized abs() const { return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values); } Vectorized angle() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto angle_lambda = [](__m512 values) { const auto zero_vec = _mm512_set1_ps(0.f); const auto nan_vec = _mm512_set1_ps(NAN); const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); const auto non_nan_mask_vec = _mm512_mask_set1_epi32( _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); const auto nan_mask = _mm512_cmp_ps_mask( _mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ); const auto pi = _mm512_set1_ps(c10::pi); const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); return angle; }; auto o1 = angle_lambda(lo); auto o2 = angle_lambda(hi); return cvt_from_fp32(o1, o2); } Vectorized real() const { return *this; } Vectorized imag() const { return _mm512_set1_epi16(0); } Vectorized conj() const { return *this; } Vectorized acos() const { return map(Sleef_acosf16_u10); } Vectorized acosh() const { return map(Sleef_acoshf16_u10); } Vectorized asin() const { return map(Sleef_asinf16_u10); } Vectorized asinh() const { return map(Sleef_asinhf16_u10); } Vectorized atan() const { return map(Sleef_atanf16_u10); } Vectorized atanh() const { return map(Sleef_atanhf16_u10); } Vectorized atan2(const Vectorized& b) const { __m512 lo, hi; __m512 b1, b2; cvt_to_fp32(values, lo, hi); cvt_to_fp32(b.values, b1, b2); auto o1 = Sleef_atan2f16_u10(lo, b1); auto o2 = Sleef_atan2f16_u10(hi, b2); return cvt_from_fp32(o1, o2); } Vectorized copysign(const Vectorized& sign) const { // copy sign bit (0x8000) from sign and remaining bits from values __m512i mask_value = _mm512_set1_epi32(~0x80008000); __m512i mask_signbit = _mm512_set1_epi32(0x80008000); return Vectorized(_mm512_or_si512( _mm512_and_si512(values, mask_value), _mm512_and_si512(sign, mask_signbit))); } Vectorized erf() const { return map(Sleef_erff16_u10); } Vectorized erfc() const { return map(Sleef_erfcf16_u15); } Vectorized erfinv() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { tmp1[i] = calc_erfinv(tmp1[i]); tmp2[i] = calc_erfinv(tmp2[i]); } auto o1 = _mm512_loadu_ps(tmp1); auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized exp() const { return map(Sleef_expf16_u10); } Vectorized exp2() const { return map(Sleef_exp2f16_u10); } Vectorized expm1() const { return map(Sleef_expm1f16_u10); } Vectorized exp_u20() const { return exp(); } Vectorized fmod(const Vectorized& q) const { __m512 x_lo, x_hi; cvt_to_fp32(values, x_lo, x_hi); __m512 q_lo, q_hi; cvtbf16_fp32(q.values, q_lo, q_hi); auto o1 = Sleef_fmodf16(x_lo, q_lo); auto o2 = Sleef_fmodf16(x_hi, q_hi); return cvt_from_fp32(o1, o2); } Vectorized hypot(const Vectorized& b) const { __m512 lo, hi; __m512 b1, b2; cvt_to_fp32(values, lo, hi); cvt_to_fp32(b.values, b1, b2); auto o1 = Sleef_hypotf16_u05(lo, b1); auto o2 = Sleef_hypotf16_u05(hi, b2); return cvt_from_fp32(o1, o2); } Vectorized i0() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { tmp1[i] = calc_i0(tmp1[i]); tmp2[i] = calc_i0(tmp2[i]); } auto o1 = _mm512_loadu_ps(tmp1); auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized i0e() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); constexpr auto sz = size(); __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); for (auto i = decltype(sz){0}; i < sz / 2; i++) { tmp1[i] = calc_i0e(tmp1[i]); tmp2[i] = calc_i0e(tmp2[i]); } const auto o1 = _mm512_loadu_ps(tmp1); const auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized digamma() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); constexpr auto sz = size(); __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); for (auto i = decltype(sz){0}; i < sz / 2; i++) { tmp1[i] = calc_digamma(tmp1[i]); tmp2[i] = calc_digamma(tmp2[i]); } const auto o1 = _mm512_loadu_ps(tmp1); const auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized igamma(const Vectorized& x) const { __m512 lo, hi; __m512 xlo, xhi; cvt_to_fp32(values, lo, hi); cvt_to_fp32(x.values, xlo, xhi); __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); } auto o1 = _mm512_loadu_ps(tmp1); auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized igammac(const Vectorized& x) const { __m512 lo, hi; __m512 xlo, xhi; cvt_to_fp32(values, lo, hi); cvt_to_fp32(x.values, xlo, xhi); __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmp1), lo); _mm512_storeu_ps(reinterpret_cast(tmp2), hi); __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); } auto o1 = _mm512_loadu_ps(tmp1); auto o2 = _mm512_loadu_ps(tmp2); return cvt_from_fp32(o1, o2); } Vectorized log() const { return map(Sleef_logf16_u10); } Vectorized log2() const { return map(Sleef_log2f16_u10); } Vectorized log10() const { return map(Sleef_log10f16_u10); } Vectorized log1p() const { return map(Sleef_log1pf16_u10); } Vectorized sin() const { return map(Sleef_sinf16_u10); } Vectorized sinh() const { return map(Sleef_sinhf16_u10); } Vectorized cos() const { return map(Sleef_cosf16_u10); } Vectorized cosh() const { return map(Sleef_coshf16_u10); } Vectorized ceil() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto o1 = _mm512_ceil_ps(lo); auto o2 = _mm512_ceil_ps(hi); return cvt_from_fp32(o1, o2); } Vectorized floor() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto o1 = _mm512_floor_ps(lo); auto o2 = _mm512_floor_ps(hi); return cvt_from_fp32(o1, o2); } Vectorized neg() const { return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000)); } Vectorized round() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto o1 = _mm512_roundscale_ps( lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); auto o2 = _mm512_roundscale_ps( hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); return cvt_from_fp32(o1, o2); } Vectorized tan() const { return map(Sleef_tanf16_u10); } Vectorized tanh() const { return map(Sleef_tanhf16_u10); } Vectorized trunc() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); return cvt_from_fp32(o1, o2); } Vectorized lgamma() const { return map(Sleef_lgammaf16_u10); } Vectorized sqrt() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto o1 = _mm512_sqrt_ps(lo); auto o2 = _mm512_sqrt_ps(hi); return cvt_from_fp32(o1, o2); } Vectorized reciprocal() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto ones = _mm512_set1_ps(1); auto o1 = _mm512_div_ps(ones, lo); auto o2 = _mm512_div_ps(ones, hi); return cvt_from_fp32(o1, o2); } Vectorized rsqrt() const { __m512 lo, hi; cvt_to_fp32(values, lo, hi); auto ones = _mm512_set1_ps(1); auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); return cvt_from_fp32(o1, o2); } Vectorized pow(const Vectorized& b) const { __m512 lo, hi; __m512 b1, b2; cvt_to_fp32(values, lo, hi); cvt_to_fp32(b.values, b1, b2); auto o1 = Sleef_powf16_u10(lo, b1); auto o2 = Sleef_powf16_u10(hi, b2); return cvt_from_fp32(o1, o2); } private: template Vectorized inline binary_compare(const VectorizedType& b, Op op) const { __m512 a_lo, a_hi; __m512 b_lo, b_hi; cvt_to_fp32(values, a_lo, a_hi); cvt_to_fp32(b.values, b_lo, b_hi); auto o1 = op(a_lo, b_lo); auto o2 = op(a_hi, b_hi); return cvt_from_fp32(o1, o2); } public: Vectorized inline operator>(const Vectorized& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } Vectorized inline operator<(const Vectorized& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } Vectorized inline operator>=(const Vectorized& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } Vectorized inline operator<=(const Vectorized& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } Vectorized inline operator==(const Vectorized16& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } Vectorized inline operator!=(const Vectorized16& other) const { return binary_compare(other, [](__m512 x, __m512 y) { auto zero_vec = _mm512_set1_epi32(0); auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); return _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); }); } }; template static inline Vectorized binary_op_as_fp32( const Vectorized& a, const Vectorized& b, Op op) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; cvt_to_fp32(__m512i(a), a_lo, a_hi); cvt_to_fp32(__m512i(b), b_lo, b_hi); auto o1 = op(a_lo, b_lo); auto o2 = op(a_hi, b_hi); return cvt_from_fp32(o1, o2); } template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized : public Vectorized16 { public: using Vectorized16::Vectorized16; using value_type = BFloat16; Vectorized frac() const; Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); }); } Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); }); } Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); }); } Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); }); } Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return _mm512_and_si512(a, b); } Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return _mm512_or_si512(a, b); } Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return _mm512_xor_si512(a, b); } inline Vectorized Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0f); } inline Vectorized Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0f); } inline Vectorized Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0f); } inline Vectorized Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0f); } inline Vectorized Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0f); } inline Vectorized Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0f); } // frac. Implement this here so we can use subtraction inline Vectorized Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(b), b_lo, b_hi); auto max_lo = _mm512_max_ps(a_lo, b_lo); auto max_hi = _mm512_max_ps(a_hi, b_hi); auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); // Exploit the fact that all-ones is a NaN. auto o1 = _mm512_or_ps(max_lo, nan_lo); auto o2 = _mm512_or_ps(max_hi, nan_hi); return cvtfp32_bf16(o1, o2); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; __m512i zero_vec = _mm512_set1_epi32(0); cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(b), b_lo, b_hi); auto min_lo = _mm512_min_ps(a_lo, b_lo); auto min_hi = _mm512_min_ps(a_hi, b_hi); auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); auto nan_lo = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); auto nan_hi = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); // Exploit the fact that all-ones is a NaN. auto o1 = _mm512_or_ps(min_lo, nan_lo); auto o2 = _mm512_or_ps(min_hi, nan_hi); return cvtfp32_bf16(o1, o2); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { __m512 a_lo, a_hi; __m512 min_lo, min_hi; __m512 max_lo, max_hi; cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(min), min_lo, min_hi); cvtbf16_fp32(__m512i(max), max_lo, max_hi); auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); return cvtfp32_bf16(o1, o2); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { __m512 a_lo, a_hi; __m512 max_lo, max_hi; cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(max), max_lo, max_hi); auto o1 = _mm512_min_ps(max_lo, a_lo); auto o2 = _mm512_min_ps(max_hi, a_hi); return cvtfp32_bf16(o1, o2); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { __m512 a_lo, a_hi; __m512 min_lo, min_hi; cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(min), min_lo, min_hi); auto o1 = _mm512_max_ps(min_lo, a_lo); auto o2 = _mm512_max_ps(min_hi, a_hi); return cvtfp32_bf16(o1, o2); } template <> inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { int64_t i; #ifndef __msvc_cl__ #pragma unroll #endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); } #ifndef __msvc_cl__ #pragma unroll #endif for (; i < n; i++) { dst[i] = src[i]; } } template <> inline void convert(const float* src, BFloat16* dst, int64_t n) { int64_t i; for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) { __m512 a = _mm512_loadu_ps(&src[i]); __m512 b = _mm512_loadu_ps(&src[i + 16]); __m512i bf = cvtfp32_bf16(a, b); _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); } for (; i < n; i++) { dst[i] = c10::convert(src[i]); } } template <> inline void convert(const double* src, BFloat16* dst, int64_t n) { auto load_float = [](const double* src) -> __m512 { // Load one float vector from an array of doubles __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); }; int64_t i; for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) { __m512 a = load_float(&src[i]); __m512 b = load_float(&src[i + 16]); __m512i bf = cvtfp32_bf16(a, b); _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); } for (; i < n; i++) { dst[i] = c10::convert(src[i]); } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; __m512 c_lo, c_hi; cvtbf16_fp32(__m512i(a), a_lo, a_hi); cvtbf16_fp32(__m512i(b), b_lo, b_hi); cvtbf16_fp32(__m512i(c), c_lo, c_hi); auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); return cvtfp32_bf16(o1, o2); } static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { __m512i r[8]; // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 // e10e11 e12e13 e14e15 b0-b15 f0-f15 c0-c15 g0-g15 d0-d15 h0-h15 i0-i15 // m0-m15 j0-j15 n0-n15 k0-k15 o0-o15 l0-l15 p0-p15 #ifndef __msvc_cl__ #pragma unroll(4) #endif for (int i = 0; i < 4; i++) { r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01); r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01); } // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 // f8f9 e10e11 f10f11 u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 // f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15 u2: c0c1 d0d1 c2c3 d2d3 c8c9 // d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11 u3: c4c5 // d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 // h12h13 g14g15 h14h15 i j m n k l o p #ifndef __msvc_cl__ #pragma unroll(4) #endif for (int i = 0; i < 8; i += 2) { u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]); u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]); } // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 // g8g9 h8h9 r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 // g2g3 h2h3 e10e11 f10f11 g10g11 h10h11 r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 // c12c13 d12d13 r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 r4: i j k // l m n o p r[0] = _mm512_unpacklo_epi64(u[0], u[2]); r[1] = _mm512_unpackhi_epi64(u[0], u[2]); r[2] = _mm512_unpacklo_epi64(u[1], u[3]); r[3] = _mm512_unpackhi_epi64(u[1], u[3]); r[4] = _mm512_unpacklo_epi64(u[4], u[6]); r[5] = _mm512_unpackhi_epi64(u[4], u[6]); r[6] = _mm512_unpacklo_epi64(u[5], u[7]); r[7] = _mm512_unpackhi_epi64(u[5], u[7]); __m512i const1 = _mm512_set_epi32( 0x00370035, 0x00330031, 0x00270025, 0x00230021, 0x00170015, 0x00130011, 0x00070005, 0x00030001, 0x00360034, 0x00320030, 0x00260024, 0x00220020, 0x00160014, 0x00120010, 0x00060004, 0x00020000); __m512i const2 = _mm512_set_epi32( 0x003f003d, 0x003b0039, 0x002f002d, 0x002b0029, 0x001f001d, 0x001b0019, 0x000f000d, 0x000b0009, 0x003e003c, 0x003a0038, 0x002e002c, 0x002a0028, 0x001e001c, 0x001a0018, 0x000e000c, 0x000a0008); // merge values from two regs // 0-- 1-- // 8-- 9-- // 2-- 3-- // 10-- 11-- // 4-- 5-- // 12-- 13-- // 6-- 7-- // 14-- 15-- #ifndef __msvc_cl__ #pragma unroll(4) #endif for (int i = 0; i < 4; i++) { u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]); u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]); } } // TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16 // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 template <> inline void transpose_mxn( const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { __m256i t[16]; // load from src to registers // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 #ifndef __msvc_cl__ #pragma unroll(16) #endif for (int i = 0; i < 16; i++) { t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); } __m512i u[8]; _transpose_mxn_half_16_16(t, u); #ifndef __msvc_cl__ #pragma unroll(8) #endif for (int i = 0; i < 8; i++) { _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), _mm512_extracti32x8_epi32(u[i], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), _mm512_extracti32x8_epi32(u[i], 0x01)); } } // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 template <> inline void transpose_mxn( const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { __m256i t[16]; // load from src to registers // Same matrix indices as above transpose_mxn #ifndef __msvc_cl__ #pragma unroll(16) #endif for (int i = 0; i < 16; i++) { t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); } __m512i u[8]; _transpose_mxn_half_16_16(t, u); #ifndef __msvc_cl__ #pragma unroll(8) #endif for (int i = 0; i < 8; i++) { _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), _mm512_extracti32x8_epi32(u[i], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), _mm512_extracti32x8_epi32(u[i], 0x01)); } } static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59 // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63 // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123 // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127 // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 // ... 187 t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 // 175 148 ... 191 t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 // 234 203 235 208 ... 251 t[7]: 196 228 197 229 198 230 199 231 204 236 205 // 237 206 238 207 239 212 ... 255 t[8]: 256 288 257 289 258 290 259 291 264 // 296 265 297 266 298 267 299 272 ... 315 t[9]: 260 292 261 293 262 294 263 // 295 268 300 269 301 270 302 271 303 276 ... 319 t[10]: 320 352 321 353 322 // 354 323 355 328 360 329 361 330 362 331 363 336 ... 379 t[11]: 324 356 325 // 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383 t[12]: 384 // 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443 // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 // ... 447 t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 // 491 464 ... 507 t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 // 494 463 495 468 ... 511 t[16]: 512 544 513 545 514 546 515 547 520 552 521 // 553 522 554 523 555 528 ... 571 // ... // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 // 980 ... 1023 #ifndef __msvc_cl__ #pragma unroll(16) #endif for (int i = 0; i < 16; ++i) { d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]); d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]); } // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121 // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123 // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125 // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127 // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 // ... 249 t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 // 235 146 ... 251 t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 // 173 205 237 148 ... 253 t[7]: 134 166 198 230 135 167 199 231 142 174 206 // 238 143 175 207 239 150 ... 255 t[8]: 256 288 320 352 257 289 321 353 264 // 296 328 360 265 297 329 361 272 ... 377 t[9]: 258 290 322 354 259 291 323 // 355 266 298 330 362 267 299 331 363 274 ... 379 t[10]: 260 292 324 356 261 // 293 325 357 268 300 332 364 269 301 333 365 276 ... 381 t[11]: 262 294 326 // 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383 t[12]: 384 // 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505 // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 // ... 507 t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 // 493 404 ... 509 t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 // 431 463 495 406 ... 511 t[16]: 512 544 576 608 513 545 577 609 520 552 584 // 616 521 553 585 617 528 ... 633 // ... // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 // 918 ... 1023 #ifndef __msvc_cl__ #pragma unroll(8) #endif for (int i = 0; i < 8; ++i) { r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]); r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]); r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]); r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]); } // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248 // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249 // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250 // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251 // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252 // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253 // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254 // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255 // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 // ... 504 t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 // 489 273 ... 505 t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 // 426 458 490 274 ... 506 t[11]: 259 291 323 355 387 419 451 483 267 299 331 // 363 395 427 459 491 275 ... 507 t[12]: 260 292 324 356 388 420 452 484 268 // 300 332 364 396 428 460 492 276 ... 508 t[13]: 261 293 325 357 389 421 453 // 485 269 301 333 365 397 429 461 493 277 ... 509 t[14]: 262 294 326 358 390 // 422 454 486 270 302 334 366 398 430 462 494 278 ... 510 t[15]: 263 295 327 // 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511 t[16]: 512 // 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760 // ... // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 // ... 1023 #ifndef __msvc_cl__ #pragma unroll(4) #endif for (int i = 0; i < 4; ++i) { d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]); d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]); d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]); d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]); d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]); d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]); d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]); d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]); } // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496 // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497 // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498 // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499 // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... // 500 t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 // ... 501 t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 // 22 ... 502 t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 // 487 23 ... 503 t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 // 456 488 24 ... 504 t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 // 425 457 489 25 ... 505 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 // 394 426 458 490 26 ... 506 t[11]: 11 43 75 107 139 171 203 235 267 299 331 // 363 395 427 459 491 27 ... 507 t[12]: 12 44 76 108 140 172 204 236 268 300 // 332 364 396 428 460 492 28 ... 508 t[13]: 13 45 77 109 141 173 205 237 269 // 301 333 365 397 429 461 493 29 ... 509 t[14]: 14 46 78 110 142 174 206 238 // 270 302 334 366 398 430 462 494 30 ... 510 t[15]: 15 47 79 111 143 175 207 // 239 271 303 335 367 399 431 463 495 31 ... 511 t[16]: 512 544 576 608 640 // 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008 // ... // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 // ... 1023 __m512i const1 = _mm512_set_epi64( 0x000000000000000d, 0x000000000000000c, 0x0000000000000005, 0x0000000000000004, 0x0000000000000009, 0x0000000000000008, 0x0000000000000001, 0x0000000000000000); __m512i const2 = _mm512_set_epi64( 0x000000000000000f, 0x000000000000000e, 0x0000000000000007, 0x0000000000000006, 0x000000000000000b, 0x000000000000000a, 0x0000000000000003, 0x0000000000000002); #ifndef __msvc_cl__ #pragma unroll(8) #endif for (int i = 0; i < 8; ++i) { r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/ const1, d[i + 8]); r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/ const2, d[i + 8]); r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const1, d[i + 24]); r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const2, d[i + 24]); } // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 // ... 992 t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 // 513 545 ... 993 t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 // 450 482 514 546 ... 994 t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 // 387 419 451 483 515 547 ... 995 t[4]: 4 36 68 100 132 164 196 228 260 292 // 324 356 388 420 452 484 516 548 ... 996 t[5]: 5 37 69 101 133 165 197 229 // 261 293 325 357 389 421 453 485 517 549 ... 997 t[6]: 6 38 70 102 134 166 // 198 230 262 294 326 358 390 422 454 486 518 550 ... 998 t[7]: 7 39 71 103 // 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999 t[8]: 8 40 // 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000 // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 // ... 1001 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 // 490 522 554 ... 1002 t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 // 395 427 459 491 523 555 ... 1003 t[12]: 12 44 76 108 140 172 204 236 268 // 300 332 364 396 428 460 492 524 556 ... 1004 t[13]: 13 45 77 109 141 173 // 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005 t[14]: 14 46 78 // 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006 t[15]: // 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... // 1007 t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 // 528 560 ... 1008 // ... // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 // ... 1023 __m512i const3 = _mm512_set_epi64( 0x000000000000000b, 0x000000000000000a, 0x0000000000000009, 0x0000000000000008, 0x0000000000000003, 0x0000000000000002, 0x0000000000000001, 0x0000000000000000); __m512i const4 = _mm512_set_epi64( 0x000000000000000f, 0x000000000000000e, 0x000000000000000d, 0x000000000000000c, 0x0000000000000007, 0x0000000000000006, 0x0000000000000005, 0x0000000000000004); #ifndef __msvc_cl__ #pragma unroll(16) #endif for (int i = 0; i < 16; ++i) { d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/ const3, r[i + 16]); d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/ const4, r[i + 16]); } } // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 template <> inline void transpose_mxn( const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) { // load from src TORCH_CHECK( M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); __m512i r[32]; int i; if (N == 32) { for (i = 0; i < M; ++i) { r[i] = _mm512_loadu_si512(&src[i * ld_src]); } } else { __mmask32 src_mask = (1 << N) - 1; for (i = 0; i < M; ++i) { r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); } __m512i d[32]; _transpose_mxn_half_32_32(r, d); // store to dst if (M == 32) { for (i = 0; i < N; ++i) { _mm512_storeu_si512(&dst[i * ld_dst], d[i]); } } else { __mmask32 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); } } } template < typename T, int M, int N, typename std::enable_if_t< std::is_same_v && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> inline void transpose_mxn( const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } template <> inline void transpose_mxn( const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) { TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); // load from src __m512i r[32]; int i; if (N == 32) { for (i = 0; i < M; ++i) { r[i] = _mm512_loadu_si512(&src[i * ld_src]); } } else { __mmask32 src_mask = (1 << N) - 1; for (i = 0; i < M; ++i) { r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); } __m512i d[32]; _transpose_mxn_half_32_32(r, d); // store to dst if (M == 32) { for (i = 0; i < N; ++i) { _mm512_storeu_si512(&dst[i * ld_dst], d[i]); } } else { __mmask32 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); } } } template < typename T, int M, int N, typename std::enable_if_t< std::is_same_v && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> inline void transpose_mxn( const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized : public Vectorized16 { public: using Vectorized16::Vectorized16; using value_type = Half; Vectorized frac() const; Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); }); } Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); }); } Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); }); } Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); }); } Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return _mm512_and_si512(a, b); } Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return _mm512_or_si512(a, b); } Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return _mm512_xor_si512(a, b); } inline Vectorized Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0f); } inline Vectorized Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0f); } inline Vectorized Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0f); } inline Vectorized Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0f); } inline Vectorized Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0f); } inline Vectorized Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0f); } // frac. Implement this here so we can use subtraction inline Vectorized Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(b), b_lo, b_hi); auto max_lo = _mm512_max_ps(a_lo, b_lo); auto max_hi = _mm512_max_ps(a_hi, b_hi); auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); // Exploit the fact that all-ones is a NaN. auto o1 = _mm512_or_ps(max_lo, nan_lo); auto o2 = _mm512_or_ps(max_hi, nan_hi); return cvtfp32_fp16(o1, o2); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; __m512i zero_vec = _mm512_set1_epi32(0); cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(b), b_lo, b_hi); auto min_lo = _mm512_min_ps(a_lo, b_lo); auto min_hi = _mm512_min_ps(a_hi, b_hi); auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); auto nan_lo = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); auto nan_hi = _mm512_castsi512_ps( _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); // Exploit the fact that all-ones is a NaN. auto o1 = _mm512_or_ps(min_lo, nan_lo); auto o2 = _mm512_or_ps(min_hi, nan_hi); return cvtfp32_fp16(o1, o2); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { __m512 a_lo, a_hi; __m512 min_lo, min_hi; __m512 max_lo, max_hi; cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(min), min_lo, min_hi); cvtfp16_fp32(__m512i(max), max_lo, max_hi); auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); return cvtfp32_fp16(o1, o2); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { __m512 a_lo, a_hi; __m512 max_lo, max_hi; cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(max), max_lo, max_hi); auto o1 = _mm512_min_ps(max_lo, a_lo); auto o2 = _mm512_min_ps(max_hi, a_hi); return cvtfp32_fp16(o1, o2); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { __m512 a_lo, a_hi; __m512 min_lo, min_hi; cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(min), min_lo, min_hi); auto o1 = _mm512_max_ps(min_lo, a_lo); auto o2 = _mm512_max_ps(min_hi, a_hi); return cvtfp32_fp16(o1, o2); } template <> inline void convert(const Half* src, Half* dst, int64_t n) { int64_t i; #ifndef __msvc_cl__ #pragma unroll #endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); } #ifndef __msvc_cl__ #pragma unroll #endif for (; i < n; i++) { dst[i] = src[i]; } } template <> inline void convert(const float* src, Half* dst, int64_t n) { int64_t i; for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) { __m512 a = _mm512_loadu_ps(&src[i]); __m512 b = _mm512_loadu_ps(&src[i + 16]); __m512i bf = cvtfp32_fp16(a, b); _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); } for (; i < n; i++) { dst[i] = c10::convert(src[i]); } } template <> inline void convert(const double* src, Half* dst, int64_t n) { auto load_float = [](const double* src) -> __m512 { // Load one float vector from an array of doubles __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); }; int64_t i; for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) { __m512 a = load_float(&src[i]); __m512 b = load_float(&src[i + 16]); __m512i bf = cvtfp32_fp16(a, b); _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); } for (; i < n; i++) { dst[i] = c10::convert(src[i]); } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { __m512 a_lo, a_hi; __m512 b_lo, b_hi; __m512 c_lo, c_hi; cvtfp16_fp32(__m512i(a), a_lo, a_hi); cvtfp16_fp32(__m512i(b), b_lo, b_hi); cvtfp16_fp32(__m512i(c), c_lo, c_hi); auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); return cvtfp32_fp16(o1, o2); } #define CONVERT_VECTORIZED_INIT(type, name) \ inline std::tuple, Vectorized> \ convert_##name##_float(const Vectorized& a) { \ __m512 o1, o2; \ cvt_to_fp32(__m512i(a), o1, o2); \ return std::make_tuple(o1, o2); \ } \ \ inline Vectorized convert_float_##name( \ const Vectorized& a, const Vectorized& b) { \ return cvt_from_fp32(__m512(a), __m512(b)); \ } CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) CONVERT_VECTORIZED_INIT(Half, half) #else // defined(CPU_CAPABILITY_AVX512) #define CONVERT_NON_VECTORIZED_INIT(type, name) \ inline std::tuple, Vectorized> \ convert_##name##_float(const Vectorized& a) { \ constexpr int64_t K = Vectorized::size(); \ __at_align__ float arr[K]; \ __at_align__ type arr2[K]; \ a.store(arr2); \ for (const auto k : c10::irange(K)) { \ arr[k] = c10::convert(arr2[k]); \ } \ return std::make_tuple( \ Vectorized::loadu(arr), \ Vectorized::loadu(arr + Vectorized::size())); \ } \ \ inline Vectorized convert_float_##name( \ const Vectorized& a, const Vectorized& b) { \ constexpr int64_t K = Vectorized::size(); \ __at_align__ float arr[K]; \ __at_align__ type arr2[K]; \ a.store(arr); \ b.store(arr + Vectorized::size()); \ for (const auto k : c10::irange(K)) { \ arr2[k] = c10::convert(arr[k]); \ } \ return Vectorized::loadu(arr2); \ } CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) CONVERT_NON_VECTORIZED_INIT(Half, half) #endif // defined(CPU_CAPABILITY_AVX512) #if defined(CPU_CAPABILITY_AVX512) #define LOAD_FP32_VECTORIZED_INIT(type, name) \ inline void load_fp32_from_##name( \ const type* data, Vectorized& out) { \ auto values = _mm256_loadu_si256(reinterpret_cast(data)); \ __m512 out_values; \ cvt_to_fp32(values, out_values); \ out = out_values; \ } \ \ inline void load_fp32_from_##name( \ const type* data, Vectorized& out1, Vectorized& out2) { \ auto vec = Vectorized::loadu(data); \ __m512 out1_values, out2_values; \ cvt_to_fp32(vec, out1_values, out2_values); \ out1 = out1_values; \ out2 = out2_values; \ } LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) LOAD_FP32_VECTORIZED_INIT(Half, fp16) #else // defined(CPU_CAPABILITY_AVX512) #define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ inline void load_fp32_from_##name( \ const type* data, Vectorized& out) { \ __at_align__ float values[Vectorized::size()]; \ for (const auto k : c10::irange(Vectorized::size())) { \ values[k] = data[k]; \ } \ out = Vectorized::loadu(values); \ } \ \ inline void load_fp32_from_##name( \ const type* data, Vectorized& out1, Vectorized& out2) { \ load_fp32_from_##name(data, out1); \ data += Vectorized::size(); \ load_fp32_from_##name(data, out2); \ } LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) #endif } // namespace CPU_CAPABILITY } // namespace at::vec