/*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Default template for a Blocked-Ell MMA. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/wmma.h" #include "cutlass/layout/matrix.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include "cutlass/gemm/threadblock/default_mma_core_wmma.h" #endif //CUTLASS_ARCH_WMMA_ENABLED #include "cutlass/gemm/threadblock/ell_mma_pipelined.h" #include "cutlass/gemm/threadblock/ell_mma_multistage.h" #include "cutlass/transform/threadblock/ell_predicated_tile_iterator.h" //////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace threadblock { //////////////////////////////////////////////////////////////////////////////// template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand typename LayoutA_, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB_, /// Layout type for B matrix operand typename LayoutB_, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator_, /// Layout type for C and D matrix operands typename LayoutC_, /// Operator class tag typename OperatorClass_, /// Tag indicating architecture to tune for typename ArchTag_, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape_, /// Warp-level tile size (concept: GemmShape) typename WarpShape_, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, /// Operation perfomed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor = false > struct DefaultEllMma; //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass Simt) template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator> struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, 2, Operator>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp) template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator > struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp) template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator > struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, arch::OpMultiplyAddFastF16>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, float, layout::RowMajor, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for column-major-interleaved output template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename OperatorClass, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator, /// Number of Interleaved K int InterleavedK> struct DefaultEllMma, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, true>; static_assert(kAlignmentA == 128 / sizeof_bits::value, "Alignment must match thread data map's vector length"); static_assert(kAlignmentB ==128 / sizeof_bits::value, "Alignment must match thread data map's vector length"); // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, layout::ColumnMajorInterleaved, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, /// Operation perfomed by GEMM typename Operator > struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, Stages, Operator>; // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; using AccessTypeA = cutlass::Array; using IteratorA = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, Stages>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp) template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, /// Operation perfomed by GEMM typename Operator > struct DefaultEllMma { static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, Stages, Operator, false, CacheOpA, CacheOpB>; // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; using AccessTypeA = cutlass::Array; using IteratorA = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, Stages>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for column-major-interleaved output template < /// Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename OperatorClass, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Number of stages used in the multistage mainloop int Stages, /// Operation performed by GEMM typename Operator, /// Number of Interleaved K int InterleavedK> struct DefaultEllMma, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, true> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved, OperatorClass, Stages, Operator, true>; // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; using AccessTypeA = cutlass::Array; using IteratorA = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::EllPredicatedTileAccessIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, Stages>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for SIMT IDP4A Kernels template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Operation performed by GEMM typename Operator, /// Warp-level tile size (concept: GemmShape) typename WarpShape> struct DefaultEllMma, 2, Operator, false> { using InstructionShape = GemmShape<1, 1, 4>; using ElementA = int8_t; using ElementB = int8_t; using OperatorClass = arch::OpClassSimt; static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value; static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value; // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// #if defined(CUTLASS_ARCH_WMMA_ENABLED) /// Specialization for Wmma TensorOp operator with 2 staged pipeline template < ///< Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Layout type for C and D matrix operands typename LayoutC, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator> struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassWmmaTensorOp, 2, Operator>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, LayoutC, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// /// Specialization for Wmma TensorOp operator with 1 staged pipeline template < ///< Element type for A matrix operand typename ElementA, /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements int kAlignmentA, /// Element type for B matrix operand typename ElementB, /// Layout type for B matrix operand typename LayoutB, /// Access granularity of B matrix in units of elements int kAlignmentB, /// Element type for internal accumulation typename ElementAccumulator, /// Layout type for C and D matrix operands typename LayoutC, /// Tag indicating architecture to tune for typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM typename Operator> struct DefaultEllMma { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassWmmaTensorOp, 1, Operator>; // Define iterators over tiles from the A operand using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; // Define the threadblock-scoped singlestage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, LayoutC, typename MmaCore::MmaPolicy>; }; //////////////////////////////////////////////////////////////////////////////// #endif //CUTLASS_ARCH_WMMA_ENABLED } // namespace threadblock } // namespace gemm } // namespace cutlass ////////////////////////////////////////////////////////////////////////////////