// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // auto-generated by generate_configs.py #pragma once #include "cutlass/gemm_coord.h" namespace ap { constexpr int kNumConfigsHalf = 23; constexpr int kNumConfigsFloat = 13; template struct SwizzleWrapper { using Type = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; }; // template // struct SwizzleWrapper { // using Type = // cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // }; #define AP_AUTOTUNE_half(func, stream, ...) \ { \ using FuncType = decltype(func<0>); \ static int selected_config_id = -1; \ static std::vector> matmul_functions = { \ func<0>, func<1>, func<2>, func<3>, func<4>, func<5>, \ func<6>, func<7>, func<8>, func<9>, func<10>, func<11>, \ func<12>, func<13>, func<14>, func<15>, func<16>, func<17>, \ func<18>, func<19>, func<20>, func<21>, func<22>}; \ if (selected_config_id == -1) { \ selected_config_id = \ ap::ProfileBestConfig(matmul_functions, stream, ##__VA_ARGS__); \ } \ matmul_functions[selected_config_id](__VA_ARGS__); \ } #define AP_AUTOTUNE_nv_bfloat16(func, stream, ...) \ AP_AUTOTUNE_half(func, stream, __VA_ARGS__) #define AP_AUTOTUNE_float(func, stream, ...) \ { \ using FuncType = decltype(func<0>); \ static int selected_config_id = -1; \ static std::vector> matmul_functions = {func<0>, \ func<1>, \ func<2>, \ func<3>, \ func<4>, \ func<5>, \ func<6>, \ func<7>, \ func<8>, \ func<9>, \ func<10>, \ func<11>, \ func<12>}; \ if (selected_config_id == -1) { \ selected_config_id = \ ap::ProfileBestConfig(matmul_functions, stream, ##__VA_ARGS__); \ } \ matmul_functions[selected_config_id](__VA_ARGS__); \ } template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 2; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = Id; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 128, 64>; using WShape = cutlass::gemm::GemmShape<32, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 1; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 2; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 64, 64>; using WShape = cutlass::gemm::GemmShape<64, 32, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 3; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 4; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 5; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 6; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 7; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 128, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 8; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 9; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 32, 64>; using WShape = cutlass::gemm::GemmShape<32, 32, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 10; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 11; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 12; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 13; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 14; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<32, 64, 64>; using WShape = cutlass::gemm::GemmShape<16, 32, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 5; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 15; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 64, 64>; using WShape = cutlass::gemm::GemmShape<32, 32, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 5; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 16; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 32>; using WShape = cutlass::gemm::GemmShape<64, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 5; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 17; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 64>; using WShape = cutlass::gemm::GemmShape<64, 64, 64>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 5; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 18; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 128, 32>; using WShape = cutlass::gemm::GemmShape<32, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 6; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 19; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 64, 32>; using WShape = cutlass::gemm::GemmShape<64, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 6; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 20; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 32, 32>; using WShape = cutlass::gemm::GemmShape<32, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 7; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 21; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 64, 32>; using WShape = cutlass::gemm::GemmShape<32, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 16>; static constexpr int kNumStages = 10; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 22; }; // Specialization for float template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 64, 16>; using WShape = cutlass::gemm::GemmShape<32, 32, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = Id; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 64, 32>; using WShape = cutlass::gemm::GemmShape<32, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 1; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 128, 32>; using WShape = cutlass::gemm::GemmShape<32, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 2; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 256, 16>; using WShape = cutlass::gemm::GemmShape<32, 64, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 3; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 256, 32>; using WShape = cutlass::gemm::GemmShape<32, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 4; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 64, 32>; using WShape = cutlass::gemm::GemmShape<64, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 5; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 16>; using WShape = cutlass::gemm::GemmShape<32, 64, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 6; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 32>; using WShape = cutlass::gemm::GemmShape<32, 64, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 7; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 16>; using WShape = cutlass::gemm::GemmShape<64, 32, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 8; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<256, 64, 32>; using WShape = cutlass::gemm::GemmShape<64, 32, 32>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 3; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 9; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<64, 128, 16>; using WShape = cutlass::gemm::GemmShape<32, 64, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 10; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 64, 16>; using WShape = cutlass::gemm::GemmShape<64, 32, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 11; }; template struct GemmTuningConfigs { using TShape = cutlass::gemm::GemmShape<128, 128, 16>; using WShape = cutlass::gemm::GemmShape<32, 64, 16>; using IShape = cutlass::gemm::GemmShape<16, 8, 8>; static constexpr int kNumStages = 4; using SwizzleThreadBlock = typename SwizzleWrapper::Type; static constexpr int kId = 12; }; } // namespace ap