#pragma once #include #include #include #include #if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) #include #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code #else #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code #endif namespace at::vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized { private: vls_float64_t values; public: using value_type = double; using size_type = int; static constexpr size_type size() { return VECTOR_WIDTH / sizeof(double); } Vectorized() {} Vectorized(svfloat64_t v) : values(v) {} Vectorized(double val) { values = svdup_n_f64(val); } template < typename... Args, typename = std::enable_if_t<(sizeof...(Args) == size())>> Vectorized(Args... vals) { __at_align__ double buffer[size()] = {vals...}; values = svld1_f64(ptrue, buffer); } operator svfloat64_t() const { return values; } template static Vectorized blend( const Vectorized& a, const Vectorized& b) { // Build an array of flags: each element is 1 if the corresponding bit in // 'mask' is set, 0 otherwise. __at_align__ int64_t flag_arr[size()]; for (int i = 0; i < size(); i++) { flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; } // Load the flag array into an SVE int64 vector. svint64_t int_mask = svld1_s64(svptrue_b64(), flag_arr); // Compare each lane of int_mask to 0; returns an svbool_t predicate where // true indicates a nonzero flag. svbool_t blend_mask = svcmpne_n_s64(svptrue_b64(), int_mask, 0); // Use svsel to select elements from b where the predicate is true, else // from a. svfloat64_t result = svsel(blend_mask, b.values, a.values); return Vectorized(result); } static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask_) { svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); return svsel_f64(mask, b, a); } template static Vectorized arange( double base = 0., step_t step = static_cast(1)) { __at_align__ double buffer[size()]; for (int64_t i = 0; i < size(); i++) { buffer[i] = base + i * step; } return svld1_f64(ptrue, buffer); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { if (count == 0) { return a; } else if (count < size()) { return svsel_f64(svwhilelt_b64(0ull, count), b, a); } return b; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return svld1_f64(ptrue, reinterpret_cast(ptr)); svbool_t pg = svwhilelt_b64(0ull, count); return svld1_f64(pg, reinterpret_cast(ptr)); } void store(void* ptr, int64_t count = size()) const { if (count == size()) { svst1_f64(ptrue, reinterpret_cast(ptr), values); } else { svbool_t pg = svwhilelt_b64(0ull, count); svst1_f64(pg, reinterpret_cast(ptr), values); } } const double& operator[](int idx) const = delete; double& operator[](int idx) = delete; int64_t zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit // and others are translated to 0-bit int64_t mask = 0; __at_align__ int64_t mask_array[size()]; svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); svst1_s64( ptrue, mask_array, svsel_s64(svbool_mask, ALL_S64_TRUE_MASK, ALL_S64_FALSE_MASK)); for (int64_t i = 0; i < size(); ++i) { if (mask_array[i]) mask |= (1ull << i); } return mask; } Vectorized isnan() const { // NaN check svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } bool has_inf_nan() const { return svptest_any( ptrue, svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); } Vectorized map(double (*f)(double)) const { __at_align__ double tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); ++i) { tmp[i] = f(tmp[i]); } return loadu(tmp); } Vectorized abs() const { return svabs_f64_x(ptrue, values); } Vectorized angle() const { const auto nan_vec = svdup_n_f64(NAN); const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); const auto pi = svdup_n_f64(c10::pi); const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); auto angle = svsel_f64(neg_mask, pi, ZERO_F64); angle = svsel_f64(nan_mask, nan_vec, angle); return angle; } Vectorized real() const { return *this; } Vectorized imag() const { return Vectorized(0.0); } Vectorized conj() const { return *this; } Vectorized acos() const { return USE_SLEEF( Vectorized(Sleef_acosdx_u10sve(values)), map(std::acos)); } Vectorized acosh() const { return USE_SLEEF( Vectorized(Sleef_acoshdx_u10sve(values)), map(std::acosh)); } Vectorized asin() const { return USE_SLEEF( Vectorized(Sleef_asindx_u10sve(values)), map(std::asin)); } Vectorized asinh() const { return USE_SLEEF( Vectorized(Sleef_asinhdx_u10sve(values)), map(std::asinh)); } Vectorized atan() const { return USE_SLEEF( Vectorized(Sleef_atandx_u10sve(values)), map(std::atan)); } Vectorized atanh() const { return USE_SLEEF( Vectorized(Sleef_atanhdx_u10sve(values)), map(std::atanh)); } Vectorized atan2(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_atan2dx_u10sve(values, b)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::atan2(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized copysign(const Vectorized& sign) const { USE_SLEEF( { return Vectorized(Sleef_copysigndx_sve(values, sign)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_sign[size()]; store(tmp); sign.store(tmp_sign); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::copysign(tmp[i], tmp_sign[i]); } return loadu(tmp); })} Vectorized erf() const { return USE_SLEEF( Vectorized(Sleef_erfdx_u10sve(values)), map(std::erf)); } Vectorized erfc() const { return USE_SLEEF( Vectorized(Sleef_erfcdx_u15sve(values)), map(std::erfc)); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return USE_SLEEF( Vectorized(Sleef_expdx_u10sve(values)), map(std::exp)); } Vectorized exp2() const { return USE_SLEEF( Vectorized(Sleef_exp2dx_u10sve(values)), map(std::exp2)); } Vectorized expm1() const { return USE_SLEEF( Vectorized(Sleef_expm1dx_u10sve(values)), map(std::expm1)); } Vectorized exp_u20() const { return exp(); } Vectorized fmod(const Vectorized& q) const {USE_SLEEF( { return Vectorized(Sleef_fmoddx_sve(values, q)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_q[size()]; store(tmp); q.store(tmp_q); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::fmod(tmp[i], tmp_q[i]); } return loadu(tmp); })} Vectorized hypot(const Vectorized& b) const { USE_SLEEF( { return Vectorized(Sleef_hypotdx_u05sve(values, b)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::hypot(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized digamma() const { return map(calc_digamma); } Vectorized igamma(const Vectorized& x) const { __at_align__ double tmp[size()]; __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { tmp[i] = calc_igamma(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized igammac(const Vectorized& x) const { __at_align__ double tmp[size()]; __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { tmp[i] = calc_igammac(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_nextafterdx_sve(values, b)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); ++i) { tmp[i] = std::nextafter(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized log() const { return USE_SLEEF( Vectorized(Sleef_logdx_u10sve(values)), map(std::log)); } Vectorized log2() const { return USE_SLEEF( Vectorized(Sleef_log2dx_u10sve(values)), map(std::log2)); } Vectorized log10() const { return USE_SLEEF( Vectorized(Sleef_log10dx_u10sve(values)), map(std::log10)); } Vectorized log1p() const { return USE_SLEEF( Vectorized(Sleef_log1pdx_u10sve(values)), map(std::log1p)); } Vectorized frac() const; Vectorized sin() const { return USE_SLEEF( Vectorized(Sleef_sindx_u10sve(values)), map(std::sin)); } Vectorized sinh() const { return USE_SLEEF( Vectorized(Sleef_sinhdx_u10sve(values)), map(std::sinh)); } Vectorized cos() const { return USE_SLEEF( Vectorized(Sleef_cosdx_u10sve(values)), map(std::cos)); } Vectorized cosh() const { return USE_SLEEF( Vectorized(Sleef_coshdx_u10sve(values)), map(std::cosh)); } Vectorized ceil() const { return svrintp_f64_x(ptrue, values); } Vectorized floor() const { return svrintm_f64_x(ptrue, values); } Vectorized neg() const { return svneg_f64_x(ptrue, values); } Vectorized round() const { return svrinti_f64_x(ptrue, values); } Vectorized tan() const { return USE_SLEEF( Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); } Vectorized tanh() const { return USE_SLEEF( Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); } Vectorized trunc() const { return svrintz_f64_x(ptrue, values); } Vectorized lgamma() const { return USE_SLEEF( Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); } Vectorized sqrt() const { return svsqrt_f64_x(ptrue, values); } Vectorized reciprocal() const { return svdivr_f64_x(ptrue, values, ONE_F64); } Vectorized rsqrt() const { return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); } Vectorized pow(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_powdx_u10sve(values, b)); }, { __at_align__ double tmp[size()]; __at_align__ double tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::pow(tmp[i], tmp_b[i]); } return loadu(tmp); })} // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized operator==(const Vectorized& other) const { svbool_t mask = svcmpeq_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized operator!=(const Vectorized& other) const { svbool_t mask = svcmpne_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized operator<(const Vectorized& other) const { svbool_t mask = svcmplt_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized operator<=(const Vectorized& other) const { svbool_t mask = svcmple_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized operator>(const Vectorized& other) const { svbool_t mask = svcmpgt_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized operator>=(const Vectorized& other) const { svbool_t mask = svcmpge_f64(ptrue, values, other); return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); } Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return svadd_f64_x(ptrue, a, b); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return svsub_f64_x(ptrue, a, b); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return svmul_f64_x(ptrue, a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return svdiv_f64_x(ptrue, a, b); } // frac. Implement this here so we can use subtraction Vectorized inline Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { return svmax_f64_x(ptrue, a, b); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { return svmin_f64_x(ptrue, a, b); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { return svmin_f64_x(ptrue, max, a); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { return svmax_f64_x(ptrue, min, a); } template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return svreinterpret_f64_s64( svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return svreinterpret_f64_s64( svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return svreinterpret_f64_s64( sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); } Vectorized inline Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0); } Vectorized inline Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0); } Vectorized inline Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0); } Vectorized inline Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0); } Vectorized inline Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0); } Vectorized inline Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0); } template <> inline void convert(const double* src, double* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { svbool_t pg = svwhilelt_b64(i, n); svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return svmad_f64_x(ptrue, a, b, c); } #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY } // namespace at::vec