// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include torch::Tensor ck_moe(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional w2_scale, // [e, 1, k], down scale std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_m = 32, std::optional expert_mask = std::nullopt); void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional w1_scale, // [e, 1, n], gate(up) scale std::optional a1_scale, // [m, 1], token scale std::optional block_m, std::optional sorted_weights, std::optional act_op); void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // [1] torch::Tensor &out, // [max_num_tokens_padded, inter_dim] int topk, std::optional w2_scale, // [e, 1, n], gate(up) scale std::optional a2_scale, // [m, 1], token scale std::optional block_m, std::optional sorted_weights); // [max_num_tokens_padded]);