diff --git a/src/lapack/backends/cusolver/cusolver_helper.hpp b/src/lapack/backends/cusolver/cusolver_helper.hpp index 0991b0efc..10cf12cff 100644 --- a/src/lapack/backends/cusolver/cusolver_helper.hpp +++ b/src/lapack/backends/cusolver/cusolver_helper.hpp @@ -181,6 +181,15 @@ class cuda_error : virtual public std::runtime_error { throw cusolver_error(std::string(name) + std::string(" : "), err); \ } +#define CUSOLVER_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != CUSOLVER_STATUS_SUCCESS) { \ + throw cusolver_error(std::string(name) + std::string(" : "), err); \ + } \ + cudaStream_t currentStreamId; \ + CUSOLVER_ERROR_FUNC(cusolverDnGetStream, err, handle, ¤tStreamId); \ + cuStreamSynchronize(currentStreamId); + inline cusolverEigType_t get_cusolver_itype(std::int64_t itype) { switch (itype) { case 1: return CUSOLVER_EIG_TYPE_1; diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 4fe2c30ab..af7600df1 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -57,8 +57,8 @@ inline void gebrd(const char *func_name, Func func, sycl::queue &queue, std::int auto taup_ = sc.get_mem(taup_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, + tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); } @@ -117,8 +117,8 @@ inline void geqrf(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -164,8 +164,8 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv32_, - devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, + scratch_, ipiv32_, devInfo_); }); }); @@ -250,8 +250,9 @@ inline void getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = sc.get_mem(ipiv_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n, - nrhs, a_, lda, ipiv_, b_, ldb, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, + b_, ldb, nullptr); }); }); } @@ -299,9 +300,11 @@ inline void gesvd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_jobsvd(jobu), - get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, - scratch_, scratchpad_size, nullptr, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_jobsvd(jobu), + get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, + ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -345,9 +348,9 @@ inline void heevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), - get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, - scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -390,9 +393,10 @@ inline void hegvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), - get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_itype(itype), get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -437,8 +441,9 @@ inline void hetrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -487,8 +492,9 @@ inline void orgbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, - a_, lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_generate(vec), m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -522,8 +528,8 @@ inline void orgqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -557,8 +563,9 @@ inline void orgtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -596,9 +603,11 @@ inline void ormtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, - lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_fill_mode(uplo), + get_cublas_operation(trans), m, n, a_, lda, tau_, c_, + ldc, scratch_, scratchpad_size, nullptr); }); }); } @@ -650,9 +659,10 @@ inline void ormqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_operation(trans), m, n, k, a_, lda, tau_, + c_, ldc, scratch_, scratchpad_size, nullptr); }); }); } @@ -688,8 +698,9 @@ inline void potrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -726,8 +737,9 @@ inline void potri(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -763,8 +775,9 @@ inline void potrs(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, - a_, lda, b_, ldb, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, + ldb, nullptr); }); }); } @@ -803,9 +816,9 @@ inline void syevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), - get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, - scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -846,9 +859,10 @@ inline void sygvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), - get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_itype(itype), get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -892,8 +906,9 @@ inline void sytrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -940,8 +955,9 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, ipiv32_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, ipiv32_, + scratch_, scratchpad_size, devInfo_); }); }); @@ -1015,8 +1031,9 @@ inline void ungbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, - a_, lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_generate(vec), m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -1050,8 +1067,8 @@ inline void ungqr(const char *func_name, Func func, sycl::queue &queue, std::int auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -1085,8 +1102,9 @@ inline void ungtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); } @@ -1138,9 +1156,10 @@ inline void unmqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_operation(trans), m, n, k, a_, lda, tau_, + c_, ldc, scratch_, scratchpad_size, nullptr); }); }); } @@ -1179,9 +1198,11 @@ inline void unmtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto c_ = sc.get_mem(c_acc); auto scratch_ = sc.get_mem(scratch_acc); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, - lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_fill_mode(uplo), + get_cublas_operation(trans), m, n, a_, lda, tau_, c_, + ldc, scratch_, scratchpad_size, nullptr); }); }); } @@ -1229,8 +1250,8 @@ inline sycl::event gebrd(const char *func_name, Func func, sycl::queue &queue, s auto taup_ = reinterpret_cast(taup); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, taup_, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, + tauq_, taup_, scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1291,8 +1312,8 @@ inline sycl::event geqrf(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1340,8 +1361,8 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s auto scratch_ = reinterpret_cast(scratchpad); auto ipiv_ = reinterpret_cast(ipiv32); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, scratch_, ipiv_, - devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, + scratch_, ipiv_, devInfo_); }); }); @@ -1433,8 +1454,9 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n, - nrhs, a_, lda, ipiv_, b_, ldb, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_operation(trans), n, nrhs, a_, lda, ipiv_, + b_, ldb, nullptr); }); }); @@ -1486,9 +1508,11 @@ inline sycl::event gesvd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; // rwork is set to nullptr. If set it is filled with information from the superdiagonal. - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_jobsvd(jobu), - get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, ldvt, - scratch_, scratchpad_size, nullptr, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_jobsvd(jobu), + get_cusolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, + ldu, vt_, ldvt, scratch_, scratchpad_size, nullptr, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1534,9 +1558,9 @@ inline sycl::event heevd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), - get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, - scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1581,9 +1605,10 @@ inline sycl::event hegvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), - get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, scratchpad_size, devInfo); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_itype(itype), get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, + scratch_, scratchpad_size, devInfo); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1629,8 +1654,9 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1684,8 +1710,9 @@ inline sycl::event orgbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, - a_, lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_generate(vec), m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1723,8 +1750,8 @@ inline sycl::event orgqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1761,8 +1788,9 @@ inline sycl::event orgtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1802,9 +1830,11 @@ inline sycl::event ormtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, - lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_fill_mode(uplo), + get_cublas_operation(trans), m, n, a_, lda, tau_, c_, + ldc, scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1858,9 +1888,10 @@ inline sycl::event ormqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_operation(trans), m, n, k, a_, lda, tau_, + c_, ldc, scratch_, scratchpad_size, nullptr); }); }); return done; @@ -1900,8 +1931,9 @@ inline sycl::event potrf(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1943,8 +1975,9 @@ inline sycl::event potri(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1986,8 +2019,9 @@ inline sycl::event potrs(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, nrhs, - a_, lda, b_, ldb, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, nrhs, a_, lda, b_, + ldb, nullptr); }); }); return done; @@ -2029,9 +2063,9 @@ inline sycl::event syevd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_job(jobz), - get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, - scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, w_, scratch_, + scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -2075,9 +2109,10 @@ inline sycl::event sygvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cusolver_itype(itype), - get_cusolver_job(jobz), get_cublas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, scratchpad_size, devInfo); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cusolver_itype(itype), get_cusolver_job(jobz), + get_cublas_fill_mode(uplo), n, a_, lda, b_, ldb, w_, + scratch_, scratchpad_size, devInfo); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -2121,8 +2156,9 @@ inline sycl::event sytrd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, d_, e_, tau_, + scratch_, scratchpad_size, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -2171,8 +2207,9 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto devInfo_ = reinterpret_cast(devInfo); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, ipiv_, scratch_, scratchpad_size, devInfo_); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, ipiv_, + scratch_, scratchpad_size, devInfo_); }); }); @@ -2255,8 +2292,9 @@ inline sycl::event ungbr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_generate(vec), m, n, k, - a_, lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_generate(vec), m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -2294,8 +2332,8 @@ inline sycl::event ungqr(const char *func_name, Func func, sycl::queue &queue, s auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_, scratch_, - scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -2332,8 +2370,9 @@ inline sycl::event ungtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo), n, a_, - lda, tau_, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_fill_mode(uplo), n, a_, lda, tau_, + scratch_, scratchpad_size, nullptr); }); }); return done; @@ -2387,9 +2426,10 @@ inline sycl::event unmqr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc, - scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_operation(trans), m, n, k, a_, lda, tau_, + c_, ldc, scratch_, scratchpad_size, nullptr); }); }); return done; @@ -2431,9 +2471,11 @@ inline sycl::event unmtr(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); auto scratch_ = reinterpret_cast(scratchpad); cusolverStatus_t err; - CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_side_mode(side), - get_cublas_fill_mode(uplo), get_cublas_operation(trans), m, n, a_, - lda, tau_, c_, ldc, scratch_, scratchpad_size, nullptr); + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, + get_cublas_side_mode(side), + get_cublas_fill_mode(uplo), + get_cublas_operation(trans), m, n, a_, lda, tau_, c_, + ldc, scratch_, scratchpad_size, nullptr); }); }); return done;