#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include // void layernorm2d(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias, double epsilon); torch::Tensor layernorm2d(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias, double epsilon, std::optional x_bias); void layernorm2d_with_add(torch::Tensor &out, torch::Tensor &input, torch::Tensor &residual_in, torch::Tensor &residual_out, torch::Tensor &weight, torch::Tensor &bias, double epsilon, std::optional x_bias); void layernorm2d_with_smoothquant(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &xscale, // [1 ,n] torch::Tensor &yscale, // [m ,1] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] double epsilon, std::optional x_bias); void layernorm2d_with_add_smoothquant(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &residual_in, // [m ,n] torch::Tensor &residual_out, // [m ,n] torch::Tensor &xscale, // [1 ,n] torch::Tensor &yscale, // [m ,1] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] double epsilon, std::optional x_bias); void layernorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &yscale, // [m ,1] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] double epsilon, std::optional x_bias); void layernorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &residual_in, // [m ,n] torch::Tensor &residual_out, // [m ,n] torch::Tensor &yscale, // [m ,1] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] double epsilon, std::optional x_bias); // following are asm kernels void layernorm2d_with_add_asm(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &residual_in, // [m ,n] torch::Tensor &residual_out, // [m ,n] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] float epsilon); void layernorm2d_with_add_smoothquant_asm(torch::Tensor &out, // [m ,n] torch::Tensor &input, // [m ,n] torch::Tensor &residual_in, // [m ,n] torch::Tensor &residual_out, // [m ,n] torch::Tensor &xscale, // [1 ,n] torch::Tensor &yscale, // [m ,1] torch::Tensor &weight, // [1 ,n] torch::Tensor &bias, // [1 ,n] float epsilon);