#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with SVE] #include #include #include #if defined(CPU_CAPABILITY_SVE) #include #include #include #include #endif namespace at::vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template<> inline Vectorized cast(const Vectorized& src) { return svreinterpret_f32_f64(src); } template<> inline Vectorized cast(const Vectorized& src) { return svreinterpret_f64_f32(src); } #define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ template<> \ inline Vectorized cast(const Vectorized& src) { \ return svreinterpret_s##int_bit##_f##float_bit(src); \ } \ template<> \ inline Vectorized cast(const Vectorized& src) { \ return svreinterpret_f##float_bit##_s##int_bit(src); \ } DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template std::enable_if_t> inline gather(const double* base_addr, const Vectorized& vindex_) { svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); return svld1_gather_s64index_f64(ptrue, base_addr, vindex); } template std::enable_if_t> inline gather(const float* base_addr, const Vectorized& vindex_) { svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); return svld1_gather_s32index_f32(ptrue, base_addr, vindex); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template std::enable_if_t> inline mask_gather(const Vectorized& src, const double* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); } template std::enable_if_t> inline mask_gather(const Vectorized& src, const float* base_addr, const Vectorized& vindex_, const Vectorized& mask_) { svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Only works for inputs in the range: [-2^51, 2^51] // From: https://stackoverflow.com/a/41148578 template<> Vectorized inline convert_to_int_of_same_size(const Vectorized &src) { svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); return svsub_s64_x(ptrue, svreinterpret_s64_f64(x), svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); } template<> Vectorized inline convert_to_int_of_same_size(const Vectorized &src) { return svcvt_s32_f32_x(ptrue, src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline interleave2(const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a3, a3} // b = {b0, b1, b2, b3} // group cols crossing lanes: // return {a0, b0, a1, b1} // {a2, b2, a3, b3} return std::make_pair(Vectorized(svzip1_f64(a, b)), Vectorized(svzip2_f64(a, b))); } template <> std::pair, Vectorized> inline interleave2(const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3, a4, a5, a6, a7} // b = {b0, b1, b2, b3, b4, b5, b6, b7} // group cols crossing lanes: // return {a0, b0, a1, b1, a2, b2, a3, b3} // {a4, b4, a5, b5, a6, b6, a7, b7} return std::make_pair(Vectorized(svzip1_f32(a, b)), Vectorized(svzip2_f32(a, b))); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline deinterleave2(const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1} // b = {a2, b2, a3, b3} // swap lanes: // return {a0, a1, a2, a3} // {b0, b1, b2, b3} return std::make_pair(Vectorized(svuzp1_f64(a, b)), Vectorized(svuzp2_f64(a, b))); } template <> std::pair, Vectorized> inline deinterleave2(const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1, a2, b2, a3, b3} // b = {a4, b4, a5, b5, a6, b6, a7, b7} // swap lanes: // return {a0, a1, a2, a3, a4, a5, a6, a7} // {b0, b1, b2, b3, b4, b5, b6, b7} return std::make_pair(Vectorized(svuzp1_f32(a, b)), Vectorized(svuzp2_f32(a, b))); } #endif // defined(CPU_CAPABILITY_SVE) }}