#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with SVE] #include #include #include #if defined(CPU_CAPABILITY_SVE) #include #include #include #include #include #endif namespace at::vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \ template <> \ inline Vectorized cast(const Vectorized& src) { \ return svreinterpret_##t1_prefix##_##t2_prefix(src); \ } \ template <> \ inline Vectorized cast(const Vectorized& src) { \ return svreinterpret_##t2_prefix##_##t1_prefix(src); \ } DEFINE_SVE_CAST(int64_t, s64, double, f64) DEFINE_SVE_CAST(int32_t, s32, double, f64) DEFINE_SVE_CAST(int16_t, s16, double, f64) DEFINE_SVE_CAST(int64_t, s64, float, f32) DEFINE_SVE_CAST(int32_t, s32, float, f32) DEFINE_SVE_CAST(int16_t, s16, float, f32) DEFINE_SVE_CAST(float, f32, double, f64) #ifdef __ARM_FEATURE_BF16 DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) #endif // __ARM_FEATURE_BF16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< double>> inline gather(const double* base_addr, const Vectorized& vindex_) { svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); return svld1_gather_s64index_f64(ptrue, base_addr, vindex); } template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< float>> inline gather(const float* base_addr, const Vectorized& vindex_) { svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); return svld1_gather_s32index_f32(ptrue, base_addr, vindex); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template std:: enable_if_t> inline mask_gather( const Vectorized& src, const double* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); return svsel_f64( mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); } template std:: enable_if_t> inline mask_gather( const Vectorized& src, const float* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); return svsel_f32( mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Only works for inputs in the range: [-2^51, 2^51] // From: https://stackoverflow.com/a/41148578 template <> Vectorized inline convert_to_int_of_same_size( const Vectorized& src) { svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); return svsub_s64_x( ptrue, svreinterpret_s64_f64(x), svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); } template <> Vectorized inline convert_to_int_of_same_size( const Vectorized& src) { return svcvt_s32_f32_x(ptrue, src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a3, a3} // b = {b0, b1, b2, b3} // group cols crossing lanes: // return {a0, b0, a1, b1} // {a2, b2, a3, b3} return std::make_pair( Vectorized(svzip1_f64(a, b)), Vectorized(svzip2_f64(a, b))); } template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3, a4, a5, a6, a7} // b = {b0, b1, b2, b3, b4, b5, b6, b7} // group cols crossing lanes: // return {a0, b0, a1, b1, a2, b2, a3, b3} // {a4, b4, a5, b5, a6, b6, a7, b7} return std::make_pair( Vectorized(svzip1_f32(a, b)), Vectorized(svzip2_f32(a, b))); } #ifdef __ARM_FEATURE_BF16 template <> std::pair< Vectorized, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3, a4, a5, a6, a7} // b = {b0, b1, b2, b3, b4, b5, b6, b7} // group cols crossing lanes: // return {a0, b0, a1, b1, a2, b2, a3, b3} // {a4, b4, a5, b5, a6, b6, a7, b7} return std::make_pair( Vectorized(svzip1_bf16(a, b)), Vectorized(svzip2_bf16(a, b))); } #endif // __ARM_FEATURE_BF16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1} // b = {a2, b2, a3, b3} // swap lanes: // return {a0, a1, a2, a3} // {b0, b1, b2, b3} return std::make_pair( Vectorized(svuzp1_f64(a, b)), Vectorized(svuzp2_f64(a, b))); } template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1, a2, b2, a3, b3} // b = {a4, b4, a5, b5, a6, b6, a7, b7} // swap lanes: // return {a0, a1, a2, a3, a4, a5, a6, a7} // {b0, b1, b2, b3, b4, b5, b6, b7} return std::make_pair( Vectorized(svuzp1_f32(a, b)), Vectorized(svuzp2_f32(a, b))); } #ifdef __ARM_FEATURE_BF16 template <> std::pair< Vectorized, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1, a2, b2, a3, b3} // b = {a4, b4, a5, b5, a6, b6, a7, b7} // swap lanes: // return {a0, a1, a2, a3, a4, a5, a6, a7} // {b0, b1, b2, b3, b4, b5, b6, b7} return std::make_pair( Vectorized(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)), Vectorized(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b))); } #endif // __ARM_FEATURE_BF16 #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY } // namespace at::vec