/*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/platform/platform.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" #include "cutlass/gemm/warp/mma_tensor_op_policy.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { template struct ConvertAndPack { using Converter = NumericArrayConverter; CUTLASS_HOST_DEVICE Array operator()(Array const &source) { Converter converter; return converter(source); } }; template struct ConvertAndPack { CUTLASS_HOST_DEVICE Array operator()(Array const &source) { return source; } }; template struct ConvertAndPack { using Converter = NumericArrayConverter; CUTLASS_HOST_DEVICE Array operator()(Array const &source) { Converter converter; Array tmp; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); tmp[i] = source[idx]; } return converter(tmp); } }; template struct ConvertAndPack { using Converter = NumericArrayConverter; CUTLASS_HOST_DEVICE Array operator()(Array const &source) { Converter converter; Array tmp; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); tmp[i] = source[idx]; } return converter(tmp); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting Tensor Cores. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Data type of A elements typename ElementA_, /// Layout of A matrix (concept: MatrixLayout) typename LayoutA_, /// Data type of B elements typename ElementB_, /// Layout of B matrix (concept: MatrixLayout) typename LayoutB_, /// Element type of C matrix typename ElementC_, /// Layout of C matrix (concept: MatrixLayout) typename LayoutC_, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) typename Policy_, /// Number of partitions along K dimension int PartitionsK_ = 1, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor = false, /// Used for partial specialization typename Enable = bool > class MmaTensorOp { public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; /// Data type of multiplicand A using ElementA = ElementA_; /// Layout of multiplicand A using LayoutA = LayoutA_; /// Data type of multiplicand B using ElementB = ElementB_; /// Layout of multiplicand B using LayoutB = LayoutB_; /// Data type of accumulator matrix C using ElementC = ElementC_; /// Layout of accumulator matrix C using LayoutC = LayoutC_; /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; /// Underlying matrix multiply operator (concept: arch::Mma) using ArchMmaOperator = typename Policy::Operator; /// Indicates math operator using MathOperator = typename ArchMmaOperator::Operator; /// Architecture tag from underlying instruction using ArchTag = typename ArchMmaOperator::ArchTag; /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; /// Shape of underlying instruction using InstructionShape = typename ArchMmaOperator::Shape; /// Complex transform on A operand static ComplexTransform const kTransformA = ComplexTransform::kNone; /// Complex transform on B operand static ComplexTransform const kTransformB = ComplexTransform::kNone; /// Number of threads participating in warp-level matrix product static int const kThreadCount = 32; /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ < 800) || (__CUDA_ARCH__ == 890)) static int const kVerticalVisit = true; #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1200) static int const kVerticalVisit = true; #else static int const kVerticalVisit = false; #endif public: /// Iterates over the A operand in memory using IteratorA = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kA, ElementA, LayoutA, MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for A tile using FragmentA = typename IteratorA::Fragment; /// Storage for transformed A tile using TransformedFragmentA = Array; /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kB, ElementB, LayoutB, MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for B tile using FragmentB = typename IteratorB::Fragment; /// Storage for transformed B tile using TransformedFragmentB = Array; /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< MatrixShape, ElementC, LayoutC, typename ArchMmaOperator::Shape, typename Policy::OpDelta>; /// Storage for C tile using FragmentC = typename IteratorC::Fragment; /// Number of mma operations performed using MmaIterations = MatrixShape< (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN >; public: /// Underlying matrix multiply operator (concept: arch::Mma) ArchMmaOperator mma; public: // // Methods // /// Ctor CUTLASS_DEVICE MmaTensorOp() {} /// Performs a warp-level matrix multiply-accumulate operation CUTLASS_DEVICE void operator()( FragmentC &D, TransformedFragmentA const &A, TransformedFragmentB const &B, FragmentC const &C ) const { using MmaOperandA = typename ArchMmaOperator::FragmentA; using MmaOperandB = typename ArchMmaOperator::FragmentB; using MmaOperandC = typename ArchMmaOperator::FragmentC; D = C; MmaOperandA const *ptr_A = reinterpret_cast(&A); MmaOperandB const *ptr_B = reinterpret_cast(&B); MmaOperandC *ptr_D = reinterpret_cast(&D); if (kVerticalVisit) { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < MmaIterations::kColumn; ++n) { CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); if (AccumulatorsInRowMajor) { // matrix B is reordered mma( ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n], ptr_D[n + m_serpentine * MmaIterations::kColumn]); } else { mma( ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n], ptr_D[m_serpentine + n * MmaIterations::kRow]); } } } } else { CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < MmaIterations::kColumn; ++n) { int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); if (AccumulatorsInRowMajor) { // matrix B is reordered mma( ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine], ptr_D[n_serpentine + m * MmaIterations::kColumn]); } else { mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine], ptr_D[m + n_serpentine * MmaIterations::kRow]); } } } } } /// Transform the mma operands to the required types CUTLASS_DEVICE void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { // // Define conversions from source type to instruction type // FloatRoundStyle const kRoundA = PreferredRoundingMode::kRound; FloatRoundStyle const kRoundB = PreferredRoundingMode::kRound; if (kVerticalVisit) { detail::ConvertAndPack convert_A; NumericArrayConverter convert_B; Array const *ptr_B = reinterpret_cast const *>(&B); Array * ptr_dst_B = reinterpret_cast *>(&dst_B); dst_A = convert_A(A); ptr_dst_B[0] = convert_B(ptr_B[0]); ptr_dst_B[1] = convert_B(ptr_B[1]); } else { detail::ConvertAndPack convert_A; NumericArrayConverter convert_B; Array const *ptr_A = reinterpret_cast const *>(&A); Array * ptr_dst_A = reinterpret_cast *>(&dst_A); dst_B = convert_B(B); ptr_dst_A[0] = convert_A(ptr_A[0]); ptr_dst_A[1] = convert_A(ptr_A[1]); } } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp } // namespace gemm } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// #include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" /////////////////////////////////////////////////////////////////////////////////////////////////