/*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #if defined(__CUDACC_RTC__) #include #else #include #endif #include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/platform/platform.h" #include "cutlass/real.h" #include "cutlass/numeric_types.h" #include "cutlass/fast_math.h" #if !defined(__CUDACC_RTC__) #include #endif namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Enumeraed type describing a transformation on a complex value. enum class ComplexTransform { kNone, kConjugate }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines ComplexTransform inversions template struct InvertComplexTransform; /// Invert ComplexTransform from kNone to kConjugate template <> struct InvertComplexTransform { static ComplexTransform const transform = ComplexTransform::kConjugate; }; /// Invert ComplexTransform from kConjugate to kNone template <> struct InvertComplexTransform { static ComplexTransform const transform = ComplexTransform::kNone; }; ///////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////// // // Accessors for CUDA complex types // #if !defined(__CUDACC_RTC__) /// Returns the real part of the complex number CUTLASS_HOST_DEVICE float const &real(cuFloatComplex const &z) { return z.x; } /// Returns the real part of the complex number CUTLASS_HOST_DEVICE float &real(cuFloatComplex &z) { return z.x; } /// Returns the real part of the complex number CUTLASS_HOST_DEVICE double const &real(cuDoubleComplex const &z) { return z.x; } /// Returns the real part of the complex number CUTLASS_HOST_DEVICE double &real(cuDoubleComplex &z) { return z.x; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE float const &imag(cuFloatComplex const &z) { return z.y; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE float &imag(cuFloatComplex &z) { return z.y; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE double const &imag(cuDoubleComplex const &z) { return z.y; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE double &imag(cuDoubleComplex &z) { return z.y; } // Returns the conjugate of the complex number CUTLASS_HOST_DEVICE cuFloatComplex conj(cuFloatComplex const& z) { return make_cuFloatComplex(z.x, -z.y); } // Returns the conjugate of the complex number CUTLASS_HOST_DEVICE cuDoubleComplex conj(cuDoubleComplex const& z) { return make_cuDoubleComplex(z.x, -z.y); } #endif /////////////////////////////////////////////////////////////////////////////////////////////////// /// Class for representing and manipulating complex numbers with conversions from built-in CUDA /// complex types. template class complex { public: /// Type alias for scalar type using value_type = T; private: // // Data members // /// Real part T _real; /// Imaginary part T _imag; public: // // Methods // /// Default constructor complex() = default; /// Constructor CUTLASS_HOST_DEVICE complex(T r) : _real(r), _imag(T(0)) {} /// Constructor CUTLASS_HOST_DEVICE complex(T r, T i) : _real(r), _imag(i) {} /// Constructor template CUTLASS_HOST_DEVICE complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} #if !defined(__CUDACC_RTC__) /// Conversion from cuFloatComplex CUTLASS_HOST_DEVICE complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} /// Conversion from cuDoubleComplex CUTLASS_HOST_DEVICE complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} #endif /// Equality operator CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { return this->real() == rhs.real() && this->imag() == rhs.imag(); } /// Inequality operator CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { return !(*this == rhs); } /// Addition template CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); } /// Reduction into memory address. Components may update out of order. template CUTLASS_DEVICE void red(complex *ptr) const { static_assert(platform::is_same::value, "Component type must match"); cutlass::atomic_add reduce; reduce(&ptr->_real, _real); reduce(&ptr->_imag, _imag); } /// Reduction into memory address. Components may update out of order. (Half specialization) CUTLASS_DEVICE void red(complex *ptr) const { static_assert(platform::is_same::value, "Component type must match"); half2 *h2_ptr = reinterpret_cast(ptr); half2 h2_data = reinterpret_cast(*this); cutlass::atomic_add reduce; reduce(h2_ptr, h2_data); } /// Subtraction template CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); } /// Multiplication template CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), this->real() * rhs.imag() + this->imag() * rhs.real()); } /// Scalar Multiplication template CUTLASS_HOST_DEVICE complex operator*(A const &s) const { return complex(this->real() * s, this->imag() * s); } /// Division template CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); return complex( (real() * rhs.real() + imag() * rhs.imag()) / d, (imag() * rhs.real() - real() * rhs.imag()) / d ); } /// Scalar Division template CUTLASS_HOST_DEVICE complex operator/(A const &s) const { return complex(this->real() / s, this->imag() / s); } /// Addition template CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { *this = *this + rhs; return *this; } /// Subtraction template CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { *this = *this - rhs; return *this; } /// Multiplication template CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { *this = *this * rhs; return *this; } /// Scalar multiplication template CUTLASS_HOST_DEVICE complex &operator*=(A s) { *this = *this * s; return *this; } /// Division template CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { *this = *this / rhs; return *this; } /// Accesses the real part of the complex number CUTLASS_HOST_DEVICE T const &real() const { return _real; } /// Accesses the real part of the complex number CUTLASS_HOST_DEVICE T &real() { return _real; } /// Accesses the imaginary part of the complex number CUTLASS_HOST_DEVICE T const &imag() const { return _imag; } /// Accesses the imaginary part of the complex number CUTLASS_HOST_DEVICE T &imag() { return _imag; } /// Set the real part of the complex number CUTLASS_HOST_DEVICE void real(T real) { _real = real; } /// Set the imaginary part of the complex number CUTLASS_HOST_DEVICE void imag(T imag) { _imag = imag; } #if !defined(__CUDACC_RTC__) /// Converts to cuFloatComplex CUTLASS_HOST_DEVICE explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } /// Converts to cuDoubleComplex CUTLASS_HOST_DEVICE explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } #endif }; // Complex conjugate template CUTLASS_HOST_DEVICE complex conj(complex const& z) { return {z.real(), -z.imag()}; } /////////////////////////////////////////////////////////////////////////////////////////////////// // // Accessors for complex template // // Nonmember real and imag need to work for non-complex numbers too. // That means cutlass::complex, std::complex, cuda::std::complex, and // any user-defined complex number type that looks like std::complex. // It's reasonable to assume that a "complex number type" has // zero-argument real() and imag() member functions returning // non-void. While cuFloatComplex and cuDoubleComplex lack those // member functions, one-argument nonmember real and imag overloads // for those types are defined above. namespace detail { template struct has_zero_argument_real_member_function : cutlass::platform::false_type {}; template struct has_zero_argument_real_member_function().real()) > > > : cutlass::platform::true_type {}; template constexpr bool has_zero_argument_real_member_function_v = has_zero_argument_real_member_function::value; template struct has_zero_argument_imag_member_function : cutlass::platform::false_type {}; template struct has_zero_argument_imag_member_function().imag()) > > > : cutlass::platform::true_type {}; template constexpr bool has_zero_argument_imag_member_function_v = has_zero_argument_imag_member_function::value; } // namespace detail template CUTLASS_HOST_DEVICE auto real(T z) { if constexpr (detail::has_zero_argument_real_member_function_v) { return z.real(); } else { return z; } } template CUTLASS_HOST_DEVICE auto imag(T z) { if constexpr (detail::has_zero_argument_imag_member_function_v) { return z.imag(); } else { // Imaginary part of a non-complex input has the same type as the // input, and its value is zero. CUTLASS assumes in this case // that value-initializing T is well-formed and results in zero. return T{}; } } // // Output operators // #if !defined(__CUDACC_RTC__) template std::ostream &operator<<(std::ostream &out, complex const &z) { T _r = real(z); T _i = imag(z); if (bool(_i)) { return out << _r << "+i" << _i; } return out << _r; } #endif // // Non-member operators defined for complex types // // // Non-member functions defined for complex numbers // // abs returns the magnitude of the complex number. CUTLASS_HOST_DEVICE float abs(complex const &z) { return ::hypot(z.real(), z.imag()); } CUTLASS_HOST_DEVICE double abs(complex const &z) { return ::hypot(z.real(), z.imag()); } // In theory, it would make sense to add a complex // specialization of abs here, since hypot works for long double too. // In practice, long double doesn't have a portable number of bits or // behavior, so users who care about higher-precision floating-point // computation should probably insist on an actual FP128 type. template CUTLASS_HOST_DEVICE T abs(complex const &z) { // cutlass::complex permits all kinds of T, including types that // don't have NaN. For a generic floating-point type with Inf // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it // would handle issues like avoiding unwarranted overflow if // z.real() or z.imag() is slightly bigger than the square root of // the max finite number. That could be a future improvement; for // now, the code just uses the naive algorithm. // // Use the "swap two-step" idiom so that argument-dependent lookup // can find any CUTLASS-specific overloads. using cutlass::sqrt; return sqrt(z.real() * z.real() + z.imag() * z.imag()); } /// Returns the magnitude of the complex number template CUTLASS_HOST_DEVICE T arg(complex const &z) { return atan2(imag(z), real(z)); } /// Returns the squared magnitude of a real number template CUTLASS_HOST_DEVICE T norm(T const &z) { return z * z; } /// Returns the squared magnitude of a real number template <> CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { return static_cast(z * z); } /// Returns the squared magnitude of a complex number template CUTLASS_HOST_DEVICE double norm(complex const &z) { return real(z) * real(z) + imag(z) * imag(z); } /// Norm-accumulate calculation template CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { return accumulator + static_cast(x) * static_cast(x); } /// Norm accumulate specialized for complex types template CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { return accumulator + static_cast(real(z)) * static_cast(real(z)) + static_cast(imag(z)) * static_cast(imag(z)); } namespace detail { template CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) { return conj(z); } template CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) { return z; } template CUTLASS_HOST_DEVICE T conj_impl(T const& z) { constexpr bool use_unqualified_conj = ! cutlass::platform::is_arithmetic_v && ! detail::has_cutlass_conj_v && detail::has_unqualified_conj_v; return conj_impl(z, cutlass::platform::bool_constant{}); } } // namespace detail // Return the complex conjugate of the input. // // This MUST be a function and not a function object, because it may // be common practice for downstream types to define specifically // cutlass::conj overloads, instead of overloads in their namespace. // // As a result of this being a function and not a function object, // CUTLASS code needs to declare "using cutlass::conj;" in scope and // then call this function unqualified, just like std::swap. // // If an overload already exists for cutlass::conj(T), that overload // will be called instead of this one. Otherwise: // // 1. for arithmetic types, return z; // // 2. for types where (namespace-unqualified) conj(z) is well formed // and cutlass::conj(z) is NOT well formed, return conj(z); and, // // 3. for everything else, return z. // // Regarding (1), the C++ Standard Library makes std::conj always // return std::complex, even for (noncomplex) arithmetic types. // cutlass::conj(T t) needs to return type T. This follows the // convention of linear algebra software like the BLAS, where // "conjugate transpose" means the same thing as "transpose" for a // matrix of noncomplex numbers. // // Case (2) covers std::complex, cuda::std::complex, and non-Standard // (including user-defined) complex number types (for which "conj(z)" // is findable via argument-dependent lookup, but does not live in the // cutlass namespace). It excludes cutlass::conj(z) in order to // prevent infinite recursion. // // Case (3) covers non-Standard non-complex number types. template CUTLASS_HOST_DEVICE T conj(T const& z) { return detail::conj_impl(z); } /// Projects the complex number z onto the Riemann sphere template CUTLASS_HOST_DEVICE complex proj(complex const &z) { T d = real(z) * real(z) + imag(z) * imag(z) + T(1); return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); } /// Returns a complex number with magnitude r and phase theta template CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { return complex(r * cos(theta), r * sin(theta)); } /// Computes the complex exponential of z. template CUTLASS_HOST_DEVICE complex exp(complex const &z) { return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); } /// Computes the log of z template CUTLASS_HOST_DEVICE complex log(complex const &z) { return complex(log(abs(z)), arg(z)); } /// Computes the log base 10 of z template CUTLASS_HOST_DEVICE complex log10(complex const &z) { return log(z) / T(log(T(10))); } /// Computes the square root of complex number z template CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { return sqrt(T(2)) / T(2) * complex(sqrt(sqrt(norm(z)) + real(z)), (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); } /// Computes the cosine of complex z. template CUTLASS_HOST_DEVICE complex cos(complex const &z) { return (exp(z) + exp(-z)) / T(2); } /// Computes the sin of complex z. template CUTLASS_HOST_DEVICE complex sin(complex const &z) { return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); } /// Comparison template CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { return true; } ////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for complex-valued type. template struct RealType< complex > { using Type = T; /// Number of elements static int const kExtent = 2; CUTLASS_HOST_DEVICE static complex from_real(double x) { return complex(static_cast(x)); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template <> CUTLASS_HOST_DEVICE cutlass::complex from_real >(double r) { return cutlass::complex(half_t(r)); } template <> CUTLASS_HOST_DEVICE cutlass::complex from_real >(double r) { return cutlass::complex(float(r)); } template <> CUTLASS_HOST_DEVICE cutlass::complex from_real >(double r) { return cutlass::complex(r); } ////////////////////////////////////////////////////////////////////////////////////////////////// template struct is_complex { static bool const value = false; }; template struct is_complex> { static bool const value = true; }; ///////////////////////////////////////////////////////////////////////////////////////////////// // functional.h numeric specializations ///////////////////////////////////////////////////////////////////////////////////////////////// /// Squares with optional conversion template struct magnitude_squared, Output> { CUTLASS_HOST_DEVICE Output operator()(complex lhs) const { multiplies mul_op; Output y_r = Output(lhs.real()); Output y_i = Output(lhs.imag()); return mul_op(y_r, y_r) + mul_op(y_i, y_i); } }; /// Fused multiply-add template struct multiply_add, complex, complex> { CUTLASS_HOST_DEVICE complex operator()( complex const &a, complex const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a.real() * b.real(); real += -a.imag() * b.imag(); imag += a.real() * b.imag(); imag += a.imag () * b.real(); return complex{ real, imag }; } }; /// Fused multiply-add template struct multiply_add, T, complex> { CUTLASS_HOST_DEVICE complex operator()( complex const &a, T const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a.real() * b; imag += a.imag () * b; return complex{ real, imag }; } }; /// Fused multiply-add template struct multiply_add, complex> { CUTLASS_HOST_DEVICE complex operator()( T const &a, complex const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a * b.real(); imag += a * b.imag(); return complex{ real, imag }; } }; /// Conjugate template struct conjugate> { CUTLASS_HOST_DEVICE complex operator()(complex const &a) const { // Invoke the complex overload specifically, rather than // wasting the compiler's effort on overload resolution. return cutlass::conj(a); } }; #if ! defined(__CUDACC_RTC__) template <> struct conjugate { CUTLASS_HOST_DEVICE cuFloatComplex operator()(cuFloatComplex const& z) const { return make_cuFloatComplex(z.x, -z.y); } }; template <> struct conjugate { CUTLASS_HOST_DEVICE cuDoubleComplex operator()(cuDoubleComplex const& z) const { return make_cuDoubleComplex(z.x, -z.y); } }; #endif /// Computes the square of a difference with optional conversion template struct magnitude_squared_difference, Output> { CUTLASS_HOST_DEVICE Output operator()(complex lhs, complex rhs) const { multiplies mul_op; Output y_r = Output(lhs.real()) - Output(rhs.real()); Output y_i = Output(lhs.imag()) - Output(rhs.imag()); return mul_op(y_r, y_r) + mul_op(y_i, y_i); } }; /// Reduces value into the data pointed to by ptr (complex specialization) template struct atomic_add> { CUTLASS_DEVICE void operator()(complex *ptr, const complex &data) { data.red(ptr); } }; ////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////