#pragma once #include #if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 11000 // cuSOLVER version >= 11000 includes 64-bit API #define USE_CUSOLVER_64_BIT #endif #if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 11701 // cuSOLVER version >= 11701 includes 64-bit API for batched syev #define USE_CUSOLVER_64_BIT_XSYEV_BATCHED #endif #if defined(CUDART_VERSION) || defined(USE_ROCM) namespace at { namespace cuda { namespace solver { #define CUDASOLVER_GETRF_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, int m, int n, Dtype* dA, int ldda, int* ipiv, int* info template void getrf(CUDASOLVER_GETRF_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrf: not implemented"); } template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(float)); template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(double)); template<> void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); template<> void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); #define CUDASOLVER_GETRS_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, int n, int nrhs, Dtype* dA, int lda, int* ipiv, Dtype* ret, int ldb, int* info, cublasOperation_t trans template void getrs(CUDASOLVER_GETRS_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrs: not implemented"); } template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(float)); template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(double)); template<> void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); template<> void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); #define CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, int n, Dtype *A, int lda, int *lwork template void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::sytrf_bufferSize: not implemented"); } template <> void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float)); template <> void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(double)); template <> void sytrf_bufferSize>( CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex)); template <> void sytrf_bufferSize>( CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex)); #define CUDASOLVER_SYTRF_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype *A, int lda, \ int *ipiv, Dtype *work, int lwork, int *devInfo template void sytrf(CUDASOLVER_SYTRF_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::sytrf: not implemented"); } template <> void sytrf(CUDASOLVER_SYTRF_ARGTYPES(float)); template <> void sytrf(CUDASOLVER_SYTRF_ARGTYPES(double)); template <> void sytrf>( CUDASOLVER_SYTRF_ARGTYPES(c10::complex)); template <> void sytrf>(CUDASOLVER_SYTRF_ARGTYPES(c10::complex)); #define CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES() \ cusolverDnHandle_t handle, int m, int n, int *lwork template void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd_buffersize: not implemented"); } template<> void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); template<> void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); template<> void gesvd_buffersize>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); template<> void gesvd_buffersize>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); #define CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, int n, Dtype *A, int lda, \ Vtype *S, Dtype *U, int ldu, Dtype *VT, int ldvt, Dtype *work, int lwork, Vtype *rwork, int *info template void gesvd(CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd: not implemented"); } template<> void gesvd(CUDASOLVER_GESVD_ARGTYPES(float, float)); template<> void gesvd(CUDASOLVER_GESVD_ARGTYPES(double, double)); template<> void gesvd>(CUDASOLVER_GESVD_ARGTYPES(c10::complex, float)); template<> void gesvd>(CUDASOLVER_GESVD_ARGTYPES(c10::complex, double)); #define CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype *A, int lda, Vtype *S, \ Dtype *U, int ldu, Dtype *V, int ldv, int *lwork, gesvdjInfo_t params template void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj_buffersize: not implemented"); } template<> void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(float, float)); template<> void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(double, double)); template<> void gesvdj_buffersize>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex, float)); template<> void gesvdj_buffersize>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \ int ldu, Dtype* V, int ldv, Dtype* work, int lwork, int *info, gesvdjInfo_t params template void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(float, float)); template<> void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(double, double)); template<> void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex, float)); template<> void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex, double)); #define CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \ int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params, int batchSize template void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float)); template<> void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(double, double)); template<> void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex, float)); template<> void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex, double)); #define CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \ Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \ int *lwork, int batchSize template void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented"); } template<> void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); template<> void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(double, double)); template<> void gesvdaStridedBatched_buffersize>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, float)); template<> void gesvdaStridedBatched_buffersize>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \ Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \ Dtype *work, int lwork, int *info, double *h_R_nrmF, int batchSize // h_R_nrmF is always double, regardless of input Dtype. template void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched: not implemented"); } template<> void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(float, float)); template<> void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(double, double)); template<> void gesvdaStridedBatched>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex, float)); template<> void gesvdaStridedBatched>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex, double)); #define CUDASOLVER_POTRF_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, Dtype* work, int lwork, int* info template void potrf(CUDASOLVER_POTRF_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf: not implemented"); } template<> void potrf(CUDASOLVER_POTRF_ARGTYPES(float)); template<> void potrf(CUDASOLVER_POTRF_ARGTYPES(double)); template<> void potrf>(CUDASOLVER_POTRF_ARGTYPES(c10::complex)); template<> void potrf>(CUDASOLVER_POTRF_ARGTYPES(c10::complex)); #define CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, int* lwork template void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf_buffersize: not implemented"); } template<> void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(float)); template<> void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(double)); template<> void potrf_buffersize>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex)); template<> void potrf_buffersize>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex)); #define CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype** A, int lda, int* info, int batchSize template void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrfBatched: not implemented"); } template<> void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(float)); template<> void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(double)); template<> void potrfBatched>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex)); template<> void potrfBatched>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex)); #define CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t) \ cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, int *lwork template void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::geqrf_bufferSize: not implemented"); } template <> void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float)); template <> void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(double)); template <> void geqrf_bufferSize>( CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex)); template <> void geqrf_bufferSize>( CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex)); #define CUDASOLVER_GEQRF_ARGTYPES(scalar_t) \ cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, \ scalar_t *tau, scalar_t *work, int lwork, int *devInfo template void geqrf(CUDASOLVER_GEQRF_ARGTYPES(scalar_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::geqrf: not implemented"); } template <> void geqrf(CUDASOLVER_GEQRF_ARGTYPES(float)); template <> void geqrf(CUDASOLVER_GEQRF_ARGTYPES(double)); template <> void geqrf>(CUDASOLVER_GEQRF_ARGTYPES(c10::complex)); template <> void geqrf>( CUDASOLVER_GEQRF_ARGTYPES(c10::complex)); #define CUDASOLVER_POTRS_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const Dtype *A, int lda, Dtype *B, int ldb, int *devInfo template void potrs(CUDASOLVER_POTRS_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrs: not implemented"); } template<> void potrs(CUDASOLVER_POTRS_ARGTYPES(float)); template<> void potrs(CUDASOLVER_POTRS_ARGTYPES(double)); template<> void potrs>(CUDASOLVER_POTRS_ARGTYPES(c10::complex)); template<> void potrs>(CUDASOLVER_POTRS_ARGTYPES(c10::complex)); #define CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, Dtype *Aarray[], int lda, Dtype *Barray[], int ldb, int *info, int batchSize template void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrsBatched: not implemented"); } template<> void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(float)); template<> void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(double)); template<> void potrsBatched>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex)); template<> void potrsBatched>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex)); #define CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, int m, int n, int k, const Dtype *A, int lda, \ const Dtype *tau, int *lwork template void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr_buffersize: not implemented"); } template <> void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(float)); template <> void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(double)); template <> void orgqr_buffersize>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex)); template <> void orgqr_buffersize>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex)); #define CUDASOLVER_ORGQR_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, int m, int n, int k, Dtype *A, int lda, \ const Dtype *tau, Dtype *work, int lwork, int *devInfo template void orgqr(CUDASOLVER_ORGQR_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr: not implemented"); } template <> void orgqr(CUDASOLVER_ORGQR_ARGTYPES(float)); template <> void orgqr(CUDASOLVER_ORGQR_ARGTYPES(double)); template <> void orgqr>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex)); template <> void orgqr>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex)); #define CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \ int m, int n, int k, const Dtype *A, int lda, const Dtype *tau, \ const Dtype *C, int ldc, int *lwork template void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::ormqr_bufferSize: not implemented"); } template <> void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float)); template <> void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(double)); template <> void ormqr_bufferSize>( CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex)); template <> void ormqr_bufferSize>( CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex)); #define CUDASOLVER_ORMQR_ARGTYPES(Dtype) \ cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \ int m, int n, int k, const Dtype *A, int lda, const Dtype *tau, Dtype *C, \ int ldc, Dtype *work, int lwork, int *devInfo template void ormqr(CUDASOLVER_ORMQR_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype), "at::cuda::solver::ormqr: not implemented"); } template <> void ormqr(CUDASOLVER_ORMQR_ARGTYPES(float)); template <> void ormqr(CUDASOLVER_ORMQR_ARGTYPES(double)); template <> void ormqr>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex)); template <> void ormqr>( CUDASOLVER_ORMQR_ARGTYPES(c10::complex)); #ifdef USE_CUSOLVER_64_BIT template cudaDataType get_cusolver_datatype() { static_assert(false&&sizeof(Dtype), "cusolver doesn't support data type"); return {}; } template<> cudaDataType get_cusolver_datatype(); template<> cudaDataType get_cusolver_datatype(); template<> cudaDataType get_cusolver_datatype>(); template<> cudaDataType get_cusolver_datatype>(); void xpotrf_buffersize( cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, const void *A, int64_t lda, cudaDataType computeType, size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost); void xpotrf( cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, void *A, int64_t lda, cudaDataType computeType, void *bufferOnDevice, size_t workspaceInBytesOnDevice, void *bufferOnHost, size_t workspaceInBytesOnHost, int *info); void xpotrs( cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A, int64_t lda, cudaDataType dataTypeB, void *B, int64_t ldb, int *info); #endif // USE_CUSOLVER_64_BIT #define CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, const scalar_t *A, int lda, const value_t *W, int *lwork template void syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevd_bufferSize: not implemented"); } template <> void syevd_bufferSize( CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(float, float)); template <> void syevd_bufferSize( CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(double, double)); template <> void syevd_bufferSize, float>( CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex, float)); template <> void syevd_bufferSize, double>( CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \ int *info template void syevd(CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevd: not implemented"); } template <> void syevd(CUDASOLVER_SYEVD_ARGTYPES(float, float)); template <> void syevd(CUDASOLVER_SYEVD_ARGTYPES(double, double)); template <> void syevd, float>( CUDASOLVER_SYEVD_ARGTYPES(c10::complex, float)); template <> void syevd, double>( CUDASOLVER_SYEVD_ARGTYPES(c10::complex, double)); #define CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \ syevjInfo_t params template void syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj_bufferSize: not implemented"); } template <> void syevj_bufferSize( CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(float, float)); template <> void syevj_bufferSize( CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(double, double)); template <> void syevj_bufferSize, float>( CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex, float)); template <> void syevj_bufferSize, double>( CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \ int *info, syevjInfo_t params template void syevj(CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj: not implemented"); } template <> void syevj(CUDASOLVER_SYEVJ_ARGTYPES(float, float)); template <> void syevj(CUDASOLVER_SYEVJ_ARGTYPES(double, double)); template <> void syevj, float>( CUDASOLVER_SYEVJ_ARGTYPES(c10::complex, float)); template <> void syevj, double>( CUDASOLVER_SYEVJ_ARGTYPES(c10::complex, double)); #define CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \ syevjInfo_t params, int batchsize template void syevjBatched_bufferSize( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevjBatched_bufferSize: not implemented"); } template <> void syevjBatched_bufferSize( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); template <> void syevjBatched_bufferSize( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(double, double)); template <> void syevjBatched_bufferSize, float>( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, float)); template <> void syevjBatched_bufferSize, double>( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \ int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \ int *info, syevjInfo_t params, int batchsize template void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevjBatched: not implemented"); } template <> void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(float, float)); template <> void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(double, double)); template <> void syevjBatched, float>( CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex, float)); template <> void syevjBatched, double>( CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex, double)); #ifdef USE_CUSOLVER_64_BIT #define CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t) \ cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \ const scalar_t *A, int64_t lda, const scalar_t *tau, \ size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost template void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf_bufferSize: not implemented"); } template <> void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(float)); template <> void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(double)); template <> void xgeqrf_bufferSize>( CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex)); template <> void xgeqrf_bufferSize>( CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex)); #define CUDASOLVER_XGEQRF_ARGTYPES(scalar_t) \ cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \ scalar_t *A, int64_t lda, scalar_t *tau, scalar_t *bufferOnDevice, \ size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost, \ size_t workspaceInBytesOnHost, int *info template void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf: not implemented"); } template <> void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(float)); template <> void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(double)); template <> void xgeqrf>( CUDASOLVER_XGEQRF_ARGTYPES(c10::complex)); template <> void xgeqrf>( CUDASOLVER_XGEQRF_ARGTYPES(c10::complex)); #define CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverDnParams_t params, \ cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, \ const scalar_t *A, int64_t lda, const value_t *W, \ size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost template void xsyevd_bufferSize( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xsyevd_bufferSize: not implemented"); } template <> void xsyevd_bufferSize( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(float, float)); template <> void xsyevd_bufferSize( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(double, double)); template <> void xsyevd_bufferSize, float>( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex, float)); template <> void xsyevd_bufferSize, double>( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, cusolverDnParams_t params, \ cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, scalar_t *A, \ int64_t lda, value_t *W, scalar_t *bufferOnDevice, \ size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost, \ size_t workspaceInBytesOnHost, int *info template void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xsyevd: not implemented"); } template <> void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(float, float)); template <> void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(double, double)); template <> void xsyevd, float>( CUDASOLVER_XSYEVD_ARGTYPES(c10::complex, float)); template <> void xsyevd, double>( CUDASOLVER_XSYEVD_ARGTYPES(c10::complex, double)); #endif // USE_CUSOLVER_64_BIT #ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED #define CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, \ cusolverDnParams_t params, \ cusolverEigMode_t jobz, \ cublasFillMode_t uplo, \ int64_t n, \ const scalar_t *A, \ int64_t lda, \ const value_t *W, \ size_t *workspaceInBytesOnDevice, \ size_t *workspaceInBytesOnHost, \ int64_t batchSize template void xsyevBatched_bufferSize( CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xsyevBatched_bufferSize: not implemented"); } template <> void xsyevBatched_bufferSize( CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); template <> void xsyevBatched_bufferSize( CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(double, double)); template <> void xsyevBatched_bufferSize, float>( CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, float)); template <> void xsyevBatched_bufferSize, double>( CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, double)); #define CUDASOLVER_XSYEV_BATCHED_ARGTYPES(scalar_t, value_t) \ cusolverDnHandle_t handle, \ cusolverDnParams_t params, \ cusolverEigMode_t jobz, \ cublasFillMode_t uplo, \ int64_t n, \ scalar_t *A, \ int64_t lda, \ value_t *W, \ void *bufferOnDevice, \ size_t workspaceInBytesOnDevice, \ void *bufferOnHost, \ size_t workspaceInBytesOnHost, \ int *info, \ int64_t batchSize template void xsyevBatched(CUDASOLVER_XSYEV_BATCHED_ARGTYPES(scalar_t, value_t)) { static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xsyevBatched: not implemented"); } template <> void xsyevBatched( CUDASOLVER_XSYEV_BATCHED_ARGTYPES(float, float)); template <> void xsyevBatched( CUDASOLVER_XSYEV_BATCHED_ARGTYPES(double, double)); template <> void xsyevBatched, float>( CUDASOLVER_XSYEV_BATCHED_ARGTYPES(c10::complex, float)); template <> void xsyevBatched, double>( CUDASOLVER_XSYEV_BATCHED_ARGTYPES(c10::complex, double)); #endif // USE_CUSOLVER_64_BIT_XSYEV_BATCHED } // namespace solver } // namespace cuda } // namespace at #endif // CUDART_VERSION