#pragma once #include #include #include #include namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { inline std::tuple, Vectorized> convert_bfloat16_float( const Vectorized& a) { constexpr int64_t K = Vectorized::size(); __at_align__ float arr[K]; __at_align__ BFloat16 arr2[K]; a.store(arr2); convert(arr2, arr, K); return std::make_tuple( Vectorized::loadu(arr), Vectorized::loadu(arr + Vectorized::size())); } inline Vectorized convert_float_bfloat16( const Vectorized& a, const Vectorized& b) { constexpr int64_t K = Vectorized::size(); __at_align__ float arr[K]; __at_align__ BFloat16 arr2[K]; a.store(arr); b.store(arr + Vectorized::size()); convert(arr, arr2, K); return Vectorized::loadu(arr2); } inline void load_fp32_from_bf16( const c10::BFloat16* data, Vectorized& out) { __at_align__ float values[Vectorized::size()]; for (const auto k : c10::irange(Vectorized::size())) { values[k] = data[k]; } out = Vectorized::loadu(values); } inline void load_fp32_from_bf16( const c10::BFloat16* data, Vectorized& out1, Vectorized& out2) { load_fp32_from_bf16(data, out1); data += Vectorized::size(); load_fp32_from_bf16(data, out2); } inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) { __at_align__ float values[Vectorized::size()]; for (const auto k : c10::irange(Vectorized::size())) { values[k] = data[k]; } out = Vectorized::loadu(values); } inline void load_fp32_from_fp16( const c10::Half* data, Vectorized& out1, Vectorized& out2) { load_fp32_from_fp16(data, out1); data += Vectorized::size(); load_fp32_from_fp16(data, out2); } } // namespace CPU_CAPABILITY } // namespace vec } // namespace at