#pragma once #include #include #include #include #include #include namespace at { namespace vec { // Note [CPU_CAPABILITY namespace] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This header, and all of its subheaders, will be compiled with // different architecture flags for each supported set of vector // intrinsics. So we need to make sure they aren't inadvertently // linked together. We do this by declaring objects in an `inline // namespace` which changes the name mangling, but can still be // accessed as `at::vec`. inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized { private: vls_bfloat16_t values; public: using value_type = BFloat16; using size_type = int; static constexpr size_type size() { return VECTOR_WIDTH / sizeof(BFloat16); } Vectorized() {} Vectorized(svbfloat16_t v) : values(v) {} Vectorized(int val); Vectorized(BFloat16 val); template < typename... Args, typename = std::enable_if_t<(sizeof...(Args) == size())>> Vectorized(Args... vals) { __at_align__ BFloat16 buffer[size()] = {vals...}; values = svld1_bf16(ptrue, reinterpret_cast(buffer)); } operator svbfloat16_t() const { return values; } static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask_) { svbool_t mask = svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), ALL_S16_TRUE_MASK); return svsel_bf16(mask, b, a); } template static Vectorized arange( BFloat16 base = 0.f, step_t step = static_cast(1)) { __at_align__ BFloat16 buffer[size()]; for (int64_t i = 0; i < size(); i++) { buffer[i] = base + i * step; } return svld1_bf16(ptrue, reinterpret_cast(buffer)); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { if (count == 0) { return a; } else if (count < size()) { return svsel_bf16(svwhilelt_b16(0ull, count), b, a); } return b; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return svld1_bf16(ptrue, reinterpret_cast(ptr)); svbool_t pg = svwhilelt_b16(0ull, count); return svld1_bf16(pg, reinterpret_cast(ptr)); } void store(void* ptr, int64_t count = size()) const { __at_align__ bfloat16_t tmp[size()]; std::memset(tmp, 0, sizeof(tmp)); if (count == size()) { svst1_bf16(ptrue, reinterpret_cast(tmp), values); } else { svbool_t pg = svwhilelt_b16(0ull, count); svst1_bf16(pg, reinterpret_cast(tmp), values); } std::memcpy( reinterpret_cast(ptr), reinterpret_cast(tmp), count * sizeof(bfloat16_t)); } const BFloat16& operator[](int idx) const = delete; BFloat16& operator[](int idx) = delete; int64_t zero_mask() const { int64_t mask = 0; // returns an integer mask where all zero elements are translated to // 1-bit and others are translated to 0-bit int64_t mask = 0; __at_align__ int16_t mask_array[size()]; svbool_t svbool_mask = svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16); svst1_s16( ptrue, mask_array, svsel_s16(svbool_mask, ALL_S16_TRUE_MASK, ALL_S16_FALSE_MASK)); for (int64_t i = 0; i < size(); ++i) { if (mask_array[i]) mask |= (1ull << i); } return mask; } Vectorized isnan() const; bool has_inf_nan() const; Vectorized map(BFloat16 (*f)(BFloat16)) const { __at_align__ BFloat16 tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); ++i) { tmp[i] = f(tmp[i]); } return loadu(tmp); } Vectorized abs() const { auto mask = svdup_n_u16(0x7FFF); auto vals = svreinterpret_u16_bf16(values); vals = svand_u16_x(ptrue, vals, mask); return svreinterpret_bf16_u16(vals); } Vectorized angle() const; Vectorized real() const { return values; } Vectorized imag() const { return Vectorized(0.f); } Vectorized conj() const { return values; } Vectorized acos() const; Vectorized acosh() const; Vectorized asin() const; Vectorized atan() const; Vectorized atanh() const; Vectorized atan2(const Vectorized& b) const; Vectorized copysign(const Vectorized& sign) const; Vectorized erf() const; Vectorized erfc() const; Vectorized erfinv() const; Vectorized exp() const; Vectorized exp2() const; Vectorized expm1() const; Vectorized exp_u20() const { return exp(); } Vectorized fmod(const Vectorized& q) const; Vectorized hypot(const Vectorized& b) const; Vectorized i0() const; Vectorized i0e() const; Vectorized digamma() const; Vectorized igamma(const Vectorized& x) const; Vectorized igammac(const Vectorized& x) const; Vectorized nextafter(const Vectorized& b) const; Vectorized log() const; Vectorized log2() const; Vectorized log10() const; Vectorized log1p() const; Vectorized frac() const; Vectorized sin() const; Vectorized sinh() const; Vectorized cos() const; Vectorized cosh() const; Vectorized ceil() const; Vectorized floor() const; Vectorized neg() const { auto mask = svdup_n_u16(0x8000); auto vals = svreinterpret_u16_bf16(values); vals = sveor_u16_x(ptrue, vals, mask); return svreinterpret_bf16_u16(vals); }; Vectorized round() const; Vectorized tan() const; Vectorized tanh() const; Vectorized trunc() const; Vectorized lgamma() const; Vectorized sqrt() const; Vectorized reciprocal() const; Vectorized rsqrt() const; Vectorized pow(const Vectorized& b) const; // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized operator==(const Vectorized& other) const; Vectorized operator!=(const Vectorized& other) const; Vectorized operator<(const Vectorized& other) const; Vectorized operator<=(const Vectorized& other) const; Vectorized operator>(const Vectorized& other) const; Vectorized operator>=(const Vectorized& other) const; Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; inline std::tuple, Vectorized> convert_bfloat16_float( const Vectorized& a) { static_assert( Vectorized::size() == 2 * Vectorized::size()); auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); auto bf16_vec1 = svzip1_bf16(zero, a); auto bf16_vec2 = svzip2_bf16(zero, a); auto x1 = svreinterpret_f32_bf16(bf16_vec1); auto x2 = svreinterpret_f32_bf16(bf16_vec2); return {Vectorized(x1), Vectorized(x2)}; } inline Vectorized convert_float_bfloat16( const Vectorized& a, const Vectorized& b) { static_assert( Vectorized::size() == 2 * Vectorized::size()); svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a); svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b); return Vectorized(svuzp1_bf16(x1, x2)); } inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { __at_align__ float values[Vectorized::size()]; for (const auto k : c10::irange(Vectorized::size())) { values[k] = data[k]; } out = Vectorized::loadu(values); } inline void load_fp32_from_bf16( const BFloat16* data, Vectorized& out1, Vectorized& out2) { Vectorized bf16_vec = Vectorized::loadu(data); auto floats = convert_bfloat16_float(bf16_vec); out1 = std::get<0>(floats); out2 = std::get<1>(floats); } template Vectorized binary_operator_via_float( Op op, const Vectorized& a, const Vectorized& b) { const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); return convert_float_bfloat16( op(a_float_low, b_float_low), op(a_float_high, b_float_high)); } template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::plus>(), a, b); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::minus>(), a, b); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::multiplies>(), a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::divides>(), a, b); } inline Vectorized::Vectorized(int val) { auto vals_f = svdup_n_f32(val); values = convert_float_bfloat16(vals_f, vals_f); } inline Vectorized::Vectorized(BFloat16 val) { auto vals_f = svdup_n_f32((float)val); values = convert_float_bfloat16(vals_f, vals_f); } bool inline Vectorized::has_inf_nan() const { auto [v1, v2] = convert_bfloat16_float(values); return v1.has_inf_nan() || v2.has_inf_nan(); } // frac. Implement this here so we can use subtraction Vectorized inline Vectorized::frac() const { return *this - this->trunc(); } #define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \ Vectorized inline Vectorized::func_name() const { \ auto [v1, v2] = convert_bfloat16_float(*this); \ v1 = v1.func_name(); \ v2 = v2.func_name(); \ return convert_float_bfloat16(v1, v2); \ } #define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \ Vectorized inline Vectorized::func_name( \ const Vectorized& a) const { \ auto [v1, v2] = convert_bfloat16_float(*this); \ auto [v3, v4] = convert_bfloat16_float(a); \ v1 = v1.func_name(v3); \ v2 = v2.func_name(v4); \ return convert_float_bfloat16(v1, v2); \ } DEFINE_BF16_FUNC_VIA_FLOAT(isnan); DEFINE_BF16_FUNC_VIA_FLOAT(angle); DEFINE_BF16_FUNC_VIA_FLOAT(acos); DEFINE_BF16_FUNC_VIA_FLOAT(acosh); DEFINE_BF16_FUNC_VIA_FLOAT(asin); DEFINE_BF16_FUNC_VIA_FLOAT(atan); DEFINE_BF16_FUNC_VIA_FLOAT(atanh); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); DEFINE_BF16_FUNC_VIA_FLOAT(erf); DEFINE_BF16_FUNC_VIA_FLOAT(erfc); DEFINE_BF16_FUNC_VIA_FLOAT(exp); DEFINE_BF16_FUNC_VIA_FLOAT(exp2); DEFINE_BF16_FUNC_VIA_FLOAT(expm1); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); DEFINE_BF16_FUNC_VIA_FLOAT(i0); DEFINE_BF16_FUNC_VIA_FLOAT(i0e); DEFINE_BF16_FUNC_VIA_FLOAT(digamma); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); DEFINE_BF16_FUNC_VIA_FLOAT(log); DEFINE_BF16_FUNC_VIA_FLOAT(log2); DEFINE_BF16_FUNC_VIA_FLOAT(log10); DEFINE_BF16_FUNC_VIA_FLOAT(log1p); DEFINE_BF16_FUNC_VIA_FLOAT(sin); DEFINE_BF16_FUNC_VIA_FLOAT(sinh); DEFINE_BF16_FUNC_VIA_FLOAT(cos); DEFINE_BF16_FUNC_VIA_FLOAT(cosh); DEFINE_BF16_FUNC_VIA_FLOAT(ceil); DEFINE_BF16_FUNC_VIA_FLOAT(floor); DEFINE_BF16_FUNC_VIA_FLOAT(round); DEFINE_BF16_FUNC_VIA_FLOAT(tan); DEFINE_BF16_FUNC_VIA_FLOAT(tanh); DEFINE_BF16_FUNC_VIA_FLOAT(trunc); DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); Vectorized inline Vectorized::operator==( const Vectorized& other) const { auto [f1, f2] = convert_bfloat16_float(values); auto [f3, f4] = convert_bfloat16_float(other); svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3); svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4); auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); auto bf16_1 = svreinterpret_bf16_f32(res1); auto bf16_2 = svreinterpret_bf16_f32(res2); return svuzp1_bf16(bf16_1, bf16_2); } Vectorized inline Vectorized::operator!=( const Vectorized& other) const { auto [f1, f2] = convert_bfloat16_float(values); auto [f3, f4] = convert_bfloat16_float(other); svbool_t mask1 = svcmpne_f32(ptrue, f1, f3); svbool_t mask2 = svcmpne_f32(ptrue, f2, f4); auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); auto bf16_1 = svreinterpret_bf16_f32(res1); auto bf16_2 = svreinterpret_bf16_f32(res2); return svuzp1_bf16(bf16_1, bf16_2); } Vectorized inline Vectorized::operator>( const Vectorized& other) const { auto [v1, v2] = convert_bfloat16_float(*this); auto [v3, v4] = convert_bfloat16_float(other); return convert_float_bfloat16(v1 > v3, v2 > v4); } Vectorized inline Vectorized::operator>=( const Vectorized& other) const { auto [v1, v2] = convert_bfloat16_float(*this); auto [v3, v4] = convert_bfloat16_float(other); return convert_float_bfloat16(v1 >= v3, v2 >= v4); } Vectorized inline Vectorized::operator<( const Vectorized& other) const { auto [v1, v2] = convert_bfloat16_float(*this); auto [v3, v4] = convert_bfloat16_float(other); return convert_float_bfloat16(v1 < v3, v2 < v4); } Vectorized inline Vectorized::operator<=( const Vectorized& other) const { auto [v1, v2] = convert_bfloat16_float(*this); auto [v3, v4] = convert_bfloat16_float(other); return convert_float_bfloat16(v1 <= v3, v2 <= v4); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&maximum), a, b); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&minimum), a, b); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&clamp_max), a, max); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&clamp_min), a, min); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { return clamp_min(clamp_max(a, max), min); } template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return svreinterpret_bf16_u16( svand_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return svreinterpret_bf16_u16( svorr_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return svreinterpret_bf16_u16( sveor_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); } Vectorized inline Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1.0f); } Vectorized inline Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1.0f); } Vectorized inline Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1.0f); } Vectorized inline Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1.0f); } Vectorized inline Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1.0f); } Vectorized inline Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1.0f); } template <> inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { const int64_t fraction = n % Vectorized::size(); #pragma unroll for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { svst1_bf16( ptrue, const_cast(reinterpret_cast(dst)) + i, svldnt1_bf16( ptrue, const_cast(reinterpret_cast(src)) + i)); } #pragma unroll for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { svbool_t pg = svwhilelt_b16(i, n); svst1_bf16( pg, const_cast(reinterpret_cast(dst)) + i, svldnt1_bf16( pg, const_cast(reinterpret_cast(src)) + i)); } } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; } #endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) } // namespace CPU_CAPABILITY } // namespace vec } // namespace at