#pragma once #include #include #include namespace at::vec { inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template struct VecMaskLoad< T, dst_n, mask_t, mask_n, typename std::enable_if_t< (mask_n == dst_n * 2 && dst_n >= 1) && (std::is_same_v || std::is_same_v), void>> { static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { VectorizedN tmp_vec; VectorizedN result; for (int i = 0; i < dst_n; i++) { tmp_vec[0] = vec_mask[2 * i]; tmp_vec[1] = vec_mask[2 * i + 1]; auto int64_mask = VecMask(tmp_vec).template cast(); auto int_mask = int64_mask.template cast()[0]; if constexpr (std::is_same_v) { result[i] = Vectorized( _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); } else { result[i] = Vectorized( _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); } } return result; } }; template struct VecMaskLoad< T, dst_n, mask_t, dst_n, typename std::enable_if_t< std::is_same_v || std::is_same_v, void>> { static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < dst_n; i++) { auto tmp_mask = VecMask(vec_mask[i]); auto int_mask = tmp_mask.template cast()[0]; if constexpr (std::is_same_v) { result[i] = Vectorized( _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); } else { result[i] = Vectorized( _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); } } return result; } }; template struct VecMaskLoad< T, 2, mask_t, 1, typename std::enable_if_t< std::is_same_v || std::is_same_v>> { static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { auto int64_mask = vec_mask.template cast(); auto result = at::vec::VectorizedN(); if constexpr (std::is_same_v) { result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); result[1] = _mm256_maskload_pd( ptr + at::vec::Vectorized::size(), int64_mask[1]); } else { result[0] = _mm256_maskload_epi64( reinterpret_cast(ptr), int64_mask[0]); result[1] = _mm256_maskload_epi64( reinterpret_cast( ptr + at::vec::Vectorized::size()), int64_mask[1]); } return result; } }; // TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm256_castsi256_ps(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm256_castps_si256(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm256_castpd_si256(vec_mask[i]); } return result; } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { VectorizedN result; #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < N; ++i) { result[i] = _mm256_castsi256_pd(vec_mask[i]); } return result; } }; template struct VecMaskCast< int64_t, dst_n, mask_t, mask_n, typename std::enable_if_t< (dst_n == 2 * mask_n) && (std::is_same_v || std::is_same_v), void>> { static inline VecMask apply( const VecMask& vec_mask) { VectorizedN result; auto int_mask = vec_mask.template cast(); #ifndef _MSC_VER #pragma unroll #endif for (int i = 0; i < mask_n; ++i) { auto int64_vec = convert(VectorizedN(int_mask[i])); result[2 * i] = int64_vec[0]; result[2 * i + 1] = int64_vec[1]; } return VecMask(result); } }; template struct VecMaskCast< dst_t, dst_n, int64_t, mask_n, typename std::enable_if_t< (mask_n == 2 * dst_n) && (std::is_same_v || std::is_same_v), void>> { static inline VecMask apply( const VecMask& vec_mask) { VectorizedN result; VectorizedN int64_vec; for (int i = 0; i < dst_n; ++i) { int64_vec[0] = vec_mask[2 * i]; int64_vec[1] = vec_mask[2 * i + 1]; result[i] = convert(int64_vec); } return VecMask(result).template cast(); } }; template <> struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { auto int64_mask = VecMaskCast::apply(vec_mask); return VecMaskCast::apply(int64_mask); } }; template <> struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { auto int64_mask = VecMaskCast::apply(vec_mask); return VecMaskCast::apply(int64_mask); } }; template <> inline bool VecMask::all_zero() const { return _mm256_testz_si256(mask_[0], mask_[0]); } template <> inline bool VecMask::is_masked(int i) const { return _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])) & (1 << i); } template <> inline bool VecMask::all_masked() const { int mask = _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])); return mask == 0xff; } template struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { bool all_zero = true; for (int i = 0; i < N; ++i) { all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); if (!all_zero) { return all_zero; } } return all_zero; } static inline bool is_masked(const VectorizedN& vec_mask, int i) { for (int j = 0; j < N; ++j) { if (i < (j + 1) * 4) { return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & (1 << (i - j * 4)); } } return false; } static inline bool all_masked(const VectorizedN& vec_mask) { bool all_masked = true; for (int i = 0; i < N; ++i) { all_masked = all_masked && (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); if (!all_masked) { return all_masked; } } return all_masked; } }; #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ inline return_type VecMask::method args_def const { \ return cast().method args; \ } VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) #undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT #endif } // namespace CPU_CAPABILITY } // namespace at::vec