/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #undef __HIP_NO_HALF_CONVERSIONS__ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Define commonly used types. template using S = ck::Sequence; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; namespace at::native { // Elementwise Operators struct AlphaBetaAdd { AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(C& c, const AB& ab) const; template<> __host__ __device__ constexpr void operator() (float& c, const float& ab) const { c = alpha_ * ab; }; template<> __host__ __device__ constexpr void operator() (ck::bhalf_t& c, const ck::bhalf_t& ab) const { c = alpha_ * ab; }; template<> __host__ __device__ constexpr void operator() (ck::half_t& c, const ck::half_t& ab) const { c = alpha_ * ab; }; float alpha_; // TODO: Leaving for now, will use later float beta_; }; template < typename Dtype, int BLOCK_SIZE, int MBLOCK, int NBLOCK, int KBLOCK, int AK1, int BK1, int MPER_XDL, int NPER_XDL, int MPER_WAVE, int NPER_WAVE, typename ABLOCK_CLUSTER_LENS, typename ABLOCK_CLUSTER_ORDER, typename ABLOCK_SRC_ORDER, int ABLOCK_VECTOR_DIM, int ABLOCK_SCALAR_VEC, int ABLOCK_SCALAR_VEC_AK1, bool ABLOCK_LDS_EXTRAM, typename BBLOCK_CLUSTER_LENS, typename BBLOCK_CLUSTER_ORDER, typename BBLOCK_SRC_ORDER, int BBLOCK_VECTOR_DIM, int BBLOCK_SCALAR_VEC, int BBLOCK_SCALAR_VEC_AK1, bool BBLOCK_LDS_EXTRAN, int CMPER_WAVE, int CNPER_WAVE, typename BLOCK_CLUSTER_LENS, typename CDE_SCALAR_VEC, bool PADDING = false, bool TRANSA = false, bool TRANSB = false> void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { // Get input information. int M = m; int N = n; int K = k; int StrideA = lda; int StrideB = ldb; int StrideC = ldc; int KBatch = 1; float falpha = alpha; float fbeta = beta; using ADataType = typename CkMathType::dtype; using BDataType = typename CkMathType::dtype; using CDataType = typename CkMathType::dtype; using DDataType = typename CkMathType::dtype; using AccDataType = float; using CShuffleDataType = typename CkMathType::dtype; using ALayout = typename CkTensorLayout::a_layout; using BLayout = typename CkTensorLayout::b_layout; using DLayout = Row; using CLayout = Row; using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = AlphaBetaAdd; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3, CLayout, ADataType, BDataType, ck::Tuple<>, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmSpec, BLOCK_SIZE, MBLOCK, NBLOCK, KBLOCK, AK1, BK1, MPER_XDL, NPER_XDL, MPER_WAVE, NPER_WAVE, ABLOCK_CLUSTER_LENS, ABLOCK_CLUSTER_ORDER, ABLOCK_SRC_ORDER, ABLOCK_VECTOR_DIM, ABLOCK_SCALAR_VEC, ABLOCK_SCALAR_VEC_AK1, ABLOCK_LDS_EXTRAM, BBLOCK_CLUSTER_LENS, BBLOCK_CLUSTER_ORDER, BBLOCK_SRC_ORDER, BBLOCK_VECTOR_DIM, BBLOCK_SCALAR_VEC, BBLOCK_SCALAR_VEC_AK1, BBLOCK_LDS_EXTRAN, CMPER_WAVE, CNPER_WAVE, BLOCK_CLUSTER_LENS, CDE_SCALAR_VEC>; auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{alpha, beta}; using DDataArrayType = std::array; DDataArrayType DDataArray; // We swap A and B inputs here as a temporary workaround auto argument = gemm.MakeArgument( reinterpret_cast(b), reinterpret_cast(a), DDataArray, reinterpret_cast(c), N, M, K, StrideB, StrideA, std::array{}, StrideC, KBatch, a_element_op, b_element_op, c_element_op); if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } auto stream = at::cuda::getCurrentHIPStream().stream(); invoker.Run(argument, StreamConfig{stream, false}); } } // namespace at::native