#pragma once #include #include #include #include #if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) #include #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code #else #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code #endif namespace at::vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE) template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized { private: vls_float32_t values; public: using value_type = float; using size_type = int; static constexpr size_type size() { return VECTOR_WIDTH / sizeof(float); } Vectorized() {} Vectorized(svfloat32_t v) : values(v) {} Vectorized(float val) { values = svdup_n_f32(val); } template < typename... Args, typename = std::enable_if_t<(sizeof...(Args) == size())>> Vectorized(Args... vals) { __at_align__ float buffer[size()] = {vals...}; values = svld1_f32(ptrue, buffer); } operator svfloat32_t() const { return values; } template static Vectorized blend( const Vectorized& a, const Vectorized& b) { // Build an array of flags: each element is 1 if the corresponding bit in // 'mask' is set, 0 otherwise. __at_align__ int32_t flag_arr[size()]; for (int i = 0; i < size(); i++) { flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; } // Load the flag array into an SVE int32 vector. svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr); // Compare each lane of int_mask to 0; returns an svbool_t predicate where // true indicates a nonzero flag. svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0); // Use svsel to select elements from b where the predicate is true, else // from a. svfloat32_t result = svsel_f32(blend_mask, b.values, a.values); return Vectorized(result); } static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask_) { svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); return svsel_f32(mask, b, a); } template static Vectorized arange( float base = 0.f, step_t step = static_cast(1)) { __at_align__ float buffer[size()]; for (int64_t i = 0; i < size(); i++) { buffer[i] = base + i * step; } return svld1_f32(ptrue, buffer); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { if (count == 0) { return a; } else if (count < size()) { return svsel_f32(svwhilelt_b32(0ull, count), b, a); } return b; } // Implementation is picked from // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f const auto shift = svreinterpret_f32_u32( svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f const auto inv_ln2 = svreinterpret_f32_u32( svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32( 0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32( 0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f const auto inf = svdup_n_f32(std::numeric_limits::infinity()); const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) const auto zero = svdup_n_f32(0.f); const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) // Range reduction: // e^x = 2^n * e^r // where: // n = floor(x / ln(2)) // r = x - n * ln(2) // // By adding x / ln(2) with 2^23 + 127 (shift): // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 // forces decimal part // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. // n) + 127 will occupy the whole fraction part of z in FP32 format. // Subtracting 2^23 + 127 (shift) from z will result in the integer part // of x / ln(2) (i.e. n) because the decimal part has been pushed out // and lost. // * The addition of 127 makes the FP32 fraction part of z ready to be // used as the exponent // in FP32 format. Left shifting z by 23 bits will result in 2^n. const auto z = svmla_f32_z(pg, shift, x, inv_ln2); const auto n = svsub_f32_z(pg, z, shift); const auto scale = svreinterpret_f32_u32( svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n // The calculation of n * ln(2) is done using 2 steps to achieve accuracy // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in // term of accuracy and performance. const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); // Compute the truncated Taylor series of e^r. // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) const auto r2 = svmul_f32_z(pg, r, r); const auto p1 = svmul_f32_z(pg, c1, r); const auto p23 = svmla_f32_z(pg, c2, c3, r); const auto p45 = svmla_f32_z(pg, c4, c5, r); const auto p2345 = svmla_f32_z(pg, p23, p45, r2); const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); auto poly = svmla_f32_z(pg, scale, p12345, scale); // Handle underflow and overflow. poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); return poly; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return svld1_f32(ptrue, reinterpret_cast(ptr)); svbool_t pg = svwhilelt_b32(0ull, count); return svld1_f32(pg, reinterpret_cast(ptr)); } void store(void* ptr, int64_t count = size()) const { if (count == size()) { svst1_f32(ptrue, reinterpret_cast(ptr), values); } else { svbool_t pg = svwhilelt_b32(0ull, count); svst1_f32(pg, reinterpret_cast(ptr), values); } } const float& operator[](int idx) const = delete; float& operator[](int idx) = delete; int64_t zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit // and others are translated to 0-bit int64_t mask = 0; __at_align__ int32_t mask_array[size()]; svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); svst1_s32( ptrue, mask_array, svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK)); for (int64_t i = 0; i < size(); ++i) { if (mask_array[i]) mask |= (1ull << i); } return mask; } Vectorized isnan() const { // NaN check svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } bool has_inf_nan() const { return svptest_any( ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); } Vectorized map(float (*f)(float)) const { __at_align__ float tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); ++i) { tmp[i] = f(tmp[i]); } return loadu(tmp); } Vectorized abs() const { return svabs_f32_x(ptrue, values); } Vectorized angle() const { const auto nan_vec = svdup_n_f32(NAN); const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); const auto pi = svdup_n_f32(c10::pi); const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); auto angle = svsel_f32(neg_mask, pi, ZERO_F32); angle = svsel_f32(nan_mask, nan_vec, angle); return angle; } Vectorized real() const { return values; } Vectorized imag() const { return Vectorized(0.f); } Vectorized conj() const { return values; } Vectorized acos() const { return USE_SLEEF( Vectorized(Sleef_acosfx_u10sve(values)), map(std::acos)); } Vectorized acosh() const { return USE_SLEEF( Vectorized(Sleef_acoshfx_u10sve(values)), map(std::acosh)); } Vectorized asin() const { return USE_SLEEF( Vectorized(Sleef_asinfx_u10sve(values)), map(std::asin)); } Vectorized asinh() const { return USE_SLEEF( Vectorized(Sleef_asinhfx_u10sve(values)), map(std::asinh)); } Vectorized atan() const { return USE_SLEEF( Vectorized(Sleef_atanfx_u10sve(values)), map(std::atan)); } Vectorized atanh() const { return USE_SLEEF( Vectorized(Sleef_atanhfx_u10sve(values)), map(std::atanh)); } Vectorized atan2(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_atan2fx_u10sve(values, b)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::atan2(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized copysign(const Vectorized& sign) const { USE_SLEEF( { return Vectorized(Sleef_copysignfx_sve(values, sign)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_sign[size()]; store(tmp); sign.store(tmp_sign); for (int64_t i = 0; i < size(); ++i) { tmp[i] = std::copysign(tmp[i], tmp_sign[i]); } return loadu(tmp); })} Vectorized erf() const { return USE_SLEEF( Vectorized(Sleef_erffx_u10sve(values)), map(std::erf)); } Vectorized erfc() const { return USE_SLEEF( Vectorized(Sleef_erfcfx_u15sve(values)), map(std::erfc)); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return USE_SLEEF( Vectorized(Sleef_expfx_u10sve(values)), map(std::exp)); } Vectorized exp2() const { return USE_SLEEF( Vectorized(Sleef_exp2fx_u10sve(values)), map(std::exp2)); } Vectorized expm1() const { return USE_SLEEF( Vectorized(Sleef_expm1fx_u10sve(values)), map(std::expm1)); } Vectorized exp_u20() const { return exp(); } Vectorized fmod(const Vectorized& q) const {USE_SLEEF( { return Vectorized(Sleef_fmodfx_sve(values, q)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_q[size()]; store(tmp); q.store(tmp_q); for (int64_t i = 0; i < size(); ++i) { tmp[i] = std::fmod(tmp[i], tmp_q[i]); } return loadu(tmp); })} Vectorized hypot(const Vectorized& b) const { USE_SLEEF( { return Vectorized(Sleef_hypotfx_u05sve(values, b)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::hypot(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized digamma() const { return map(calc_digamma); } Vectorized igamma(const Vectorized& x) const { __at_align__ float tmp[size()]; __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { tmp[i] = calc_igamma(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized igammac(const Vectorized& x) const { __at_align__ float tmp[size()]; __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { tmp[i] = calc_igammac(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_nextafterfx_sve(values, b)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); ++i) { tmp[i] = std::nextafter(tmp[i], tmp_b[i]); } return loadu(tmp); })} Vectorized log() const { return USE_SLEEF( Vectorized(Sleef_logfx_u10sve(values)), map(std::log)); } Vectorized log2() const { return USE_SLEEF( Vectorized(Sleef_log2fx_u10sve(values)), map(std::log2)); } Vectorized log10() const { return USE_SLEEF( Vectorized(Sleef_log10fx_u10sve(values)), map(std::log10)); } Vectorized log1p() const { return USE_SLEEF( Vectorized(Sleef_log1pfx_u10sve(values)), map(std::log1p)); } Vectorized frac() const; Vectorized sin() const { return USE_SLEEF( Vectorized(Sleef_sinfx_u10sve(values)), map(std::sin)); } Vectorized sinh() const { return USE_SLEEF( Vectorized(Sleef_sinhfx_u10sve(values)), map(std::sinh)); } Vectorized cos() const { return USE_SLEEF( Vectorized(Sleef_cosfx_u10sve(values)), map(std::cos)); } Vectorized cosh() const { return USE_SLEEF( Vectorized(Sleef_coshfx_u10sve(values)), map(std::cosh)); } Vectorized ceil() const { return svrintp_f32_x(ptrue, values); } Vectorized floor() const { return svrintm_f32_x(ptrue, values); } Vectorized neg() const { return svneg_f32_x(ptrue, values); } Vectorized round() const { return svrinti_f32_x(ptrue, values); } Vectorized tan() const { return USE_SLEEF( Vectorized(Sleef_tanfx_u10sve(values)), map(std::tan)); } // Implementation is picked from // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 Vectorized tanh() const { // Constants used for the tanh calculation. const svfloat32_t CONST_1 = svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. const svfloat32_t CONST_2 = svdup_n_f32( 2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). const svfloat32_t CONST_MIN_TANH = svdup_n_f32( -10.f); // Minimum threshold for input values to prevent overflow. const svfloat32_t CONST_MAX_TANH = svdup_n_f32( 10.f); // Maximum threshold for input values to prevent overflow. // Step 1: Clamp the values within the range [-10, 10] to prevent overflow // during exponentiation. The tanh function approaches ±1 rapidly as the // input grows large, so we limit the input range to avoid numerical // instability. svmax_f32_z ensures values are greater than -10, and // svmin_f32_z ensures they are less than 10. svfloat32_t x = svmin_f32_z( ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); // Step 2: Calculate exp(2 * x), where x is the clamped value. // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of // the result. svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); // Step 3: Calculate the numerator of the tanh function, which is exp(2x) // - 1. svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); // Step 4: Calculate the denominator of the tanh function, which is exp(2x) // + 1. svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); // Step 5: Calculate the tanh function as the ratio of the numerator and // denominator: num / den. svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); // Return the calculated tanh values. return tanh; } Vectorized trunc() const { return svrintz_f32_x(ptrue, values); } Vectorized lgamma() const { return USE_SLEEF( Vectorized(Sleef_lgammafx_u10sve(values)), map(std::lgamma)); } Vectorized sqrt() const { return svsqrt_f32_x(ptrue, values); } Vectorized reciprocal() const { return svdivr_f32_x(ptrue, values, ONE_F32); } Vectorized rsqrt() const { return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); } Vectorized pow(const Vectorized& b) const {USE_SLEEF( { return Vectorized(Sleef_powfx_u10sve(values, b)); }, { __at_align__ float tmp[size()]; __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { tmp[i] = std::pow(tmp[i], tmp_b[i]); } return loadu(tmp); })} // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized operator==(const Vectorized& other) const { svbool_t mask = svcmpeq_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized operator!=(const Vectorized& other) const { svbool_t mask = svcmpne_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized operator<(const Vectorized& other) const { svbool_t mask = svcmplt_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized operator<=(const Vectorized& other) const { svbool_t mask = svcmple_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized operator>(const Vectorized& other) const { svbool_t mask = svcmpgt_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized operator>=(const Vectorized& other) const { svbool_t mask = svcmpge_f32(ptrue, values, other); return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); } Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return svadd_f32_x(ptrue, a, b); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return svsub_f32_x(ptrue, a, b); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return svmul_f32_x(ptrue, a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return svdiv_f32_x(ptrue, a, b); } // frac. Implement this here so we can use subtraction Vectorized inline Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { return svmax_f32_x(ptrue, a, b); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { return svmin_f32_x(ptrue, a, b); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { return svmin_f32_x(ptrue, max, a); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { return svmax_f32_x(ptrue, min, a); } template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return svreinterpret_f32_s32( svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return svreinterpret_f32_s32( svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return svreinterpret_f32_s32( sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); } Vectorized inline Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0f); } Vectorized inline Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0f); } Vectorized inline Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0f); } Vectorized inline Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0f); } Vectorized inline Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0f); } Vectorized inline Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0f); } template <> inline void convert(const float* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { svbool_t pg = svwhilelt_b32(i, n); svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); } } template <> inline void convert(const float* src, at::Half* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svfloat16_t src_vec = svuzp1_f16( svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_16 = svwhilelt_b16(i, n); pg_32 = svwhilelt_b32(i, n); svfloat16_t src_vec = svuzp1_f16( svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); } } template <> inline void convert(const at::Half* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svfloat16_t src_vec = svzip1_f16( svldnt1_f16(pg_16, reinterpret_cast(src) + i), ZERO_F16); svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_16 = svwhilelt_b16(i, n); pg_32 = svwhilelt_b32(i, n); svfloat16_t src_vec = svzip1_f16( svldnt1_f16(pg_16, reinterpret_cast(src) + i), ZERO_F16); svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); } } template <> inline void convert(const bool* src, float* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { pg_8 = svwhilelt_b8(i, n); pg_32 = svwhilelt_b32(i, n); svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return svmad_f32_x(ptrue, a, b, c); } #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY } // namespace at::vec