#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #include #include #include namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { // Following vec128_half_neon.h, we only support aarch64. #if !defined(C10_MOBILE) && defined(__aarch64__) #ifdef __BIG_ENDIAN__ #error "Big endian is not supported." #endif // Unlike the float16_t family of types, bfloat16_t is not available // when we're not targeting bfloat16 hardware support on some // platforms (but not Mac, so we have to be careful not to shadow the // definitions in case they are actually there!). (See // https://godbolt.org/z/orv6e94n4 ) So, we need to handle it as // uint16_t in that case. #define IMPLEMENT_AT_BF16_SHIM(vec_suffix) \ inline at_bfloat16x4_t at_vget_low_bf16(at_bfloat16x8_t a) { \ return vget_low_##vec_suffix(a); \ } \ \ inline at_bfloat16x4_t at_vget_high_bf16(at_bfloat16x8_t a) { \ return vget_high_##vec_suffix(a); \ } \ \ inline at_bfloat16x8_t at_vcombine_bf16( \ at_bfloat16x4_t low, at_bfloat16x4_t high) { \ return vcombine_##vec_suffix(low, high); \ } \ \ inline at_bfloat16x8_t at_vdupq_n_bf16(at_bfloat16_t value) { \ return vdupq_n_##vec_suffix(value); \ } \ \ inline at_bfloat16x8_t at_vld1q_bf16(const at_bfloat16_t* ptr) { \ return vld1q_##vec_suffix(ptr); \ } \ \ inline void at_vst1q_bf16(at_bfloat16_t* ptr, at_bfloat16x8_t value) { \ vst1q_##vec_suffix(ptr, value); \ } \ \ template \ inline at_bfloat16x8_t at_vreinterpretq_bf16_u16(T val) { \ if constexpr (std::is_same_v) { \ return val; \ } else { \ return vreinterpretq_bf16_u16(val); \ } \ } \ template \ inline at_bfloat16x4_t at_vreinterpret_bf16_u16(T val) { \ if constexpr (std::is_same_v) { \ return val; \ } else { \ return vreinterpret_bf16_u16(val); \ } \ } \ template \ inline uint16x8_t at_vreinterpretq_u16_bf16(T val) { \ if constexpr (std::is_same_v) { \ return val; \ } else { \ return vreinterpretq_u16_bf16(val); \ } \ } \ template \ inline uint16x4_t at_vreinterpret_u16_bf16(T val) { \ if constexpr (std::is_same_v) { \ return val; \ } else { \ return vreinterpret_u16_bf16(val); \ } \ } #ifdef __ARM_FEATURE_BF16 using at_bfloat16x8_t = bfloat16x8_t; using at_bfloat16x4_t = bfloat16x4_t; using at_bfloat16_t = bfloat16_t; IMPLEMENT_AT_BF16_SHIM(bf16) #define at_vsetq_lane_bf16 vsetq_lane_bf16 #define at_vgetq_lane_bf16 vgetq_lane_bf16 #else using at_bfloat16x8_t = uint16x8_t; using at_bfloat16x4_t = uint16x4_t; using at_bfloat16_t = uint16_t; IMPLEMENT_AT_BF16_SHIM(u16) #define at_vsetq_lane_bf16 vsetq_lane_u16 #define at_vgetq_lane_bf16 vgetq_lane_u16 #endif // __ARM_FEATURE_BF16 template struct BlendBFloat16Regs { static at_bfloat16x8_t impl( const at_bfloat16x8_t& a, const at_bfloat16x8_t& b, at_bfloat16x8_t& res); }; template struct BlendBFloat16Regs { static at_bfloat16x8_t impl( const at_bfloat16x8_t& a, const at_bfloat16x8_t& b, at_bfloat16x8_t& res) { return at_vsetq_lane_bf16(at_vgetq_lane_bf16(b, index), res, index); } }; template struct BlendBFloat16Regs { static at_bfloat16x8_t impl( const at_bfloat16x8_t& a, const at_bfloat16x8_t& b, at_bfloat16x8_t& res) { return at_vsetq_lane_bf16(at_vgetq_lane_bf16(a, index), res, index); } }; template <> struct is_vec_specialized_for : std::bool_constant {}; template <> class Vectorized : public Vectorized16< at_bfloat16x8_t, c10::BFloat16, BlendBFloat16Regs, Vectorized> { using Base = Vectorized16< at_bfloat16x8_t, c10::BFloat16, BlendBFloat16Regs, Vectorized>; friend Base; friend std::tuple, Vectorized> convert_bfloat16_float( const Vectorized& a); friend Vectorized convert_float_bfloat16( const Vectorized& a, const Vectorized& b); private: Vectorized map2( const Vectorized& second, c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const { __at_align__ c10::BFloat16 tmp_first[size()]; __at_align__ c10::BFloat16 tmp_second[size()]; store(tmp_first); // store this to tmp_first second.store(tmp_second); for (const auto i : c10::irange(size())) { tmp_first[i] = f(tmp_first[i], tmp_second[i]); } return loadu(tmp_first); } static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) { #ifdef __ARM_FEATURE_BF16 return vcvt_f32_bf16(bf16); #else int32x4_t shift = vdupq_n_s32(16); return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift)); #endif // __ARM_FEATURE_BF16 } static at_bfloat16x4_t convert_bf16_f32(const Vectorized& f32) { #ifdef __ARM_FEATURE_BF16 return vcvt_bf16_f32(f32); #else static_assert(std::is_same_v); uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32); uint32x4_t rounding_bias = vaddq_u32( vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)), vdupq_n_u32(0x7FFF)); at_bfloat16x4_t rounded = vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16); const auto bf16_nan = vdup_n_u16(0x7FC0); return vbsl_u16( vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded); #endif // __ARM_FEATURE_BF16 } Vectorized map_with_vec_float_method( Vectorized (Vectorized::*m)() const) const { float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); Vectorized mv0 = (Vectorized(v00).*m)(); Vectorized mv1 = (Vectorized(v01).*m)(); at_bfloat16x4_t r00 = convert_bf16_f32(mv0); at_bfloat16x4_t r01 = convert_bf16_f32(mv1); return Vectorized(at_vcombine_bf16(r00, r01)); } Vectorized map2_with_vec_float_method( const Vectorized& second, Vectorized (Vectorized::*m)(const Vectorized&) const) const { float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); Vectorized mv0 = (Vectorized(v00).*m)(second_v00); Vectorized mv1 = (Vectorized(v01).*m)(second_v01); at_bfloat16x4_t r00 = convert_bf16_f32(mv0); at_bfloat16x4_t r01 = convert_bf16_f32(mv1); return Vectorized(at_vcombine_bf16(r00, r01)); } Vectorized map2_bitmask_with_vec_float_method( const Vectorized& second, Vectorized (Vectorized::*m)(const Vectorized&) const) const { float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); Vectorized mv0 = (Vectorized(v00).*m)(second_v00); Vectorized mv1 = (Vectorized(v01).*m)(second_v01); // Assume the operator returns a bitmask, not "real" floats, and // just narrow the bits. All-ones is a NaN and will get mangled by // conversion! at_bfloat16x4_t r00 = at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); at_bfloat16x4_t r01 = at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); return Vectorized(at_vcombine_bf16(r00, r01)); } public: using Vectorized16::Vectorized16; Vectorized() = default; Vectorized(c10::BFloat16 val) : Vectorized16(at_vdupq_n_bf16(c10::bit_cast(val.x))) {} Vectorized(float val) : Vectorized(c10::BFloat16(val)) {} Vectorized( value_type val0, value_type val1, value_type val2, value_type val3, value_type val4, value_type val5, value_type val6, value_type val7) : Vectorized16(at_bfloat16x8_t{ c10::bit_cast(val0.x), c10::bit_cast(val1.x), c10::bit_cast(val2.x), c10::bit_cast(val3.x), c10::bit_cast(val4.x), c10::bit_cast(val5.x), c10::bit_cast(val6.x), c10::bit_cast(val7.x)}) {} static Vectorized blendv( const Vectorized& a, const Vectorized& b, const Vectorized& mask) { // NOTE: blendv has the same problems as it does for Half; see comments in // vec128_half_neon.h. Vectorized vec(mask.values); vec.values = at_vreinterpretq_bf16_u16(vbslq_u16( at_vreinterpretq_u16_bf16(vec.values), at_vreinterpretq_u16_bf16(b.values), at_vreinterpretq_u16_bf16(a.values))); return vec; } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { uint16_t pre_mask[size()] = {0}; for (int i = 0; i < count; i++) { pre_mask[i] = 0xFFFF; } uint16x8_t mask = vld1q_u16(pre_mask); Vectorized vec(at_vreinterpretq_bf16_u16(vbslq_u16( mask, at_vreinterpretq_u16_bf16(b.values), at_vreinterpretq_u16_bf16(a.values)))); return vec; } static Vectorized loadu( const void* ptr, int64_t count = size()) { if (count == size()) { return at_vld1q_bf16(reinterpret_cast(ptr)); } __at_align__ at_bfloat16_t tmp_values[size()]; std::memset(tmp_values, 0, sizeof(tmp_values)); std::memcpy( tmp_values, reinterpret_cast(ptr), count * sizeof(at_bfloat16_t)); return at_vld1q_bf16(reinterpret_cast(tmp_values)); } void store(void* ptr, int64_t count = size()) const { if (count == size()) { at_vst1q_bf16(reinterpret_cast(ptr), values); return; } else { at_bfloat16_t tmp_values[size()]; at_vst1q_bf16(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t)); } } Vectorized isnan() const { // NOTE: we could make this faster by doing vectorized checks of // exponent/payload bits. __at_align__ c10::BFloat16 tmp[size()]; __at_align__ c10::BFloat16 res[size()]; store(tmp); for (const auto i : c10::irange(size())) { if (_isnan(tmp[i])) { std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::BFloat16)); } else { std::memset(static_cast(&res[i]), 0, sizeof(c10::BFloat16)); } } return loadu(res); } bool has_inf_nan() const { __at_align__ c10::BFloat16 tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { if (_isnan(tmp[i]) || _isinf(tmp[i])) { return true; } } return false; } #define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name) \ Vectorized name() const { \ return map_with_vec_float_method(&Vectorized::name); \ } #define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \ Vectorized name(const Vectorized& other) const { \ return map2_bitmask_with_vec_float_method( \ other, &Vectorized::name); \ } DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) Vectorized frac() const; DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; }; // Vectorized inline std::tuple, Vectorized> convert_bfloat16_float( const Vectorized& a) { static_assert( Vectorized::size() == 2 * Vectorized::size()); at_bfloat16x8_t x = a; float32x4_t x1 = Vectorized::convert_f32_bf16(at_vget_low_bf16(x)); float32x4_t x2 = Vectorized::convert_f32_bf16(at_vget_high_bf16(x)); return {Vectorized(x1), Vectorized(x2)}; } inline Vectorized convert_float_bfloat16( const Vectorized& a, const Vectorized& b) { static_assert( Vectorized::size() == 2 * Vectorized::size()); at_bfloat16x4_t x1 = Vectorized::convert_bf16_f32(a); at_bfloat16x4_t x2 = Vectorized::convert_bf16_f32(b); return Vectorized(at_vcombine_bf16(x1, x2)); } template Vectorized binary_operator_via_float( Op op, const Vectorized& a, const Vectorized& b) { const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); return convert_float_bfloat16( op(a_float_low, b_float_low), op(a_float_high, b_float_high)); } template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::plus>(), a, b); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::minus>(), a, b); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::multiplies>(), a, b); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float(std::divides>(), a, b); } // frac. Implement this here so we can use subtraction inline Vectorized Vectorized::frac() const { return *this - this->trunc(); } template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&maximum), a, b); } template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { return binary_operator_via_float( static_cast (*)( const Vectorized&, const Vectorized&)>(&minimum), a, b); } template <> Vectorized inline clamp( const Vectorized& a, const Vectorized& min, const Vectorized& max) { return minimum(max, maximum(min, a)); } template <> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max) { return minimum(max, a); } template <> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min) { return maximum(min, a); } template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { return Vectorized(at_vreinterpretq_bf16_u16( vandq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { return Vectorized(at_vreinterpretq_bf16_u16( vorrq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { return Vectorized(at_vreinterpretq_bf16_u16( veorq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); } inline Vectorized Vectorized::eq( const Vectorized& other) const { return (*this == other) & Vectorized(1); } inline Vectorized Vectorized::ne( const Vectorized& other) const { return (*this != other) & Vectorized(1); } inline Vectorized Vectorized::gt( const Vectorized& other) const { return (*this > other) & Vectorized(1); } inline Vectorized Vectorized::ge( const Vectorized& other) const { return (*this >= other) & Vectorized(1); } inline Vectorized Vectorized::lt( const Vectorized& other) const { return (*this < other) & Vectorized(1); } inline Vectorized Vectorized::le( const Vectorized& other) const { return (*this <= other) & Vectorized(1); } template <> Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also, // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered // elements, not the bottom and top half, so they don't seem // particularly useful here. Ideally we would include dot product in // the Vectorized interface... return a * b + c; } template <> Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { // See NOTE [BF16 FMA] above. return a * b - c; } #endif // !defined(C10_MOBILE) && defined(__aarch64__) } // namespace CPU_CAPABILITY } // namespace at::vec