#pragma once #include #include #include // Note: header order is important here #include #include #include #include #include #include #include #include #include #include #include namespace at { namespace vec { inline namespace CPU_CAPABILITY { DEFINE_CLAMP_FUNCS(c10::quint8) DEFINE_CLAMP_FUNCS(c10::qint8) DEFINE_CLAMP_FUNCS(c10::qint32) DEFINE_CLAMP_FUNCS(int16_t) DEFINE_CLAMP_FUNCS(int32_t) DEFINE_CLAMP_FUNCS(int64_t) DEFINE_CLAMP_FUNCS(float) DEFINE_CLAMP_FUNCS(double) template <> Vectorized C10_ALWAYS_INLINE fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return Vectorized{ vec_madd(a.vec0(), b.vec0(), c.vec0()), vec_madd(a.vec1(), b.vec1(), c.vec1())}; } template <> Vectorized C10_ALWAYS_INLINE fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return Vectorized{ a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; } template <> Vectorized C10_ALWAYS_INLINE fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return Vectorized{ a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; } template <> Vectorized C10_ALWAYS_INLINE fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return Vectorized{ a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; } DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) template <> Vectorized C10_ALWAYS_INLINE convert_to_int_of_same_size(const Vectorized& src) { return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; } template <> Vectorized C10_ALWAYS_INLINE convert_to_int_of_same_size(const Vectorized& src) { return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; } template <> inline void convert(const int32_t* src, float* dst, int64_t n) { // int32_t and float have same size int64_t i; for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { const int32_t* src_a = src + i; float* dst_a = dst + i; vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast(src_a)); vint32 input_vec1 = vec_vsx_ld(offset16, reinterpret_cast(src_a)); vfloat32 c0 = vec_float(input_vec0); vfloat32 c1 = vec_float(input_vec1); vec_vsx_st(c0, offset0, dst_a); vec_vsx_st(c1, offset16, dst_a); } for (; i < n; i++) { dst[i] = static_cast(src[i]); } } template <> inline void convert(const int64_t* src, double* dst, int64_t n) { int64_t i; for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { const int64_t* src_a = src + i; double* dst_a = dst + i; vint64 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast(src_a)); vint64 input_vec1 = vec_vsx_ld(offset16, reinterpret_cast(src_a)); vfloat64 c0 = vec_double(input_vec0); vfloat64 c1 = vec_double(input_vec1); vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); } for (; i < n; i++) { dst[i] = static_cast(src[i]); } } // Generic implementation to fix compiler error // TO-DO : Add optimized version for ppc64 inline std::tuple, Vectorized> convert_half_float( const Vectorized& a) { constexpr int64_t K = Vectorized::size(); __at_align__ float arr[K]; __at_align__ Half arr2[K]; a.store(arr2); convert(arr2, arr, K); return std::make_tuple( Vectorized::loadu(arr), Vectorized::loadu(arr + Vectorized::size())); } inline Vectorized convert_float_half( const Vectorized& a, const Vectorized& b) { constexpr int64_t K = Vectorized::size(); __at_align__ float arr[K]; __at_align__ Half arr2[K]; a.store(arr); b.store(arr + Vectorized::size()); convert(arr, arr2, K); return Vectorized::loadu(arr2); }; template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3} // b = {b0, b1, b2, b3} vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); // return {a0, b0, a1, b1} // {a2, b2, a3, b3} return std::make_pair( Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); } template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1} // b = {a2, b2, a3, b3} vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); // swap lanes: // return {a0, a1, a2, a3} // {b0, b1, b2, b3} return std::make_pair( Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); } template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3,, a4, a5, a6, a7} // b = {b0, b1, b2, b3,, b4, b5, b6, b7} vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); // group cols crossing lanes: // return {a0, b0, a1, b1,, a2, b2, a3, b3} // {a4, b4, a5, b5,, a6, b6, a7, b7} return std::make_pair( Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); } template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1,, a2, b2, a3, b3} // b = {a4, b4, a5, b5,, a6, b6, a7, b7} // {a0,a2,b0,b2} {a1,a3,b1,b3} vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); // it could be done with vec_perm ,too // swap lanes: // return {a0, a1, a2, a3,, a4, a5, a6, a7} // {b0, b1, b2, b3,, b4, b5, b6, b7} return std::make_pair( Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); } } // namespace CPU_CAPABILITY } // namespace vec } // namespace at