#pragma once #include #include namespace at::native { template inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); } template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); } // namespace at::native