#pragma once #include #include #include #include namespace at::vec { inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { VectorizedN result; __m512 value; cvtbf16_fp32(_mm512_castsi512_si256(src[0]), value); result[0] = value; return result; } }; template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { VectorizedN result; __m512 value; cvtfp16_fp32(_mm512_castsi512_si256(src[0]), value); result[0] = value; return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { VectorizedN result; result[0] = _mm512_castsi256_si512(cvtfp32_bf16(src[0])); return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { VectorizedN result; result[0] = convert_float_bfloat16(src[0], src[1]); return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { VectorizedN result; std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); return result; } }; template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { VectorizedN result; result[0] = _mm512_castsi256_si512(cvtfp32_fp16(src[0])); return result; } }; template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { VectorizedN result; result[0] = convert_float_half(src[0], src[1]); return result; } }; template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { VectorizedN result; std::tie(result[0], result[1]) = convert_half_float(src[0]); return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto low = _mm512_cvtepi64_ps(src[0]); auto high = _mm512_cvtepi64_ps(src[1]); return Vectorized( _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { at::vec::VectorizedN result; result[0] = _mm512_cvt_roundps_epi64( _mm512_castps512_ps256(src[0]), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); result[1] = _mm512_cvt_roundps_epi64( _mm512_extractf32x8_ps(src[0], 1), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto low = _mm512_cvtepi64_epi32(src[0]); auto high = _mm512_cvtepi64_epi32(src[1]); return Vectorized( _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { at::vec::VectorizedN result; result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src[0])); result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src[0], 1)); return result; } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto src128 = _mm512_castsi512_si128(src[0]); return Vectorized(_mm512_cvtepi8_epi32(src128)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto src128 = _mm512_castsi512_si128(src[0]); return Vectorized(_mm512_cvtepu8_epi32(src128)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { return Vectorized(_mm512_cvttps_epi32(src[0])); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { return Vectorized(_mm512_cvtepi32_ps(src[0])); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto src256 = _mm512_castsi512_si256(src[0]); return Vectorized(_mm512_cvtepu8_epi16(src256)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto src128 = _mm512_cvtepi32_epi8(src[0]); return Vectorized(_mm512_castsi128_si512(src128)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src) { auto src256 = _mm512_cvtepi16_epi8(src[0]); return Vectorized(_mm512_castsi256_si512(src256)); } }; template struct VecConvert< dst_t, 1, src_t, 1, typename std::enable_if_t< (is_reduced_floating_point_v && is_8bit_integer_v) || (is_reduced_floating_point_v && is_8bit_integer_v), void>> { static inline VectorizedN apply(const VectorizedN& src) { VectorizedN tmp_fp32 = VecConvert::apply(src); return VecConvert::apply(tmp_fp32); } }; template struct VecConvert< dst_t, 1, float, 2, typename std::enable_if_t, void>> { static inline VectorizedN apply(const VectorizedN& src) { at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); __m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2)); __m512 result = _mm512_insertf32x4( _mm512_castsi512_ps(vec1), lane2, 1); // Insert lane2 into the second 128-bit lane return at::vec::Vectorized(_mm512_castps_si512(result)); } }; template struct VecConvert< dst_t, 1, float, 1, typename std::enable_if_t, void>> { static inline VectorizedN apply(const VectorizedN& src) { return convert_float_to_int8(src[0]); } }; template struct VecConvert< float, 2, src_t, 1, typename std::enable_if_t, void>> { static inline VectorizedN apply(const VectorizedN& src) { __m512i src2 = _mm512_castsi128_si512(_mm_castps_si128(_mm512_extractf32x4_ps( _mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane )); return VectorizedN( convert_int8_to_float(src[0]), convert_int8_to_float(src2)); } }; template struct VecConvert< float, 1, src_t, 1, typename std::enable_if_t, void>> { static inline VectorizedN apply(const VectorizedN& src) { return convert_int8_to_float(src[0]); } }; template struct VecConvert< dst_t, 1, int64_t, 2, std::enable_if_t< std::is_same_v || std::is_same_v>> { static inline VectorizedN apply( const VectorizedN& src) { return VecConvert::apply( VecConvert::apply(src)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src_n) { at::vec::Vectorized src = src_n[0]; __m128i res128 = cvtfp32_fp8e4m3(src); return at::vec::Vectorized(_mm512_castsi128_si512(res128)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src_n) { // cvt first 16x8 bits from Float8_e4m3fn to float at::vec::Vectorized src = src_n[0]; __m512 result; cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result); return at::vec::Vectorized(result); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src_n) { at::vec::Vectorized src = src_n[0]; __m128i res128 = cvtfp32_fp8e5m2(src); return at::vec::Vectorized(_mm512_castsi128_si512(res128)); } }; template <> struct VecConvert { static inline VectorizedN apply( const VectorizedN& src_n) { // cvt first 16x8 bits from Float8_e5m2 to float at::vec::Vectorized src = src_n[0]; __m512 result; cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result); return at::vec::Vectorized(result); } }; #endif } // namespace CPU_CAPABILITY } // namespace at::vec