#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #if AT_CUSPARSELT_ENABLED() #include #endif namespace at::native { at::Tensor _cslt_compress(const Tensor& sparse_input); TORCH_CUDA_CPP_API std::tuple _cslt_sparse_mm_impl( const Tensor& compressed_A, const Tensor& dense_B, const std::optional& bias_opt, const std::optional& alpha_opt, const std::optional out_dtype_opt, bool transpose_result, int alg_id, int split_k, int split_k_mode, bool search_alg_id ); at::Tensor _cslt_sparse_mm( const Tensor& compressed_A, const Tensor& dense_B, const std::optional& bias_opt, const std::optional& alpha_opt, const std::optional out_dtype_opt, bool transpose_result, int64_t alg_id, int64_t split_k, int64_t split_k_mode ); int64_t _cslt_sparse_mm_search( const Tensor& compressed_A, const Tensor& dense_B, const std::optional& bias_opt, const std::optional& alpha_opt, const std::optional out_dtype_opt, bool transpose_result ); } // namespace at::native