From f22a4cb3472893d47c2c13bc14ef76615e81df7d Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Wed, 21 Aug 2024 08:26:30 +0200 Subject: [PATCH] Extend `gesv_impl/gesv_batch_impl` for work with oneMKL Interfaces (#2001) * Implement batch solve via getrf_batch and getrs_batch * Pass sycl::queue by reference for getrs/getrs_batch * Extend gesv_impl to use onemkl interfaces * Reduce code duplication in gesv_impl * Extend gesv_batch_impl to use onemkl interfaces * Remove getrs_batch implementation * Pass correct batch_strides to gesv_batch_fn * Reduce dublicate code for gesv_impl * Replace maybe_unused to if defined/else * Expand comments for trans parameter --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- dpnp/backend/extensions/lapack/gesv.cpp | 85 +++++++-- dpnp/backend/extensions/lapack/gesv_batch.cpp | 176 ++++++++++++++++-- dpnp/backend/extensions/lapack/getrs.cpp | 6 +- dpnp/backend/extensions/lapack/getrs.hpp | 2 +- 4 files changed, 240 insertions(+), 29 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 660afb58193..e2f6d3ebd76 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -56,11 +56,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, char *in_b, const std::vector &depends) { -#if defined(USE_ONEMKL_INTERFACES) - // Temporary flag for build only - // FIXME: Need to implement by using lapack::getrf and lapack::getrs - std::logic_error("Not Implemented"); -#else type_utils::validate_type_for_device(exec_q); T *a = reinterpret_cast(in_a); @@ -69,12 +64,31 @@ static sycl::event gesv_impl(sycl::queue &exec_q, const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - const std::int64_t scratchpad_size = + std::int64_t scratchpad_size = 0; + sycl::event comp_event; + std::int64_t *ipiv = nullptr; + + std::stringstream error_msg; + bool is_exception_caught = false; + +#if defined(USE_ONEMKL_INTERFACES) + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + // Since gesv takes F-contiguous as input, we use transpose::N. + oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; + + scratchpad_size = std::max( + mkl_lapack::getrf_scratchpad_size(exec_q, n, n, lda), + mkl_lapack::getrs_scratchpad_size(exec_q, trans, n, nrhs, lda, ldb)); + +#else + scratchpad_size = mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); +#endif // USE_ONEMKL_INTERFACES + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); - std::int64_t *ipiv = nullptr; try { ipiv = helper::alloc_ipiv(n, exec_q); } catch (const std::exception &e) { @@ -83,12 +97,57 @@ static sycl::event gesv_impl(sycl::queue &exec_q, throw; } - std::stringstream error_msg; - bool is_exception_caught = false; +#if defined(USE_ONEMKL_INTERFACES) + sycl::event getrf_event; + try { + getrf_event = mkl_lapack::getrf( + exec_q, + n, // The order of the square matrix A (0 ≤ n). + // It must be a non-negative integer. + n, // The number of columns in the square matrix A (0 ≤ n). + // It must be a non-negative integer. + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A. + // It must be at least max(1, n). + ipiv, // Pointer to the output array of pivot indices. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); - sycl::event gesv_event; + comp_event = mkl_lapack::getrs( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A, must be at least max(1, + // n). It must be at least max(1, n). + ipiv, // Pointer to the output array of pivot indices that were used + // during factorization (n, ). + b, // Pointer to the matrix B of right-hand sides (ldb, nrhs). + ldb, // The leading dimension of matrix B, must be at least max(1, + // n). + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, {getrf_event}); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, + scratchpad, ipiv, e, error_msg); + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrf() or " + "getrs() call:\n" + << e.what(); + } +#else try { - gesv_event = mkl_lapack::gesv( + comp_event = mkl_lapack::gesv( exec_q, n, // The order of the square matrix A // and the number of rows in matrix B (0 ≤ n). @@ -114,6 +173,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, error_msg << "Unexpected SYCL exception caught during gesv() call:\n" << e.what(); } +#endif // USE_ONEMKL_INTERFACES if (is_exception_caught) // an unexpected error occurs { @@ -125,7 +185,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, } sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gesv_event); + cgh.depends_on(comp_event); auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { sycl::free(scratchpad, ctx); @@ -134,7 +194,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, }); return ht_ev; -#endif // USE_ONEMKL_INTERFACES } std::pair diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index 90aaf8ebcfd..3d0813f2484 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -44,6 +44,10 @@ typedef sycl::event (*gesv_batch_impl_fn_ptr_t)( const std::int64_t, const std::int64_t, const std::int64_t, +#if defined(USE_ONEMKL_INTERFACES) + const std::int64_t, + const std::int64_t, +#endif // USE_ONEMKL_INTERFACES char *, char *, const std::vector &); @@ -56,6 +60,10 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, const std::int64_t batch_size, +#if defined(USE_ONEMKL_INTERFACES) + const std::int64_t stride_a, + const std::int64_t stride_b, +#endif // USE_ONEMKL_INTERFACES char *in_a, char *in_b, const std::vector &depends) @@ -65,23 +73,147 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, T *a = reinterpret_cast(in_a); T *b = reinterpret_cast(in_b); - const std::int64_t a_size = n * n; - const std::int64_t b_size = n * nrhs; - const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); + std::int64_t scratchpad_size = 0; + sycl::event comp_event; + std::int64_t *ipiv = nullptr; + T *scratchpad = nullptr; + + std::stringstream error_msg; + bool is_exception_caught = false; + +#if defined(USE_ONEMKL_INTERFACES) + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + // Since gesv_batch takes F-contiguous as input, we use transpose::N. + oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; + const std::int64_t stride_ipiv = n; + + scratchpad_size = std::max( + mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, + stride_a, stride_ipiv, ldb, + stride_b, batch_size), + mkl_lapack::getrf_batch_scratchpad_size(exec_q, n, n, lda, stride_a, + stride_ipiv, batch_size)); + + scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + + // pass batch_size * n to allocate the memory for a 2D array of pivot + // indices + try { + ipiv = helper::alloc_ipiv(batch_size * n, exec_q); + } catch (const std::exception &e) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + throw; + } + + sycl::event getrf_batch_event; + try { + getrf_batch_event = mkl_lapack::getrf_batch( + exec_q, + n, // The order of each square matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + n, // The number of columns in each matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + a, // Pointer to the batch of square matrices, each of size (n x n). + lda, // The leading dimension of each matrix in the batch. + stride_a, // Stride between consecutive matrices in the batch. + ipiv, // Pointer to the array of pivot indices for each matrix in + // the batch. + stride_ipiv, // Stride between pivot indices: Spacing between pivot + // arrays in 'ipiv'. + batch_size, // Stride between pivot index arrays in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + + comp_event = mkl_lapack::getrs_batch( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of each square matrix A in the batch + // and the number of rows in each matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in each matrix B in the batch + // (0 ≤ nrhs). + a, // Pointer to the batch of square matrices A (n x n). + lda, // The leading dimension of each matrix A in the batch. + // It must be at least max(1, n). + stride_a, // Stride between individual matrices in the batch for + // matrix A. + ipiv, // Pointer to the batch of arrays of pivot indices. + stride_ipiv, // Stride between pivot index arrays in the batch. + b, // Pointer to the batch of matrices B (n, nrhs). + ldb, // The leading dimension of each matrix B in the batch. + // Must be at least max(1, n). + stride_b, // Stride between individual matrices in the batch for + // matrix B. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, {getrf_batch_event}); + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + + error_msg << "Singular matrix. Errors in matrices with IDs: "; + for (size_t i = 0; i < error_matrices_ids.size(); ++i) { + error_msg << error_matrices_ids[i]; + if (i < error_matrices_ids.size() - 1) { + error_msg << ", "; + } + } + error_msg << "."; + + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + + throw LinAlgError(error_msg.str().c_str()); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + std::int64_t info = e.info(); + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else { + error_msg << "Unexpected MKL exception caught during getrf_batch() " + "or getrs_batch() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrf() or " + "getrs() call:\n" + << e.what(); + } +#else + const std::int64_t a_size = n * n; + const std::int64_t b_size = n * nrhs; + // Get the number of independent linear streams const std::int64_t n_linear_streams = (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1)); - const std::int64_t scratchpad_size = + scratchpad_size = mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); - T *scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, - n_linear_streams, exec_q); + scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, + n_linear_streams, exec_q); - std::int64_t *ipiv = nullptr; try { ipiv = helper::alloc_ipiv_batch(n, n_linear_streams, exec_q); } catch (const std::exception &e) { @@ -93,9 +225,6 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, // Computation events to manage dependencies for each linear stream std::vector> comp_evs(n_linear_streams, depends); - std::stringstream error_msg; - bool is_exception_caught = false; - // Release GIL to avoid serialization of host task // submissions to the same queue in OneMKL py::gil_scoped_release release; @@ -147,6 +276,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, // Update the event dependencies for the current stream comp_evs[stream_id] = {gesv_event}; } +#endif // USE_ONEMKL_INTERFACES if (is_exception_caught) // an unexpected error occurs { @@ -158,9 +288,13 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, } sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { +#if defined(USE_ONEMKL_INTERFACES) + cgh.depends_on(comp_event); +#else for (const auto &ev : comp_evs) { cgh.depends_on(ev); } +#endif // USE_ONEMKL_INTERFACES auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { sycl::free(scratchpad, ctx); @@ -242,9 +376,27 @@ std::pair const std::int64_t nrhs = (dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1; - sycl::event gesv_ev = - gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data, + sycl::event gesv_ev; + +#if defined(USE_ONEMKL_INTERFACES) + auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector(); + auto const &dependent_vals_strides = dependent_vals.get_strides_vector(); + + // Get the strides for the batch matrices. + // Since the matrices are stored in F-contiguous order, + // the stride between batches is the last element in the strides vector. + const std::int64_t coeff_matrix_batch_stride = coeff_matrix_strides.back(); + const std::int64_t dependent_vals_batch_stride = + dependent_vals_strides.back(); + + gesv_ev = + gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride, + dependent_vals_batch_stride, coeff_matrix_data, dependent_vals_data, depends); +#else + gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data, + dependent_vals_data, depends); +#endif // USE_ONEMKL_INTERFACES sycl::event ht_ev = dpctl::utils::keep_args_alive( exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 6a8f41a904e..3b3547f1a5c 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -41,7 +41,7 @@ namespace mkl_lapack = oneapi::mkl::lapack; namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue, +typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &, oneapi::mkl::transpose, const std::int64_t, const std::int64_t, @@ -56,7 +56,7 @@ typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue, static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types]; template -static sycl::event getrs_impl(sycl::queue exec_q, +static sycl::event getrs_impl(sycl::queue &exec_q, oneapi::mkl::transpose trans, const std::int64_t n, const std::int64_t nrhs, @@ -156,7 +156,7 @@ static sycl::event getrs_impl(sycl::queue exec_q, } std::pair - getrs(sycl::queue exec_q, + getrs(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array, diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 9906cf53eee..437799ef224 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -33,7 +33,7 @@ namespace dpnp::extensions::lapack { extern std::pair - getrs(sycl::queue exec_q, + getrs(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array,