#pragma once #include #include #include namespace at::vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) #define VEC_INT_SVE_TEMPLATE(vl, bit) \ template <> \ struct is_vec_specialized_for : std::bool_constant {}; \ \ template <> \ class Vectorized { \ private: \ vls_int##bit##_t values; \ \ public: \ using value_type = int##bit##_t; \ using size_type = int; \ static constexpr size_type size() { \ return vl; \ } \ Vectorized() {} \ Vectorized(svint##bit##_t v) : values(v) {} \ Vectorized(int##bit##_t val) { \ values = svdup_n_s##bit(val); \ } \ template < \ typename... Args, \ typename = std::enable_if_t<(sizeof...(Args) == size())>> \ Vectorized(Args... vals) { \ __at_align__ int##bit##_t buffer[size()] = {vals...}; \ values = svld1_s##bit(ptrue, buffer); \ } \ operator svint##bit##_t() const { \ return values; \ } \ template \ static Vectorized blend( \ const Vectorized& a, \ const Vectorized& b) { \ __at_align__ int##bit##_t flag_arr[size()]; \ for (int i = 0; i < size(); ++i) { \ flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \ } \ svbool_t blend_mask = svcmpne_n_s##bit( \ svptrue_b##bit(), svld1_s##bit(svptrue_b##bit(), flag_arr), 0); \ return Vectorized( \ svsel_s##bit(blend_mask, b.values, a.values)); \ } \ static Vectorized blendv( \ const Vectorized& a, \ const Vectorized& b, \ const Vectorized& mask_) { \ svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ return svsel_s##bit(mask, b, a); \ } \ /* step sometimes requires a higher precision type (e.g., T=int, \ * step_t=double) */ \ template \ static Vectorized arange( \ int##bit##_t base = 0, \ step_t step = static_cast(1)) { \ __at_align__ int##bit##_t buffer[size()]; \ for (int64_t i = 0; i < size(); i++) { \ buffer[i] = base + i * step; \ } \ return svld1_s##bit(ptrue, buffer); \ } \ static Vectorized set( \ const Vectorized& a, \ const Vectorized& b, \ int##bit##_t count = size()) { \ if (count == 0) { \ return a; \ } else if (count < size()) { \ return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ } \ return b; \ } \ static Vectorized loadu( \ const void* ptr, \ int64_t count = size()) { \ if (count == size()) \ return svld1_s##bit( \ ptrue, reinterpret_cast(ptr)); \ svbool_t pg = svwhilelt_b##bit(0ull, count); \ return svld1_s##bit(pg, reinterpret_cast(ptr)); \ } \ void store(void* ptr, int64_t count = size()) const { \ if (count == size()) { \ svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ } else { \ svbool_t pg = svwhilelt_b##bit(0ull, count); \ svst1_s##bit(pg, reinterpret_cast(ptr), values); \ } \ } \ const int##bit##_t& operator[](int idx) const = delete; \ int##bit##_t& operator[](int idx) = delete; \ Vectorized abs() const { \ return svabs_s##bit##_x(ptrue, values); \ } \ Vectorized real() const { \ return values; \ } \ Vectorized imag() const { \ return svdup_n_s##bit(0); \ } \ Vectorized conj() const { \ return values; \ } \ Vectorized frac() const; \ Vectorized neg() const { \ return svneg_s##bit##_x(ptrue, values); \ } \ Vectorized operator==( \ const Vectorized& other) const { \ svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized operator!=( \ const Vectorized& other) const { \ svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized operator<( \ const Vectorized& other) const { \ svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized operator<=( \ const Vectorized& other) const { \ svbool_t mask = svcmple_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized operator>( \ const Vectorized& other) const { \ svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized operator>=( \ const Vectorized& other) const { \ svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ return svsel_s##bit( \ mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ } \ Vectorized eq(const Vectorized& other) const; \ Vectorized ne(const Vectorized& other) const; \ Vectorized gt(const Vectorized& other) const; \ Vectorized ge(const Vectorized& other) const; \ Vectorized lt(const Vectorized& other) const; \ Vectorized le(const Vectorized& other) const; \ }; \ template <> \ Vectorized inline operator+( \ const Vectorized& a, const Vectorized& b) { \ return svadd_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline operator-( \ const Vectorized& a, const Vectorized& b) { \ return svsub_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline operator*( \ const Vectorized& a, const Vectorized& b) { \ return svmul_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline maximum( \ const Vectorized& a, const Vectorized& b) { \ return svmax_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline minimum( \ const Vectorized& a, const Vectorized& b) { \ return svmin_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline clamp( \ const Vectorized& a, \ const Vectorized& min, \ const Vectorized& max) { \ return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ } \ template <> \ Vectorized inline clamp_max( \ const Vectorized& a, \ const Vectorized& max) { \ return svmin_s##bit##_x(ptrue, max, a); \ } \ template <> \ Vectorized inline clamp_min( \ const Vectorized& a, \ const Vectorized& min) { \ return svmax_s##bit##_x(ptrue, min, a); \ } \ template <> \ Vectorized inline operator&( \ const Vectorized& a, const Vectorized& b) { \ return svand_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline operator|( \ const Vectorized& a, const Vectorized& b) { \ return svorr_s##bit##_x(ptrue, a, b); \ } \ template <> \ Vectorized inline operator^( \ const Vectorized& a, const Vectorized& b) { \ return sveor_s##bit##_x(ptrue, a, b); \ } \ template <> \ inline Vectorized operator~( \ const Vectorized& a) { \ return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ } \ Vectorized inline Vectorized::eq( \ const Vectorized& other) const { \ return (*this == other) & Vectorized(1); \ } \ Vectorized inline Vectorized::ne( \ const Vectorized& other) const { \ return (*this != other) & Vectorized(1); \ } \ Vectorized inline Vectorized::gt( \ const Vectorized& other) const { \ return (*this > other) & Vectorized(1); \ } \ Vectorized inline Vectorized::ge( \ const Vectorized& other) const { \ return (*this >= other) & Vectorized(1); \ } \ Vectorized inline Vectorized::lt( \ const Vectorized& other) const { \ return (*this < other) & Vectorized(1); \ } \ Vectorized inline Vectorized::le( \ const Vectorized& other) const { \ return (*this <= other) & Vectorized(1); \ } VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) template Vectorized inline intdiv_nosve( const Vectorized& a, const Vectorized& b) { T values_a[Vectorized::size()]; T values_b[Vectorized::size()]; a.store(values_a); b.store(values_b); for (int i = 0; i != Vectorized::size(); i++) { values_a[i] /= values_b[i]; } return Vectorized::loadu(values_a); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return svdiv_s64_x(ptrue, a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return svdiv_s32_x(ptrue, a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return intdiv_nosve(a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return intdiv_nosve(a, b); } template <> inline void convert(const int32_t* src, int64_t* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_32 = svwhilelt_b32(i, n); pg_64 = svwhilelt_b64(i, n); svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); } } template <> inline void convert(const int64_t* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); svst1_f32(pg_32, dst + i, src_vec_f32); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_32 = svwhilelt_b32(i, n); pg_64 = svwhilelt_b64(i, n); svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); svst1_f32(pg_32, dst + i, src_vec_f32); } } template <> inline void convert(const int32_t* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svint32_t src_vec = svldnt1_s32(pg, src + i); svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg = svwhilelt_b32(i, n); svint32_t src_vec = svldnt1_s32(pg, src + i); svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); } } template <> inline void convert(const bool* src, int64_t* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_8 = svwhilelt_b8(i, n); pg_64 = svwhilelt_b64(i, n); svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); } } template <> inline void convert(const bool* src, int32_t* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_8 = svwhilelt_b8(i, n); pg_32 = svwhilelt_b32(i, n); svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); } } template <> inline void convert(const uint8_t* src, bool* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); svst1_u8( pg, reinterpret_cast(dst) + i, svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg = svwhilelt_b8(i, n); svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); svst1_u8( pg, reinterpret_cast(dst) + i, svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); } } template <> Vectorized inline operator<<( const Vectorized& a, const Vectorized& b) { return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); } template <> Vectorized inline operator<<( const Vectorized& a, const Vectorized& b) { return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); } template <> Vectorized inline operator<<( const Vectorized& a, const Vectorized& b) { return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); } template <> Vectorized inline operator<<( const Vectorized& a, const Vectorized& b) { return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); } template <> Vectorized inline operator>>( const Vectorized& a, const Vectorized& b) { return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); } template <> Vectorized inline operator>>( const Vectorized& a, const Vectorized& b) { return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); } template <> Vectorized inline operator>>( const Vectorized& a, const Vectorized& b) { return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); } template <> Vectorized inline operator>>( const Vectorized& a, const Vectorized& b) { return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); } #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY } // namespace at::vec