/*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Matrix multiply */ #pragma once #include #include "cutlass/cutlass.h" #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" //////////////////////////////////////////////////////////////////////////////// #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) #define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define CUTLASS_ARCH_MMA_SM80_ENABLED #if (__CUDA_ARCH__ <= 900) #define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED #endif #if (__CUDA_ARCH__ <= 890) #define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED #endif #endif #endif //////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace arch { //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 1688 - Float BF16, FP32 accumulation // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 template <> struct Mma< gemm::GemmShape<16, 8, 8>, 32, bfloat16_t, layout::RowMajor, bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 8>; using ElementA = bfloat16_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = bfloat16_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = float; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); asm( "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) ); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 1684 - Float TF32 // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 template <> struct Mma< gemm::GemmShape<16, 8, 4>, 32, tfloat32_t, layout::RowMajor, tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 4>; using ElementA = tfloat32_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = tfloat32_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = float; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) ); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 1688 - Float TF32 // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 template <> struct Mma, 32, tfloat32_t, layout::RowMajor, tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 8>; using ElementA = tfloat32_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = tfloat32_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = float; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16816 // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F16 = F16 * F16 + F16 template <> struct Mma< gemm::GemmShape<16, 8, 16>, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 16>; using ElementA = half_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = half_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = half_t; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); uint32_t const *C = reinterpret_cast(&c); uint32_t *D = reinterpret_cast(&d); asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]) ); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 template <> struct Mma< gemm::GemmShape<16, 8, 16>, 32, bfloat16_t, layout::RowMajor, bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 16>; using ElementA = bfloat16_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = bfloat16_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = float; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F32 = F16 * F16 + F32 template <> struct Mma< gemm::GemmShape<16, 8, 16>, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16, 8, 16>; using ElementA = half_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = half_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = float; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " "{%10,%11,%12,%13};\n" : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 884 - F64 // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: F64 = F64 * F64 + F64 template <> struct Mma< gemm::GemmShape<8,8,4>, 32, double, layout::RowMajor, double, layout::ColumnMajor, double, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<8,8,4>; using ElementA = double; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = double; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = double; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) double const & A = reinterpret_cast(a); double const & B = reinterpret_cast(b); double const *C = reinterpret_cast(&c); double *D = reinterpret_cast(&d); asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=d"(D[0]), "=d"(D[1]) : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); #else CUTLASS_UNUSED(d); CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_NOT_IMPLEMENTED(); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< gemm::GemmShape<16,8,16>, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< gemm::GemmShape<16,8,16>, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< gemm::GemmShape<16,8,16>, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< gemm::GemmShape<16,8,16>, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< gemm::GemmShape<16,8,32>, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const * A = reinterpret_cast(&a); uint32_t const * B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< gemm::GemmShape<16,8,32>, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,32>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< gemm::GemmShape<16,8,32>, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< gemm::GemmShape<16,8,32>, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,32>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct Mma< gemm::GemmShape<16, 8, 64>, 32, cutlass::int4b_t, layout::RowMajor, cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16, 8, 64>; using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const * A = reinterpret_cast(&a); uint32_t const * B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct Mma< gemm::GemmShape<16, 8, 64>, 32, cutlass::uint4b_t, layout::RowMajor, cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16, 8, 64>; using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; /// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct Mma< gemm::GemmShape<16, 8, 64>, 32, cutlass::int4b_t, layout::RowMajor, cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16, 8, 64>; using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; /// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct Mma< gemm::GemmShape<16, 8, 64>, 32, cutlass::uint4b_t, layout::RowMajor, cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16, 8, 64>; using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = B1 & B1 + S32 template <> struct Mma< gemm::GemmShape<16,8,256>, 32, cutlass::uint1b_t, layout::RowMajor, cutlass::uint1b_t, layout::ColumnMajor, int32_t, layout::RowMajor, OpAndPopc> { using Shape = gemm::GemmShape<16,8,256>; using ElementA = cutlass::uint1b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::uint1b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int32_t; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpAndPopc; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, " "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; /// Matrix multiply-add operation: S32 = B1 & B1 + S32 template <> struct Mma< gemm::GemmShape<16,8,256>, 32, cutlass::uint1b_t, layout::RowMajor, cutlass::uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd> { using Shape = gemm::GemmShape<16,8,256>; using ElementA = cutlass::uint1b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::uint1b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int32_t; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpMultiplyAdd; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, " "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif } }; //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = B1 & B1 + S32 template <> struct Mma< gemm::GemmShape<16,8,256>, 32, cutlass::uint1b_t, layout::RowMajor, cutlass::uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc> { using Shape = gemm::GemmShape<16,8,256>; using ElementA = cutlass::uint1b_t; using LayoutA = layout::RowMajor; using FragmentA = Array; using ElementB = cutlass::uint1b_t; using LayoutB = layout::ColumnMajor; using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; using Operator = OpXorPopc; using ArchTag = arch::Sm80; /// Computes multiply-add CUTLASS_HOST_DEVICE void operator()( FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c ) const { #if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, " "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); assert(0); #endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) } }; //////////////////////////////////////////////////////////////////////////////// } // namespace arch } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////