diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index b2d3f8e89..57b9f4a88 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -29,32 +29,53 @@ namespace cusolver { // BATCH BUFFER API -void geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &tau, - std::int64_t stride_tau, std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "geqrf_batch"); -} -void geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &tau, - std::int64_t stride_tau, std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "geqrf_batch"); -} -void geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &tau, std::int64_t stride_tau, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "geqrf_batch"); -} -void geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &tau, std::int64_t stride_tau, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "geqrf_batch"); +template +inline void geqrf_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &tau, std::int64_t stride_tau, std::int64_t batch_size, + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto tau_acc = tau.template get_access(cgh); + auto scratch_acc = scratchpad.template get_access(cgh); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto tau_ = sc.get_mem(tau_acc); + auto scratch_ = sc.get_mem(scratch_acc); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); } + +#define GEQRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, \ + std::int64_t lda, std::int64_t stride_a, sycl::buffer &tau, \ + std::int64_t stride_tau, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return geqrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size); \ + } + +GEQRF_STRIDED_BATCH_LAUNCHER(float, cusolverDnSgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER(double, cusolverDnDgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZgeqrf) + +#undef GEQRF_STRIDED_BATCH_LAUNCHER + void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, std::int64_t batch_size, sycl::buffer &scratchpad, @@ -79,251 +100,642 @@ void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { throw unimplemented("lapack", "getri_batch"); } -void getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &ipiv, std::int64_t stride_ipiv, sycl::buffer &b, - std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, - sycl::buffer &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrs_batch"); -} -void getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrs_batch"); -} -void getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, sycl::buffer> &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - sycl::buffer> &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrs_batch"); -} -void getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, sycl::buffer> &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - sycl::buffer> &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrs_batch"); -} -void getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrf_batch"); -} -void getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - sycl::buffer &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrf_batch"); -} -void getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrf_batch"); -} -void getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getrf_batch"); -} -void orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &tau, std::int64_t stride_tau, std::int64_t batch_size, - sycl::buffer &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "orgqr_batch"); -} -void orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &tau, std::int64_t stride_tau, std::int64_t batch_size, - sycl::buffer &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "orgqr_batch"); -} -void potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, sycl::buffer &a, - std::int64_t lda, std::int64_t stride_a, std::int64_t batch_size, - sycl::buffer &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrf_batch"); -} -void potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrf_batch"); -} -void potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrf_batch"); -} -void potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrf_batch"); -} -void potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, - sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrs_batch"); + +template +inline void getrs_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &ipiv, std::int64_t stride_ipiv, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, nrhs, lda, ldb, stride_ipiv, stride_b, batch_size, scratchpad_size); + + // cuSolver legacy api does not accept 64-bit ints. + // To get around the limitation. + // Create new buffer and convert 64-bit values. + std::uint64_t ipiv_size = stride_ipiv * batch_size; + sycl::buffer ipiv32(sycl::range<1>{ ipiv_size }); + + queue.submit([&](sycl::handler &cgh) { + auto ipiv32_acc = ipiv32.template get_access(cgh); + auto ipiv_acc = ipiv.template get_access(cgh); + cgh.parallel_for(sycl::range<1>{ ipiv_size }, [=](sycl::id<1> index) { + ipiv32_acc[index] = static_cast(ipiv_acc[index]); + }); + }); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto ipiv_acc = ipiv32.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto ipiv_ = sc.get_mem(ipiv_acc); + auto b_ = sc.get_mem(b_acc); + cusolverStatus_t err; + + // Does not use scratch so call cuSolver asynchronously and sync at end + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n, + nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, + b_ + stride_b * i, ldb, nullptr); + } + CUSOLVER_SYNC(err, handle) + }); + }); } -void potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, - sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrs_batch"); + +#define GETRS_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, \ + std::int64_t nrhs, sycl::buffer &a, std::int64_t lda, \ + std::int64_t stride_a, sycl::buffer &ipiv, \ + std::int64_t stride_ipiv, sycl::buffer &b, std::int64_t ldb, \ + std::int64_t stride_b, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return getrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, trans, n, nrhs, a, lda, \ + stride_a, ipiv, stride_ipiv, b, ldb, stride_b, batch_size, scratchpad, \ + scratchpad_size); \ + } + +GETRS_STRIDED_BATCH_LAUNCHER(float, cusolverDnSgetrs) +GETRS_STRIDED_BATCH_LAUNCHER(double, cusolverDnDgetrs) +GETRS_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCgetrs) +GETRS_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZgetrs) + +#undef GETRS_STRIDED_BATCH_LAUNCHER + +template +inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &ipiv, std::int64_t stride_ipiv, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + // cuSolver legacy api does not accept 64-bit ints. + // To get around the limitation. + // Create new buffer with 32-bit ints then copy over results + std::uint64_t ipiv_size = stride_ipiv * batch_size; + sycl::buffer ipiv32(sycl::range<1>{ ipiv_size }); + sycl::buffer devInfo{ batch_size }; + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto ipiv32_acc = ipiv32.template get_access(cgh); + auto devInfo_acc = devInfo.template get_access(cgh); + auto scratch_acc = scratchpad.template get_access(cgh); + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto ipiv_ = sc.get_mem(ipiv32_acc); + auto devInfo_ = sc.get_mem(devInfo_acc); + auto scratch_ = sc.get_mem(scratch_acc); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (std::int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i); + } + }); + }); + + // Copy from 32-bit USM to 64-bit + queue.submit([&](sycl::handler &cgh) { + auto ipiv32_acc = ipiv32.template get_access(cgh); + auto ipiv_acc = ipiv.template get_access(cgh); + cgh.parallel_for(sycl::range<1>{ ipiv_size }, + [=](sycl::id<1> index) { ipiv_acc[index] = ipiv32_acc[index]; }); + }); + + lapack_info_check(queue, devInfo, __func__, func_name, batch_size); } -void potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrs_batch"); + +#define GETRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, sycl::buffer &a, \ + std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, \ + std::int64_t stride_ipiv, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return getrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, stride_a, \ + ipiv, stride_ipiv, batch_size, scratchpad, scratchpad_size); \ + } + +GETRF_STRIDED_BATCH_LAUNCHER(float, cusolverDnSgetrf) +GETRF_STRIDED_BATCH_LAUNCHER(double, cusolverDnDgetrf) +GETRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCgetrf) +GETRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZgetrf) + +#undef GETRF_STRIDED_BATCH_LAUNCHER + +template +inline void orgqr_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &tau, std::int64_t stride_tau, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto tau_acc = tau.template get_access(cgh); + auto scratch_acc = scratchpad.template get_access(cgh); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto tau_ = sc.get_mem(tau_acc); + auto scratch_ = sc.get_mem(scratch_acc); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); } -void potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "potrs_batch"); + +#define ORGQR_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, \ + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, \ + sycl::buffer &tau, std::int64_t stride_tau, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return orgqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size); \ + } + +ORGQR_STRIDED_BATCH_LAUNCHER(float, cusolverDnSorgqr) +ORGQR_STRIDED_BATCH_LAUNCHER(double, cusolverDnDorgqr) + +#undef ORGQR_STRIDED_BATCH_LAUNCHER + +template +inline void potrf_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::uplo uplo, std::int64_t n, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, std::int64_t batch_size, + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, batch_size, scratchpad_size); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + CUdeviceptr a_dev; + CUresult cuda_result; + cusolverStatus_t err; + + auto a_ = sc.get_mem(a_acc); + + // Transform ptr and stride to list of ptr's + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + + auto **a_dev_ = reinterpret_cast(a_dev); + + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); + + free(a_batched); + cuMemFree(a_dev); + }); + }); } -void ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &tau, std::int64_t stride_tau, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "ungqr_batch"); + +// Scratchpad memory not needed as parts of buffer a is used as workspace memory +#define POTRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, \ + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t batch_size, sycl::buffer &scratchpad, \ + std::int64_t scratchpad_size) { \ + return potrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, a, lda, stride_a, \ + batch_size, scratchpad, scratchpad_size); \ + } + +POTRF_STRIDED_BATCH_LAUNCHER(float, cusolverDnSpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER(double, cusolverDnDpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZpotrfBatched) + +#undef POTRF_STRIDED_BATCH_LAUNCHER + +template +inline void potrs_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, nrhs, lda, ldb, stride_a, stride_b, batch_size, scratchpad_size); + + // cuSolver function only supports nrhs = 1 + if (nrhs != 1) + throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1"); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + CUdeviceptr a_dev, b_dev; + cusolverStatus_t err; + CUresult cuda_result; + + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + + // Transform ptr and stride to list of ptr's + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + cuDataType **b_batched = create_ptr_list_from_stride(b_, stride_b, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &b_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, b_dev, b_batched, sizeof(T *) * batch_size); + + auto **a_dev_ = reinterpret_cast(a_dev); + auto **b_dev_ = reinterpret_cast(b_dev); + + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, + (int)batch_size); + + free(a_batched); + free(b_batched); + cuMemFree(a_dev); + cuMemFree(b_dev); + }); + }); } -void ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - sycl::buffer> &a, std::int64_t lda, std::int64_t stride_a, - sycl::buffer> &tau, std::int64_t stride_tau, - std::int64_t batch_size, sycl::buffer> &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "ungqr_batch"); + +// Scratchpad memory not needed as parts of buffer a is used as workspace memory +#define POTRS_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, \ + std::int64_t nrhs, sycl::buffer &a, std::int64_t lda, \ + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, \ + std::int64_t stride_b, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return potrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, \ + stride_a, b, ldb, stride_b, batch_size, scratchpad, scratchpad_size); \ + } + +POTRS_STRIDED_BATCH_LAUNCHER(float, cusolverDnSpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER(double, cusolverDnDpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZpotrsBatched) + +#undef POTRS_STRIDED_BATCH_LAUNCHER + +template +inline void ungqr_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &tau, std::int64_t stride_tau, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto tau_acc = tau.template get_access(cgh); + auto scratch_acc = scratchpad.template get_access(cgh); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto tau_ = sc.get_mem(tau_acc); + auto scratch_ = sc.get_mem(scratch_acc); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); } +#define UNGQR_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, \ + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, \ + sycl::buffer &tau, std::int64_t stride_tau, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return ungqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size); \ + } + +UNGQR_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnCungqr) +UNGQR_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZungqr) + +#undef UNGQR_STRIDED_BATCH_LAUNCHER + // BATCH USM API -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, float *a, - std::int64_t lda, std::int64_t stride_a, float *tau, - std::int64_t stride_tau, std::int64_t batch_size, float *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, double *a, - std::int64_t lda, std::int64_t stride_a, double *tau, - std::int64_t stride_tau, std::int64_t batch_size, double *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex *tau, - std::int64_t stride_tau, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex *tau, - std::int64_t stride_tau, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, float **a, - std::int64_t *lda, float **tau, std::int64_t group_count, - std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, double **a, - std::int64_t *lda, double **tau, std::int64_t group_count, - std::int64_t *group_sizes, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, - std::complex **a, std::int64_t *lda, std::complex **tau, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event geqrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, - std::complex **a, std::int64_t *lda, std::complex **tau, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "geqrf_batch"); -} -sycl::event getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, float *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, float *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); -} -sycl::event getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, double *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, double *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); -} -sycl::event getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); -} -sycl::event getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); -} -sycl::event getrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, float **a, - std::int64_t *lda, std::int64_t **ipiv, std::int64_t group_count, - std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); +template +inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, T *a, std::int64_t lda, std::int64_t stride_a, + T *tau, std::int64_t stride_tau, std::int64_t batch_size, + T *scratchpad, std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); + + return done; } -sycl::event getrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, double **a, - std::int64_t *lda, std::int64_t **ipiv, std::int64_t group_count, - std::int64_t *group_sizes, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); + +#define GEQRF_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event geqrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, TYPE *a, \ + std::int64_t lda, std::int64_t stride_a, TYPE *tau, \ + std::int64_t stride_tau, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return geqrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size, \ + dependencies); \ + } + +GEQRF_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgeqrf) +GEQRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgeqrf) + +#undef GEQRF_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, T **a, std::int64_t *lda, T **tau, + std::int64_t group_count, std::int64_t *group_sizes, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(group_count, scratchpad_size); + for (int64_t i = 0; i < group_count; ++i) + overflow_check(m[i], n[i], lda[i], group_sizes[i]); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + int64_t global_id = 0; + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; + ++local_id, ++global_id) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + n[group_id], a_[global_id], lda[group_id], + tau_[global_id], scratch_, scratchpad_size, nullptr); + } + } + }); + }); + + return done; } -sycl::event getrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, - std::complex **a, std::int64_t *lda, std::int64_t **ipiv, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); + +#define GEQRF_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event geqrf_batch( \ + sycl::queue &queue, std::int64_t *m, std::int64_t *n, TYPE **a, std::int64_t *lda, \ + TYPE **tau, std::int64_t group_count, std::int64_t *group_sizes, TYPE *scratchpad, \ + std::int64_t scratchpad_size, const std::vector &dependencies) { \ + return geqrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, tau, \ + group_count, group_sizes, scratchpad, scratchpad_size, dependencies); \ + } + +GEQRF_BATCH_LAUNCHER_USM(float, cusolverDnSgeqrf) +GEQRF_BATCH_LAUNCHER_USM(double, cusolverDnDgeqrf) +GEQRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgeqrf) +GEQRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgeqrf) + +#undef GEQRF_BATCH_LAUNCHER_USM + +template +inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, T *a, std::int64_t lda, std::int64_t stride_a, + std::int64_t *ipiv, std::int64_t stride_ipiv, + std::int64_t batch_size, T *scratchpad, std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + // cuSolver legacy api does not accept 64-bit ints. + // To get around the limitation. + // Allocate memory with 32-bit ints then copy over results + std::uint64_t ipiv_size = stride_ipiv * batch_size; + int *ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue); + int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto devInfo_ = reinterpret_cast(devInfo); + auto scratchpad_ = reinterpret_cast(scratchpad); + auto ipiv_ = reinterpret_cast(ipiv32); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i, + lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i); + } + }); + }); + + // Copy from 32-bit USM to 64-bit + sycl::event done_casting = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for(sycl::range<1>{ ipiv_size }, + [=](sycl::id<1> index) { ipiv[index] = ipiv32[index]; }); + }); + + // Enqueue free memory, don't return event as not-neccessary for user to wait for ipiv32 being released + queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done_casting); + cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ipiv32, queue); }); + }); + + // lapack_info_check calls queue.wait() + lapack_info_check(queue, devInfo, __func__, func_name, batch_size); + sycl::free(devInfo, queue); + + return done_casting; } -sycl::event getrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, - std::complex **a, std::int64_t *lda, std::int64_t **ipiv, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrf_batch"); + +#define GETRF_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getrf_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, TYPE *a, \ + std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, \ + std::int64_t stride_ipiv, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return getrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, stride_a, \ + ipiv, stride_ipiv, batch_size, scratchpad, scratchpad_size, \ + dependencies); \ + } + +GETRF_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSgetrf) +GETRF_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDgetrf) +GETRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgetrf) +GETRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrf) + +#undef GETRF_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, T **a, std::int64_t *lda, + std::int64_t **ipiv, std::int64_t group_count, + std::int64_t *group_sizes, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + int64_t batch_size = 0; + overflow_check(group_count, scratchpad_size); + for (int64_t i = 0; i < group_count; ++i) { + overflow_check(m[i], n[i], lda[i], group_sizes[i]); + batch_size += group_sizes[i]; + } + + // cuSolver legacy api does not accept 64-bit ints. + // To get around the limitation. + // Allocate memory with 32-bit ints then copy over results + int **ipiv32 = (int **)malloc(sizeof(int *) * batch_size); + int64_t global_id = 0; + for (int64_t group_id = 0; group_id < group_count; ++group_id) + for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) + ipiv32[global_id] = (int *)malloc_device(sizeof(int) * n[group_id], queue); + int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto scratch_ = reinterpret_cast(scratchpad); + int64_t global_id = 0; + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; + ++local_id, ++global_id) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + n[group_id], a_[global_id], lda[group_id], scratch_, + ipiv32[global_id], devInfo + global_id); + } + } + }); + }); + + // Copy from 32-bit USM to 64-bit + std::vector casting_dependencies(group_count); + for (int64_t group_id = 0, global_id = 0; group_id < group_count; ++group_id) { + uint64_t ipiv_size = n[group_id]; + for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { + int64_t *d_ipiv = ipiv[global_id]; + int *d_ipiv32 = ipiv32[global_id]; + + sycl::event e = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for(sycl::range<1>{ ipiv_size }, + [=](sycl::id<1> index) { d_ipiv[index] = d_ipiv32[index]; }); + }); + casting_dependencies[group_id] = e; + } + } + + // Enqueue free memory + sycl::event done_freeing = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = casting_dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(casting_dependencies[i]); + } + cgh.host_task([=](sycl::interop_handle ih) { + for (int64_t global_id = 0; global_id < batch_size; ++global_id) + sycl::free(ipiv32[global_id], queue); + free(ipiv32); + }); + }); + + // lapack_info_check calls queue.wait() + lapack_info_check(queue, devInfo, __func__, func_name, batch_size); + sycl::free(devInfo, queue); + + return done_freeing; } + +#define GETRF_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getrf_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, TYPE **a, \ + std::int64_t *lda, std::int64_t **ipiv, std::int64_t group_count, \ + std::int64_t *group_sizes, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return getrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, a, lda, ipiv, \ + group_count, group_sizes, scratchpad, scratchpad_size, dependencies); \ + } + +GETRF_BATCH_LAUNCHER_USM(float, cusolverDnSgetrf) +GETRF_BATCH_LAUNCHER_USM(double, cusolverDnDgetrf) +GETRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgetrf) +GETRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrf) + +#undef GETRS_BATCH_LAUNCHER_USM + sycl::event getri_batch(sycl::queue &queue, std::int64_t n, float *a, std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, std::int64_t batch_size, float *scratchpad, std::int64_t scratchpad_size, @@ -376,123 +788,338 @@ sycl::event getri_batch(sycl::queue &queue, std::int64_t *n, std::complex &dependencies) { throw unimplemented("lapack", "getri_batch"); } -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, float *a, std::int64_t lda, std::int64_t stride_a, - std::int64_t *ipiv, std::int64_t stride_ipiv, float *b, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size, float *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, double *a, std::int64_t lda, std::int64_t stride_a, - std::int64_t *ipiv, std::int64_t stride_ipiv, double *b, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size, double *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, std::complex *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, std::complex *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, - std::int64_t nrhs, std::complex *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, std::complex *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, - std::int64_t *nrhs, float **a, std::int64_t *lda, std::int64_t **ipiv, - float **b, std::int64_t *ldb, std::int64_t group_count, - std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, - std::int64_t *nrhs, double **a, std::int64_t *lda, std::int64_t **ipiv, - double **b, std::int64_t *ldb, std::int64_t group_count, - std::int64_t *group_sizes, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, - std::int64_t *nrhs, std::complex **a, std::int64_t *lda, - std::int64_t **ipiv, std::complex **b, std::int64_t *ldb, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, - std::int64_t *nrhs, std::complex **a, std::int64_t *lda, - std::int64_t **ipiv, std::complex **b, std::int64_t *ldb, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getrs_batch"); -} -sycl::event orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - float *a, std::int64_t lda, std::int64_t stride_a, float *tau, - std::int64_t stride_tau, std::int64_t batch_size, float *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "orgqr_batch"); -} -sycl::event orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - double *a, std::int64_t lda, std::int64_t stride_a, double *tau, - std::int64_t stride_tau, std::int64_t batch_size, double *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "orgqr_batch"); -} -sycl::event orgqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, - float **a, std::int64_t *lda, float **tau, std::int64_t group_count, - std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "orgqr_batch"); -} -sycl::event orgqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, - double **a, std::int64_t *lda, double **tau, std::int64_t group_count, - std::int64_t *group_sizes, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "orgqr_batch"); + +template +inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, + T *a, std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, + std::int64_t stride_ipiv, T *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, nrhs, lda, ldb, stride_ipiv, stride_b, batch_size, scratchpad_size); + + // cuSolver legacy api does not accept 64-bit ints. + // To get around the limitation. + // Create new memory and convert 64-bit values. + std::uint64_t ipiv_size = stride_ipiv * batch_size; + int *ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue); + + auto done_casting = queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::range<1>{ ipiv_size }, [=](sycl::id<1> index) { + ipiv32[index] = static_cast(ipiv[index]); + }); + }); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + cgh.depends_on(done_casting); + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto ipiv_ = reinterpret_cast(ipiv32); + auto b_ = reinterpret_cast(b); + cusolverStatus_t err; + + // Does not use scratch so call cuSolver asynchronously and sync at end + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n, + nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i, + b_ + stride_b * i, ldb, nullptr); + } + CUSOLVER_SYNC(err, handle) + + sycl::free(ipiv32, queue); + }); + }); + + return done; } -sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, float *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t batch_size, - float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrf_batch"); + +#define GETRS_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getrs_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, \ + std::int64_t nrhs, TYPE *a, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t *ipiv, std::int64_t stride_ipiv, TYPE *b, \ + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, \ + TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return getrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, trans, n, nrhs, a, lda, \ + stride_a, ipiv, stride_ipiv, b, ldb, stride_b, batch_size, scratchpad, \ + scratchpad_size, dependencies); \ + } + +GETRS_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSgetrs) +GETRS_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDgetrs) +GETRS_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgetrs) +GETRS_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrs) + +#undef GETRS_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *nrhs, + T **a, std::int64_t *lda, std::int64_t **ipiv, T **b, + std::int64_t *ldb, std::int64_t group_count, + std::int64_t *group_sizes, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + int64_t batch_size = 0; + overflow_check(group_count, scratchpad_size); + for (int64_t i = 0; i < group_count; ++i) { + overflow_check(n[i], nrhs[i], lda[i], ldb[i], group_sizes[i]); + batch_size += group_sizes[i]; + } + + // cuSolver legacy api does not accept 64-bit ints. + // ipiv is an array of pointers in host memory, pointing to + // an array of 64-bit ints in device memory. Each vec of ipiv + // values need to be converted from 64-bit to 32-bit. The list + // must stay on host. + int **ipiv32 = (int **)malloc(sizeof(int *) * batch_size); + std::vector casting_dependencies(batch_size); + int64_t global_id = 0; + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id) { + uint64_t ipiv_size = n[group_id]; + int *d_group_ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue); + ipiv32[global_id] = d_group_ipiv32; + int64_t *d_group_ipiv = ipiv[global_id]; + + auto e = queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::range<1>{ ipiv_size }, [=](sycl::id<1> index) { + d_group_ipiv32[index] = static_cast(d_group_ipiv[index]); + }); + }); + casting_dependencies[global_id] = e; + } + } + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + for (int64_t i = 0; i < batch_size; i++) { + cgh.depends_on(casting_dependencies[i]); + } + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + cusolverStatus_t err; + int64_t global_id = 0; + + // Does not use scratch so call cuSolver asynchronously and sync at end + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; + ++local_id, ++global_id) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, + get_cublas_operation(trans[group_id]), n[group_id], + nrhs[group_id], a_[global_id], lda[group_id], + ipiv32[global_id], b_[global_id], ldb[group_id], nullptr); + } + } + CUSOLVER_SYNC(err, handle) + + for (int64_t i = 0; i < batch_size; ++i) + sycl::free(ipiv32[i], queue); + free(ipiv32); + }); + }); + + return done; } -sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, double *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t batch_size, - double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrf_batch"); + +#define GETRS_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getrs_batch( \ + sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *nrhs, \ + TYPE **a, std::int64_t *lda, std::int64_t **ipiv, TYPE **b, std::int64_t *ldb, \ + std::int64_t group_count, std::int64_t *group_sizes, TYPE *scratchpad, \ + std::int64_t scratchpad_size, const std::vector &dependencies) { \ + return getrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, trans, n, nrhs, a, lda, \ + ipiv, b, ldb, group_count, group_sizes, scratchpad, scratchpad_size, \ + dependencies); \ + } + +GETRS_BATCH_LAUNCHER_USM(float, cusolverDnSgetrs) +GETRS_BATCH_LAUNCHER_USM(double, cusolverDnDgetrs) +GETRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnCgetrs) +GETRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrs) + +#undef GETRS_BATCH_LAUNCHER_USM + +template +inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, std::int64_t k, T *a, std::int64_t lda, + std::int64_t stride_a, T *tau, std::int64_t stride_tau, + std::int64_t batch_size, T *scratchpad, std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); + + return done; } -sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::int64_t batch_size, std::complex *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrf_batch"); + +#define ORGQR_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event orgqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, \ + TYPE *a, std::int64_t lda, std::int64_t stride_a, TYPE *tau, \ + std::int64_t stride_tau, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return orgqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size, \ + dependencies); \ + } + +ORGQR_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSorgqr) +ORGQR_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDorgqr) + +#undef ORGQR_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *k, T **a, + std::int64_t *lda, T **tau, std::int64_t group_count, + std::int64_t *group_sizes, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(group_count, scratchpad_size); + for (int64_t i = 0; i < group_count; ++i) + overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + int64_t global_id = 0; + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; + ++local_id, ++global_id) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + n[group_id], k[group_id], a_[global_id], + lda[group_id], tau_[global_id], scratch_, + scratchpad_size, nullptr); + } + } + }); + }); + + return done; } -sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::int64_t batch_size, std::complex *scratchpad, - std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrf_batch"); + +#define ORGQR_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event orgqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, \ + TYPE **a, std::int64_t *lda, TYPE **tau, std::int64_t group_count, \ + std::int64_t *group_sizes, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return orgqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, tau, \ + group_count, group_sizes, scratchpad, scratchpad_size, dependencies); \ + } + +ORGQR_BATCH_LAUNCHER_USM(float, cusolverDnSorgqr) +ORGQR_BATCH_LAUNCHER_USM(double, cusolverDnDorgqr) + +#undef ORGQR_BATCH_LAUNCHER_USM + +template +inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::uplo uplo, std::int64_t n, T *a, std::int64_t lda, + std::int64_t stride_a, std::int64_t batch_size, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, batch_size, scratchpad_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + CUdeviceptr a_dev; + cusolverStatus_t err; + CUresult cuda_result; + + auto *a_ = reinterpret_cast(a); + + // Transform ptr and stride to list of ptr's + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + + auto **a_dev_ = reinterpret_cast(a_dev); + + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + (int)n, a_dev_, (int)lda, nullptr, (int)batch_size); + + free(a_batched); + cuMemFree(a_dev); + }); + }); + return done; } +// Scratchpad memory not needed as parts of buffer a is used as workspace memory +#define POTRF_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, TYPE *a, \ + std::int64_t lda, std::int64_t stride_a, std::int64_t batch_size, \ + TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return potrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, a, lda, stride_a, \ + batch_size, scratchpad, scratchpad_size, dependencies); \ + } + +POTRF_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCpotrfBatched) +POTRF_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZpotrfBatched) + +#undef POTRF_STRIDED_BATCH_LAUNCHER_USM + template inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, T **a, std::int64_t *lda, @@ -507,27 +1134,33 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu batch_size += group_sizes[i]; } - T **a_dev = (T **)malloc_device(sizeof(T *) * batch_size, queue); - auto done_cpy = - queue.submit([&](sycl::handler &h) { h.memcpy(a_dev, a, batch_size * sizeof(T *)); }); - auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { cgh.depends_on(dependencies[i]); } - cgh.depends_on(done_cpy); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); int64_t offset = 0; + CUdeviceptr a_dev; + CUresult cuda_result; cusolverStatus_t err; + + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a, sizeof(T *) * batch_size); + + auto **a_dev_ = reinterpret_cast(a_dev); + + // Does not use scratch so call cuSolver asynchronously and sync at end for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a_dev); CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo[i]), - (int)n[i], a_ + offset, (int)lda[i], nullptr, + (int)n[i], a_dev_ + offset, (int)lda[i], nullptr, (int)group_sizes[i]); offset += group_sizes[i]; } + CUSOLVER_SYNC(err, handle) + + cuMemFree(a_dev); }); }); return done; @@ -550,37 +1183,76 @@ POTRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnZpotrfBatched) #undef POTRF_BATCH_LAUNCHER_USM -sycl::event potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::int64_t nrhs, float *a, std::int64_t lda, std::int64_t stride_a, - float *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, - float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrs_batch"); -} -sycl::event potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::int64_t nrhs, double *a, std::int64_t lda, std::int64_t stride_a, - double *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, - double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrs_batch"); -} -sycl::event potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::int64_t nrhs, std::complex *a, std::int64_t lda, - std::int64_t stride_a, std::complex *b, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrs_batch"); -} -sycl::event potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, - std::int64_t nrhs, std::complex *a, std::int64_t lda, - std::int64_t stride_a, std::complex *b, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "potrs_batch"); +template +inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &queue, + oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, T *a, + std::int64_t lda, std::int64_t stride_a, T *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, nrhs, lda, ldb, stride_a, stride_b, batch_size, scratchpad_size); + + // cuSolver function only supports nrhs = 1 + if (nrhs != 1) + throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1"); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + CUresult cuda_result; + CUdeviceptr a_dev, b_dev; + auto *a_ = reinterpret_cast(a); + auto *b_ = reinterpret_cast(b); + cusolverStatus_t err; + + // Transform ptr and stride to list of ptr's + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + cuDataType **b_batched = create_ptr_list_from_stride(b_, stride_b, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &b_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, b_dev, b_batched, sizeof(T *) * batch_size); + + auto **a_dev_ = reinterpret_cast(a_dev); + auto **b_dev_ = reinterpret_cast(b_dev); + + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_fill_mode(uplo), + (int)n, (int)nrhs, a_dev_, (int)lda, b_dev_, ldb, nullptr, + (int)batch_size); + + free(a_batched); + free(b_batched); + cuMemFree(a_dev); + }); + }); + return done; } +// Scratchpad memory not needed as parts of buffer a is used as workspace memory +#define POTRS_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event potrs_batch( \ + sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, TYPE *a, \ + std::int64_t lda, std::int64_t stride_a, TYPE *b, std::int64_t ldb, std::int64_t stride_b, \ + std::int64_t batch_size, TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return potrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, \ + stride_a, b, ldb, stride_b, batch_size, scratchpad, scratchpad_size, \ + dependencies); \ + } + +POTRS_STRIDED_BATCH_LAUNCHER_USM(float, cusolverDnSpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER_USM(double, cusolverDnDpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCpotrsBatched) +POTRS_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZpotrsBatched) + +#undef POTRS_STRIDED_BATCH_LAUNCHER_USM + template inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, T **a, @@ -595,7 +1267,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(n[i], lda[i], group_sizes[i]); batch_size += group_sizes[i]; - // cusolver function only supports nrhs = 1 + // cuSolver function only supports nrhs = 1 if (nrhs[i] != 1) throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1"); @@ -621,6 +1293,8 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu auto handle = sc.get_handle(queue); int64_t offset = 0; cusolverStatus_t err; + + // Does not use scratch so call cuSolver asynchronously and sync at end for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a_dev); auto **b_ = reinterpret_cast(b_dev); @@ -630,90 +1304,174 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu b_ + offset, (int)ldb[i], info_, (int)group_sizes[i]); offset += group_sizes[i]; } + CUSOLVER_SYNC(err, handle) + }); + }); + return done; +} + +// Scratchpad memory not needed as parts of buffer a is used as workspace memory +#define POTRS_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event potrs_batch( \ + sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, \ + TYPE **a, std::int64_t *lda, TYPE **b, std::int64_t *ldb, std::int64_t group_count, \ + std::int64_t *group_sizes, TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return potrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, b, \ + ldb, group_count, group_sizes, scratchpad, scratchpad_size, \ + dependencies); \ + } + +POTRS_BATCH_LAUNCHER_USM(float, cusolverDnSpotrsBatched) +POTRS_BATCH_LAUNCHER_USM(double, cusolverDnDpotrsBatched) +POTRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnCpotrsBatched) +POTRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnZpotrsBatched) + +#undef POTRS_BATCH_LAUNCHER_USM + +template +inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, + std::int64_t n, std::int64_t k, T *a, std::int64_t lda, + std::int64_t stride_a, T *tau, std::int64_t stride_tau, + std::int64_t batch_size, T *scratchpad, std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t i = 0; i < batch_size; ++i) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_ + stride_a * i, + lda, tau_ + stride_tau * i, scratch_, scratchpad_size, + nullptr); + } + }); + }); + + return done; +} + +#define UNGQR_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, \ + TYPE *a, std::int64_t lda, std::int64_t stride_a, TYPE *tau, \ + std::int64_t stride_tau, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return ungqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, stride_a, \ + tau, stride_tau, batch_size, scratchpad, scratchpad_size, \ + dependencies); \ + } + +UNGQR_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnCungqr) +UNGQR_STRIDED_BATCH_LAUNCHER_USM(std::complex, cusolverDnZungqr) + +#undef UNGQR_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *k, T **a, + std::int64_t *lda, T **tau, std::int64_t group_count, + std::int64_t *group_sizes, T *scratchpad, + std::int64_t scratchpad_size, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(group_count, scratchpad_size); + for (int64_t i = 0; i < group_count; ++i) + overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); + + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto tau_ = reinterpret_cast(tau); + auto scratch_ = reinterpret_cast(scratchpad); + int64_t global_id = 0; + cusolverStatus_t err; + + // Uses scratch so sync between each cuSolver call + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + for (int64_t local_id = 0; local_id < group_sizes[group_id]; + ++local_id, ++global_id) { + CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id], + n[group_id], k[group_id], a_[global_id], + lda[group_id], tau_[global_id], scratch_, + scratchpad_size, nullptr); + } + } + }); + }); + + return done; +} + +#define UNGQR_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event ungqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, \ + TYPE **a, std::int64_t *lda, TYPE **tau, std::int64_t group_count, \ + std::int64_t *group_sizes, TYPE *scratchpad, \ + std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return ungqr_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, a, lda, tau, \ + group_count, group_sizes, scratchpad, scratchpad_size, dependencies); \ + } + +UNGQR_BATCH_LAUNCHER_USM(std::complex, cusolverDnCungqr) +UNGQR_BATCH_LAUNCHER_USM(std::complex, cusolverDnZungqr) + +#undef UNGQR_BATCH_LAUNCHER_USM + +// BATCH SCRATCHPAD API + +template +inline void getrf_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t m, std::int64_t n, std::int64_t lda, + std::int64_t stride_a, std::int64_t stride_ipiv, + std::int64_t batch_size, int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cusolverStatus_t err; + + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); }); - return done; + e.wait(); } -// Scratchpad memory not needed as parts of buffer a is used as workspace memory -#define POTRS_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ - sycl::event potrs_batch( \ - sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, \ - TYPE **a, std::int64_t *lda, TYPE **b, std::int64_t *ldb, std::int64_t group_count, \ - std::int64_t *group_sizes, TYPE *scratchpad, std::int64_t scratchpad_size, \ - const std::vector &dependencies) { \ - return potrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, b, \ - ldb, group_count, group_sizes, scratchpad, scratchpad_size, \ - dependencies); \ +#define GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t getrf_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t m, std::int64_t n, std::int64_t lda, \ + std::int64_t stride_a, std::int64_t stride_ipiv, std::int64_t batch_size) { \ + int scratch_size; \ + getrf_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, lda, \ + stride_a, stride_ipiv, batch_size, &scratch_size); \ + return scratch_size; \ } -POTRS_BATCH_LAUNCHER_USM(float, cusolverDnSpotrsBatched) -POTRS_BATCH_LAUNCHER_USM(double, cusolverDnDpotrsBatched) -POTRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnCpotrsBatched) -POTRS_BATCH_LAUNCHER_USM(std::complex, cusolverDnZpotrsBatched) - -#undef POTRS_BATCH_LAUNCHER_USM - -sycl::event ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::complex *tau, std::int64_t stride_tau, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "ungqr_batch"); -} -sycl::event ungqr_batch(sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, - std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::complex *tau, std::int64_t stride_tau, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "ungqr_batch"); -} -sycl::event ungqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, - std::complex **a, std::int64_t *lda, std::complex **tau, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "ungqr_batch"); -} -sycl::event ungqr_batch(sycl::queue &queue, std::int64_t *m, std::int64_t *n, std::int64_t *k, - std::complex **a, std::int64_t *lda, std::complex **tau, - std::int64_t group_count, std::int64_t *group_sizes, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "ungqr_batch"); -} +GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(float, cusolverDnSgetrf_bufferSize) +GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(double, cusolverDnDgetrf_bufferSize) +GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnCgetrf_bufferSize) +GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) -// BATCH SCRATCHPAD API +#undef GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH -template <> -std::int64_t getrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, std::int64_t batch_size) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); -} -template <> -std::int64_t getrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); -} -template <> -std::int64_t getrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t m, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); -} -template <> -std::int64_t getrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t m, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); -} template <> std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t n, std::int64_t lda, std::int64_t stride_a, @@ -743,174 +1501,201 @@ std::int64_t getri_batch_scratchpad_size>(sycl::queue &queu std::int64_t batch_size) { throw unimplemented("lapack", "getri_batch_scratchpad_size"); } -template <> -std::int64_t getrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::transpose trans, - std::int64_t n, std::int64_t nrhs, std::int64_t lda, - std::int64_t stride_a, std::int64_t stride_ipiv, - std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::transpose trans, - std::int64_t n, std::int64_t nrhs, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, - std::int64_t lda, std::int64_t stride_a, std::int64_t stride_ipiv, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, - std::int64_t lda, std::int64_t stride_a, std::int64_t stride_ipiv, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_tau, std::int64_t batch_size) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_tau, std::int64_t batch_size) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t m, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_tau, - std::int64_t batch_size) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t m, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_tau, - std::int64_t batch_size) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t potrf_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, std::int64_t batch_size) { - throw unimplemented("lapack", "potrf_batch_scratchpad_size"); -} -template <> -std::int64_t potrf_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, std::int64_t batch_size) { - throw unimplemented("lapack", "potrf_batch_scratchpad_size"); -} -template <> -std::int64_t potrf_batch_scratchpad_size>(sycl::queue &queue, - oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t batch_size) { - throw unimplemented("lapack", "potrf_batch_scratchpad_size"); -} -template <> -std::int64_t potrf_batch_scratchpad_size>(sycl::queue &queue, - oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t lda, - std::int64_t stride_a, - std::int64_t batch_size) { - throw unimplemented("lapack", "potrf_batch_scratchpad_size"); -} -template <> -std::int64_t potrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t nrhs, std::int64_t lda, - std::int64_t stride_a, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "potrs_batch_scratchpad_size"); -} -template <> -std::int64_t potrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::uplo uplo, - std::int64_t n, std::int64_t nrhs, - std::int64_t lda, std::int64_t stride_a, - std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size) { - throw unimplemented("lapack", "potrs_batch_scratchpad_size"); -} -template <> -std::int64_t potrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, std::int64_t lda, - std::int64_t stride_a, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "potrs_batch_scratchpad_size"); -} -template <> -std::int64_t potrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, std::int64_t lda, - std::int64_t stride_a, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { - throw unimplemented("lapack", "potrs_batch_scratchpad_size"); -} -template <> -std::int64_t orgqr_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t k, std::int64_t lda, - std::int64_t stride_a, std::int64_t stride_tau, - std::int64_t batch_size) { - throw unimplemented("lapack", "orgqr_batch_scratchpad_size"); -} -template <> -std::int64_t orgqr_batch_scratchpad_size(sycl::queue &queue, std::int64_t m, std::int64_t n, - std::int64_t k, std::int64_t lda, - std::int64_t stride_a, std::int64_t stride_tau, - std::int64_t batch_size) { - throw unimplemented("lapack", "orgqr_batch_scratchpad_size"); -} -template <> -std::int64_t ungqr_batch_scratchpad_size>( - sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, std::int64_t lda, - std::int64_t stride_a, std::int64_t stride_tau, std::int64_t batch_size) { - throw unimplemented("lapack", "ungqr_batch_scratchpad_size"); -} -template <> -std::int64_t ungqr_batch_scratchpad_size>( - sycl::queue &queue, std::int64_t m, std::int64_t n, std::int64_t k, std::int64_t lda, - std::int64_t stride_a, std::int64_t stride_tau, std::int64_t batch_size) { - throw unimplemented("lapack", "ungqr_batch_scratchpad_size"); -} -template <> -std::int64_t getrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); +// cusolverDnXgetrs does not use scratchpad memory +#define GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getrs_batch_scratchpad_size( \ + sycl::queue & queue, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, \ + std::int64_t lda, std::int64_t stride_a, std::int64_t stride_ipiv, std::int64_t ldb, \ + std::int64_t stride_b, std::int64_t batch_size) { \ + return 0; \ + } + +GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(float) +GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(double) +GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) +GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) + +#undef GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH + +template +inline void geqrf_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t m, std::int64_t n, std::int64_t lda, + std::int64_t stride_a, std::int64_t stride_tau, + std::int64_t batch_size, int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cusolverStatus_t err; + + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); + }); + }); + e.wait(); } -template <> -std::int64_t getrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); + +#define GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t geqrf_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t m, std::int64_t n, std::int64_t lda, \ + std::int64_t stride_a, std::int64_t stride_tau, std::int64_t batch_size) { \ + int scratch_size; \ + geqrf_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, lda, \ + stride_a, stride_tau, batch_size, &scratch_size); \ + return scratch_size; \ + } + +GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH(float, cusolverDnSgeqrf_bufferSize) +GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH(double, cusolverDnDgeqrf_bufferSize) +GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnCgeqrf_bufferSize) +GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnZgeqrf_bufferSize) + +#undef GEQRF_STRIDED_BATCH_LAUNCHER_SCRATCH + +// cusolverDnXpotrfBatched does not use scratchpad memory +#define POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t potrf_batch_scratchpad_size( \ + sycl::queue & queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t lda, \ + std::int64_t stride_a, std::int64_t batch_size) { \ + return 0; \ + } + +POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH(float) +POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH(double) +POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) +POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) + +#undef POTRF_STRIDED_BATCH_LAUNCHER_SCRATCH + +// cusolverDnXpotrsBatched does not use scratchpad memory +#define POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t potrs_batch_scratchpad_size( \ + sycl::queue & queue, oneapi::mkl::uplo uplo, std::int64_t n, std::int64_t nrhs, \ + std::int64_t lda, std::int64_t stride_a, std::int64_t ldb, std::int64_t stride_b, \ + std::int64_t batch_size) { \ + return 0; \ + } + +POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH(float) +POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH(double) +POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) +POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) + +#undef POTRS_STRIDED_BATCH_LAUNCHER_SCRATCH + +template +inline void orgqr_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t m, std::int64_t n, std::int64_t k, + std::int64_t lda, std::int64_t stride_a, + std::int64_t stride_tau, std::int64_t batch_size, + int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cusolverStatus_t err; + + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + scratch_size); + }); + }); + e.wait(); } -template <> -std::int64_t getrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); + +#define ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t orgqr_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t m, std::int64_t n, std::int64_t k, std::int64_t lda, \ + std::int64_t stride_a, std::int64_t stride_tau, std::int64_t batch_size) { \ + int scratch_size; \ + orgqr_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, lda, \ + stride_a, stride_tau, batch_size, &scratch_size); \ + return scratch_size; \ + } + +ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(float, cusolverDnSorgqr_bufferSize) +ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(double, cusolverDnDorgqr_bufferSize) + +#undef ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH + +template +inline void ungqr_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t m, std::int64_t n, std::int64_t k, + std::int64_t lda, std::int64_t stride_a, + std::int64_t stride_tau, std::int64_t batch_size, + int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + cusolverStatus_t err; + + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, nullptr, lda, nullptr, + scratch_size); + }); + }); + e.wait(); } -template <> -std::int64_t getrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrf_batch_scratchpad_size"); + +#define ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t ungqr_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t m, std::int64_t n, std::int64_t k, std::int64_t lda, \ + std::int64_t stride_a, std::int64_t stride_tau, std::int64_t batch_size) { \ + int scratch_size; \ + ungqr_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, lda, \ + stride_a, stride_tau, batch_size, &scratch_size); \ + return scratch_size; \ + } + +ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnCungqr_bufferSize) +ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnZungqr_bufferSize) + +#undef ORGQR_STRIDED_BATCH_LAUNCHER_SCRATCH + +template +inline void getrf_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *lda, + std::int64_t group_count, std::int64_t *group_sizes, + int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int group_scratch_size = 0; + *scratch_size = 0; + cusolverStatus_t err; + + // Get the maximum scratch_size across the groups + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m[group_id], n[group_id], + nullptr, lda[group_id], &group_scratch_size); + *scratch_size = + group_scratch_size > *scratch_size ? group_scratch_size : *scratch_size; + } + }); + }); + e.wait(); } + +#define GETRF_GROUP_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t getrf_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t * m, std::int64_t * n, std::int64_t * lda, \ + std::int64_t group_count, std::int64_t * group_sizes) { \ + int scratch_size; \ + getrf_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, lda, \ + group_count, group_sizes, &scratch_size); \ + return scratch_size; \ + } + +GETRF_GROUP_LAUNCHER_SCRATCH(float, cusolverDnSgetrf_bufferSize) +GETRF_GROUP_LAUNCHER_SCRATCH(double, cusolverDnDgetrf_bufferSize) +GETRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnCgetrf_bufferSize) +GETRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) + +#undef GETRF_GROUP_LAUNCHER_SCRATCH + template <> std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t *n, std::int64_t *lda, std::int64_t group_count, @@ -937,77 +1722,106 @@ std::int64_t getri_batch_scratchpad_size>(sycl::queue &queu std::int64_t *group_sizes) { throw unimplemented("lapack", "getri_batch_scratchpad_size"); } -template <> -std::int64_t getrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::transpose *trans, - std::int64_t *n, std::int64_t *nrhs, - std::int64_t *lda, std::int64_t *ldb, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size(sycl::queue &queue, oneapi::mkl::transpose *trans, - std::int64_t *n, std::int64_t *nrhs, - std::int64_t *lda, std::int64_t *ldb, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *nrhs, - std::int64_t *lda, std::int64_t *ldb, std::int64_t group_count, std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t getrs_batch_scratchpad_size>( - sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *nrhs, - std::int64_t *lda, std::int64_t *ldb, std::int64_t group_count, std::int64_t *group_sizes) { - throw unimplemented("lapack", "getrs_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t geqrf_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "geqrf_batch_scratchpad_size"); -} -template <> -std::int64_t orgqr_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *k, std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "orgqr_batch_scratchpad_size"); + +#define GETRS_GROUP_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getrs_batch_scratchpad_size( \ + sycl::queue & queue, oneapi::mkl::transpose * trans, std::int64_t * n, \ + std::int64_t * nrhs, std::int64_t * lda, std::int64_t * ldb, std::int64_t group_count, \ + std::int64_t * group_sizes) { \ + return 0; \ + } + +GETRS_GROUP_LAUNCHER_SCRATCH(float) +GETRS_GROUP_LAUNCHER_SCRATCH(double) +GETRS_GROUP_LAUNCHER_SCRATCH(std::complex) +GETRS_GROUP_LAUNCHER_SCRATCH(std::complex) + +#undef GETRS_GROUP_LAUNCHER_SCRATCH + +template +inline void geqrf_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *lda, + std::int64_t group_count, std::int64_t *group_sizes, + int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int group_scratch_size = 0; + *scratch_size = 0; + cusolverStatus_t err; + + // Get the maximum scratch_size across the groups + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m[group_id], n[group_id], + nullptr, lda[group_id], &group_scratch_size); + *scratch_size = + group_scratch_size > *scratch_size ? group_scratch_size : *scratch_size; + } + }); + }); + e.wait(); } -template <> -std::int64_t orgqr_batch_scratchpad_size(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *k, - std::int64_t *lda, std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "orgqr_batch_scratchpad_size"); + +#define GEQRF_GROUP_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t geqrf_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t * m, std::int64_t * n, std::int64_t * lda, \ + std::int64_t group_count, std::int64_t * group_sizes) { \ + int scratch_size; \ + geqrf_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, lda, \ + group_count, group_sizes, &scratch_size); \ + return scratch_size; \ + } + +GEQRF_GROUP_LAUNCHER_SCRATCH(float, cusolverDnSgeqrf_bufferSize) +GEQRF_GROUP_LAUNCHER_SCRATCH(double, cusolverDnDgeqrf_bufferSize) +GEQRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnCgeqrf_bufferSize) +GEQRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnZgeqrf_bufferSize) + +#undef GEQRF_GROUP_LAUNCHER_SCRATCH + +template +inline void orgqr_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + std::int64_t *lda, std::int64_t group_count, + std::int64_t *group_sizes, int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int group_scratch_size = 0; + *scratch_size = 0; + cusolverStatus_t err; + + // Get the maximum scratch_size across the groups + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m[group_id], n[group_id], + k[group_id], nullptr, lda[group_id], nullptr, + &group_scratch_size); + *scratch_size = + group_scratch_size > *scratch_size ? group_scratch_size : *scratch_size; + } + }); + }); + e.wait(); } +#define ORGQR_GROUP_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t orgqr_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t * m, std::int64_t * n, std::int64_t * k, \ + std::int64_t * lda, std::int64_t group_count, std::int64_t * group_sizes) { \ + int scratch_size; \ + orgqr_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, lda, \ + group_count, group_sizes, &scratch_size); \ + return scratch_size; \ + } + +ORGQR_GROUP_LAUNCHER_SCRATCH(float, cusolverDnSorgqr_bufferSize) +ORGQR_GROUP_LAUNCHER_SCRATCH(double, cusolverDnDorgqr_bufferSize) + +#undef ORGQR_GROUP_LAUNCHER_SCRATCH + // cusolverDnXpotrfBatched does not use scratchpad memory #define POTRF_GROUP_LAUNCHER_SCRATCH(TYPE) \ template <> \ @@ -1041,23 +1855,47 @@ POTRS_GROUP_LAUNCHER_SCRATCH(std::complex) #undef POTRS_GROUP_LAUNCHER_SCRATCH -template <> -std::int64_t ungqr_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *k, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "ungqr_batch_scratchpad_size"); -} -template <> -std::int64_t ungqr_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *m, - std::int64_t *n, std::int64_t *k, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "ungqr_batch_scratchpad_size"); +template +inline void ungqr_batch_scratchpad_size(const char *func_name, Func func, sycl::queue &queue, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + std::int64_t *lda, std::int64_t group_count, + std::int64_t *group_sizes, int *scratch_size) { + auto e = queue.submit([&](sycl::handler &cgh) { + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int group_scratch_size = 0; + *scratch_size = 0; + cusolverStatus_t err; + + // Get the maximum scratch_size across the groups + for (int64_t group_id = 0; group_id < group_count; ++group_id) { + CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m[group_id], n[group_id], + k[group_id], nullptr, lda[group_id], nullptr, + &group_scratch_size); + *scratch_size = + group_scratch_size > *scratch_size ? group_scratch_size : *scratch_size; + } + }); + }); + e.wait(); } +#define UNGQR_GROUP_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ + template <> \ + std::int64_t ungqr_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t * m, std::int64_t * n, std::int64_t * k, \ + std::int64_t * lda, std::int64_t group_count, std::int64_t * group_sizes) { \ + int scratch_size; \ + ungqr_batch_scratchpad_size(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, m, n, k, lda, \ + group_count, group_sizes, &scratch_size); \ + return scratch_size; \ + } + +UNGQR_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnCungqr_bufferSize) +UNGQR_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnZungqr_bufferSize) + +#undef UNGQR_GROUP_LAUNCHER_SCRATCH + } // namespace cusolver } // namespace lapack } // namespace mkl diff --git a/src/lapack/backends/cusolver/cusolver_helper.hpp b/src/lapack/backends/cusolver/cusolver_helper.hpp index 9d96d4f83..a94c3224b 100644 --- a/src/lapack/backends/cusolver/cusolver_helper.hpp +++ b/src/lapack/backends/cusolver/cusolver_helper.hpp @@ -185,14 +185,20 @@ class cuda_error : virtual public std::runtime_error { throw cusolver_error(std::string(name) + std::string(" : "), err); \ } -#define CUSOLVER_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ - err = func(handle, __VA_ARGS__); \ - if (err != CUSOLVER_STATUS_SUCCESS) { \ - throw cusolver_error(std::string(name) + std::string(" : "), err); \ - } \ - cudaStream_t currentStreamId; \ - CUSOLVER_ERROR_FUNC(cusolverDnGetStream, err, handle, ¤tStreamId); \ - cuStreamSynchronize(currentStreamId); +#define CUSOLVER_SYNC(err, handle) \ + cudaStream_t currentStreamId; \ + CUSOLVER_ERROR_FUNC(cusolverDnGetStream, err, handle, ¤tStreamId); \ + { \ + CUresult __cuda_err; \ + CUDA_ERROR_FUNC(cuStreamSynchronize, __cuda_err, currentStreamId); \ + } + +#define CUSOLVER_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != CUSOLVER_STATUS_SUCCESS) { \ + throw cusolver_error(std::string(name) + std::string(" : "), err); \ + } \ + CUSOLVER_SYNC(err, handle) inline cusolverEigType_t get_cusolver_itype(std::int64_t itype) { switch (itype) { @@ -274,26 +280,43 @@ struct CudaEquivalentType> { /* devinfo */ -inline int get_cusolver_devinfo(sycl::queue &queue, sycl::buffer &devInfo) { - sycl::host_accessor dev_info_{ devInfo }; - return dev_info_[0]; +inline void get_cusolver_devinfo(sycl::queue &queue, sycl::buffer &devInfo, + std::vector &dev_info_) { + sycl::host_accessor dev_info_acc{ devInfo }; + for (unsigned int i = 0; i < dev_info_.size(); ++i) + dev_info_[i] = dev_info_acc[i]; } -inline int get_cusolver_devinfo(sycl::queue &queue, const int *devInfo) { - int dev_info_; +inline void get_cusolver_devinfo(sycl::queue &queue, const int *devInfo, + std::vector &dev_info_) { queue.wait(); - queue.memcpy(&dev_info_, devInfo, sizeof(int)); - return dev_info_; + queue.memcpy(dev_info_.data(), devInfo, sizeof(int)); } template inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name, - const char *cufunc_name) { - const int devinfo_ = get_cusolver_devinfo(queue, devinfo); - if (devinfo_ > 0) - throw oneapi::mkl::lapack::computation_error( - func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(devinfo_), - devinfo_); + const char *cufunc_name, int dev_info_size = 1) { + std::vector dev_info_(dev_info_size); + get_cusolver_devinfo(queue, devinfo, dev_info_); + for (const auto &val : dev_info_) { + if (val > 0) + throw oneapi::mkl::lapack::computation_error( + func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(val), + val); + } +} + +/* batched helpers */ + +// Creates list of matrix/vector pointers from initial ptr and stride +// Note: user is responsible for deallocating memory +template +T **create_ptr_list_from_stride(T *ptr, int64_t ptr_stride, int64_t batch_size) { + T **ptr_list = (T **)malloc(sizeof(T *) * batch_size); + for (int64_t i = 0; i < batch_size; i++) + ptr_list[i] = ptr + i * ptr_stride; + + return ptr_list; } } // namespace cusolver