#pragma once #include #include #include namespace at::vec { inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) template struct VecMaskLoad< T, dst_n, mask_t, mask_n, typename std::enable_if_t< (mask_n == dst_n * 2 && dst_n >= 1) && (std::is_same_v || std::is_same_v), void>> { static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { at::vec::Vectorized zero_vec(0); auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); VectorizedN tmp_vec; VectorizedN result; for (int i = 0; i < dst_n; i++) { tmp_vec[0] = vec_mask[2 * i]; tmp_vec[1] = vec_mask[2 * i + 1]; auto int64_mask = VecMask(tmp_vec).template cast(); auto int_mask = int64_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); if constexpr (std::is_same_v) { result[i] = Vectorized(_mm512_mask_loadu_ps( zero_vec, mmask, ptr + i * Vectorized::size())); } else { result[i] = Vectorized(_mm512_mask_loadu_epi32( zero_vec, mmask, ptr + i * Vectorized::size())); } } return result; } }; template struct VecMaskLoad< T, dst_n, mask_t, dst_n, typename std::enable_if_t< std::is_same_v || std::is_same_v, void>> { static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { at::vec::Vectorized zero_vec(0); auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < dst_n; i++) { auto tmp_mask = VecMask(vec_mask[i]); auto int_mask = tmp_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); if constexpr (std::is_same_v) { result[i] = Vectorized(_mm512_mask_loadu_ps( zero_vec, mmask, ptr + i * Vectorized::size())); } else { result[i] = Vectorized(_mm512_mask_loadu_epi32( zero_vec, mmask, ptr + i * Vectorized::size())); } } return result; } }; template struct VecMaskLoad< data_t, dst_n, mask_t, dst_n, std::enable_if_t< std::is_same_v || std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < dst_n; i++) { auto tmp_mask = VecMask(vec_mask[i]); auto int_mask = tmp_mask.template cast(); auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); auto zero = _mm256_set1_epi16(0); auto temp0 = _mm256_mask_loadu_epi16( zero, mmask0, ptr + (2 * i) * Vectorized::size()); auto temp1 = _mm256_mask_loadu_epi16( zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); result[i] = Vectorized( _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); } return result; } }; template struct VecMaskLoad< data_t, dst_n, mask_t, mask_n, typename std::enable_if_t< (mask_n == 2 * dst_n && dst_n >= 1) && (std::is_same_v || std::is_same_v)>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); VectorizedN result; VectorizedN tmp_vec; for (int i = 0; i < dst_n; i++) { tmp_vec[0] = vec_mask[2 * i]; tmp_vec[1] = vec_mask[2 * i + 1]; auto int_mask = VecMask(tmp_vec).template cast(); auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); auto zero = _mm256_set1_epi16(0); auto temp0 = _mm256_mask_loadu_epi16( zero, mmask0, ptr + (2 * i) * Vectorized::size()); auto temp1 = _mm256_mask_loadu_epi16( zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); result[i] = Vectorized( _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); } return result; } }; template struct VecMaskLoad< data_t, 1, mask_t, 1, std::enable_if_t< std::is_same_v || std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); auto int_mask = vec_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); auto zero = _mm_set1_epi8(0); auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr); return Vectorized( _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0)); } }; template struct VecMaskLoad< data_t, 2, mask_t, 1, std::enable_if_t< std::is_same_v || std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); at::vec::Vectorized zero_vec(0); auto int_mask = vec_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); at::vec::VectorizedN result; if constexpr (std::is_same_v) { result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); result[1] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); } else { result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); result[1] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm512_castsi512_ps(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm512_castps_si512(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm512_castpd_si512(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm512_castsi512_pd(vec_mask[i]); } return result; } }; template struct VecMaskCast< int64_t, dst_n, mask_t, mask_n, typename std::enable_if_t< (dst_n == 2 * mask_n) && (std::is_same_v || std::is_same_v), void>> { static inline VecMask apply( const VecMask& vec_mask) { VectorizedN result; auto int_mask = vec_mask.template cast(); #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < mask_n; ++i) { auto int64_vec = convert(VectorizedN(int_mask[i])); result[2 * i] = int64_vec[0]; result[2 * i + 1] = int64_vec[1]; } return VecMask(result); } }; template struct VecMaskCast< dst_t, dst_n, int64_t, mask_n, typename std::enable_if_t< (mask_n == 2 * dst_n) && (std::is_same_v || std::is_same_v), void>> { static inline VecMask apply( const VecMask& vec_mask) { VectorizedN result; VectorizedN int64_vec; for (int i = 0; i < dst_n; ++i) { int64_vec[0] = vec_mask[2 * i]; int64_vec[1] = vec_mask[2 * i + 1]; result[i] = convert(int64_vec); } return VecMask(result).template cast(); } }; template <> struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { auto int64_mask = VecMaskCast::apply(vec_mask); return VecMaskCast::apply(int64_mask); } }; template <> struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { auto int64_mask = VecMaskCast::apply(vec_mask); return VecMaskCast::apply(int64_mask); } }; template <> inline bool VecMask::all_zero() const { __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]); return mask == 0; } template <> inline bool VecMask::is_masked(int i) const { return _mm512_movepi32_mask(mask_[0]) & (1 << i); } template <> inline bool VecMask::all_masked() const { __mmask16 mask = _mm512_movepi32_mask(mask_[0]); return mask == 0xffff; } template struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { bool all_zero = true; for (int i = 0; i < N; ++i) { all_zero = all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); if (!all_zero) { return all_zero; } } return all_zero; } static inline bool is_masked(const VectorizedN& vec_mask, int i) { for (int j = 0; j < N; ++j) { if (i < (j + 1) * 8) { return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); } } return false; } static inline bool all_masked(const VectorizedN& vec_mask) { bool all_masked = true; for (int i = 0; i < N; ++i) { all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); if (!all_masked) { return all_masked; } } return all_masked; } }; #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ inline return_type VecMask::method args_def const { \ return cast().method args; \ } VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) #undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT #endif } // namespace CPU_CAPABILITY } // namespace at::vec