From 2a2ca7be7a04c45beb218e2288e30e0b79f18444 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 16 Aug 2024 14:59:55 +0200 Subject: [PATCH 01/12] Implement batch solve via getrf_batch and getrs_batch --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/getrs.hpp | 15 + .../backend/extensions/lapack/getrs_batch.cpp | 343 ++++++++++++++++++ dpnp/backend/extensions/lapack/lapack_py.cpp | 11 + dpnp/linalg/dpnp_utils_linalg.py | 215 ++++++++--- 5 files changed, 537 insertions(+), 48 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/getrs_batch.cpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 91d2f832d58..426c0cb0ec7 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -36,6 +36,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 9906cf53eee..49ec7e1306f 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -39,5 +39,20 @@ extern std::pair dpctl::tensor::usm_ndarray b_array, const std::vector &depends = {}); +extern std::pair + getrs_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray ipiv_array, + dpctl::tensor::usm_ndarray b_array, + std::int64_t batch_size, + std::int64_t n, + std::int64_t nrhs, + std::int64_t stride_a, + std::int64_t stride_ipiv, + std::int64_t stride_b, + py::list dev_info, + const std::vector &depends = {}); + extern void init_getrs_dispatch_vector(void); +extern void init_getrs_batch_dispatch_vector(void); } // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp new file mode 100644 index 00000000000..b90b04223f2 --- /dev/null +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -0,0 +1,343 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "getrs.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp::extensions::lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( + sycl::queue, + oneapi::mkl::transpose, + std::int64_t, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::int64_t *, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::int64_t, + py::list, + const std::vector &); + +static getrs_batch_impl_fn_ptr_t + getrs_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event getrs_batch_impl(sycl::queue exec_q, + oneapi::mkl::transpose trans, + std::int64_t n, + std::int64_t nrhs, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + std::int64_t *ipiv, + std::int64_t stride_ipiv, + char *in_b, + std::int64_t ldb, + std::int64_t stride_b, + std::int64_t batch_size, + py::list dev_info, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t scratchpad_size = + mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, + stride_a, stride_ipiv, ldb, + stride_b, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event getrs_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + getrs_batch_event = mkl_lapack::getrs_batch( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of each square matrix A in the batch + // and the number of rows in each matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in each matrix B in the batch + // (0 ≤ nrhs). + a, // Pointer to the batch of square matrices A (n x n). + lda, // The leading dimension of each matrix A in the batch. + // It must be at least max(1, n). + stride_a, // Stride between individual matrices in the batch for + // matrix A. + ipiv, // Pointer to the batch of arrays of pivot indices. + stride_ipiv, // Stride between pivot index arrays in the batch. + b, // Pointer to the batch of matrices B (n, nrhs). + ldb, // The leading dimension of each matrix B in the batch. + // Must be at least max(1, n). + stride_b, // Stride between individual matrices in the batch for + // matrix B. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + // Get the indices of the first zero diagonal elements of these matrices + auto error_info = be.exceptions(); + + auto error_matrices_ids_size = error_matrices_ids.size(); + auto dev_info_size = static_cast(py::len(dev_info)); + if (error_matrices_ids_size != dev_info_size) { + throw py::value_error("The size of `dev_info` must be equal to" + + std::to_string(error_matrices_ids_size) + + ", but currently it is " + + std::to_string(dev_info_size) + "."); + } + + for (size_t i = 0; i < error_matrices_ids.size(); ++i) { + // Assign the index of the first zero diagonal element in each + // error matrix to the corresponding index in 'dev_info' + dev_info[error_matrices_ids[i]] = error_info[i]; + } + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else { + error_msg << "Unexpected MKL exception caught during getrs_batch() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during getrs_batch() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(getrs_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + + return ht_ev; +} + +std::pair + getrs_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray ipiv_array, + dpctl::tensor::usm_ndarray b_array, + std::int64_t batch_size, + std::int64_t n, + std::int64_t nrhs, + std::int64_t stride_a, + std::int64_t stride_ipiv, + std::int64_t stride_b, + py::list dev_info, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int b_array_nd = b_array.get_ndim(); + const int ipiv_array_nd = ipiv_array.get_ndim(); + + if (a_array_nd != 3) { + throw py::value_error( + "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + + ", but a 3-dimensional array is expected."); + } + + if (b_array_nd < 2 || b_array_nd > 3) { + throw py::value_error("The dependent values array has ndim=" + + std::to_string(b_array_nd) + ", but a " + + "2-dimensional or a " + + "3-dimensional array is expected."); + } + + if (ipiv_array_nd != 2) { + throw py::value_error("The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but a 2-dimensional array is expected."); + } + + const int dev_info_size = py::len(dev_info); + if (dev_info_size != batch_size) { + throw py::value_error("The size of 'dev_info' (" + + std::to_string(dev_info_size) + + ") does not match the expected batch size (" + + std::to_string(batch_size) + ")."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {a_array, b_array, ipiv_array})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, ipiv_array) || overlap(a_array, b_array) || + overlap(ipiv_array, b_array)) + { + throw py::value_error("Arrays have overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_a_array_f_contig = a_array.is_f_contiguous(); + bool is_b_array_f_contig = b_array.is_f_contiguous(); + bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); + bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous(); + if (!is_a_array_c_contig && !is_a_array_f_contig) { + throw py::value_error("The LU-factorized array " + "must be either C-contiguous " + "or F-contiguous"); + } + if (!is_b_array_f_contig) { + throw py::value_error("The right-hand sides array " + "must be F-contiguous"); + } + if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { + throw py::value_error("The array of pivot indices " + "must be contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int b_array_type_id = + array_types.typenum_to_lookup_id(b_array.get_typenum()); + + if (a_array_type_id != b_array_type_id) { + throw py::value_error("The types of the LU-factorized and " + "right-hand sides arrays are mismatched"); + } + + getrs_batch_impl_fn_ptr_t getrs_batch_fn = + getrs_batch_dispatch_vector[a_array_type_id]; + if (getrs_batch_fn == nullptr) { + throw py::value_error( + "No getrs_batch implementation defined for the provided type " + "of the input matrix."); + } + + auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); + int ipiv_array_type_id = + ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); + + if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { + throw py::value_error("The type of 'ipiv_array' must be int64."); + } + + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + oneapi::mkl::transpose trans = is_a_array_c_contig + ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + + char *a_array_data = a_array.get_data(); + char *b_array_data = b_array.get_data(); + char *ipiv_array_data = ipiv_array.get_data(); + std::int64_t *d_ipiv = reinterpret_cast(ipiv_array_data); + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + sycl::event getrs_batch_ev = + getrs_batch_fn(exec_q, trans, n, nrhs, a_array_data, lda, stride_a, + d_ipiv, stride_ipiv, b_array_data, ldb, stride_b, + batch_size, dev_info, depends); + + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {a_array, ipiv_array, b_array}, {getrs_batch_ev}); + + return std::make_pair(ht_ev, getrs_batch_ev); +} + +template +struct GetrsBatchContigFactory +{ + fnT get() + { + if constexpr (types::GetrsTypePairSupportFactory::is_defined) { + return getrs_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_getrs_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(getrs_batch_dispatch_vector); +} +} // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index b2981089339..4667906b3da 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -59,6 +59,7 @@ void init_dispatch_vectors(void) lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); lapack_ext::init_getrs_dispatch_vector(); + lapack_ext::init_getrs_batch_dispatch_vector(); lapack_ext::init_orgqr_batch_dispatch_vector(); lapack_ext::init_orgqr_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); @@ -153,6 +154,16 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), py::arg("b_array"), py::arg("depends") = py::list()); + m.def("_getrs_batch", &lapack_ext::getrs_batch, + "Call `getrs_batch` from OneMKL LAPACK library to return " + "the solves of linear equations with a batch of LU-factored " + "square coefficient matrix, with multiple right-hand sides", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), + py::arg("b_array"), py::arg("batch_size"), py::arg("n"), + py::arg("nrhs"), py::arg("stride_a"), py::arg("stride_ipiv"), + py::arg("stride_b"), py::arg("dev_info_array"), + py::arg("depends") = py::list()); + m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " "the real orthogonal matrix Qi of the QR factorization " diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 7865775418d..878cbe3b997 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -272,66 +272,185 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type): a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1])) - # Reorder the elements by moving the last two axes of `a` to the front - # to match fortran-like array order which is assumed by gesv. - a = dpnp.moveaxis(a, (-2, -1), (0, 1)) - # The same for `b` if it is 3D; - # if it is 2D, transpose it. - if b.ndim > 2: - b = dpnp.moveaxis(b, (-2, -1), (0, 1)) + new = True + + if new: + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getri_batch overwrites `a` + a_h = dpnp.empty_like( + a, order="C", dtype=res_type, usm_type=res_usm_type + ) + + # use DPCTL tensor function to fill the matrix array + # with content from the input array `a` + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a.get_array(), + dst=a_h.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + batch_size = a.shape[0] + n = a.shape[1] + + ipiv_h = dpnp.empty( + (batch_size, n), + dtype=dpnp.int64, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + dev_info = [0] * batch_size + + a_stride = a_h.strides[0] + ipiv_stride = n + + # Call the LAPACK extension function _getrf_batch + # to perform LU decomposition of a batch of general matrices + ht_ev, getrf_batch_ev = li._getrf_batch( + exec_q, + a_h.get_array(), + ipiv_h.get_array(), + dev_info, + n, + a_stride, + ipiv_stride, + batch_size, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_batch_ev) + + _check_lapack_dev_info(dev_info) + + # The same for `b` if it is 3D; + # if it is 2D, transpose it. + if b.ndim > 2: + b = dpnp.moveaxis(b, (-2, -1), (0, 1)) + else: + b = b.T + b_usm_arr = dpnp.get_usm_ndarray(b) + + # oneMKL LAPACK getrs overwrites `b` and assumes fortran-like array as + # input. + # Allocate 'F' order memory for dpnp arrays to comply with + # these requirements. + b_f = dpnp.empty_like( + b, order="F", dtype=res_type, usm_type=res_usm_type + ) + + # use DPCTL tensor function to fill the array of multiple dependent + # variables with content from the input array `b` + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_f.get_array(), + sycl_queue=b.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) + + nrhs = b_f.shape[1] if b_f.ndim > 2 else 1 + b_stride = b_f.strides[-1] + + # Call the LAPACK extension function _getrs_batch + # to solve the system of linear equations with a batch of LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_batch_ev = li._getrs_batch( + exec_q, + a_h.get_array(), + ipiv_h.get_array(), + b_f.get_array(), + batch_size, + n, + nrhs, + a_stride, + ipiv_stride, + b_stride, + dev_info, + depends=[getrf_batch_ev, b_copy_ev], + ) + _manager.add_event_pair(ht_ev, getrs_batch_ev) + + _check_lapack_dev_info(dev_info) + + # Getrs_batch call overwtires `b` in Fortran order, reorder the axes + # to match C order by moving the last axis to the front and + # reshape it back to the original shape of `b`. + v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) + + # dpnp.moveaxis can make the array non-contiguous if it is not 2D + # Convert to contiguous to align with NumPy + if b.ndim > 2: + v = dpnp.ascontiguousarray(v) + + return v + else: - b = b.T + # Reorder the elements by moving the last two axes of `a` to the front + # to match fortran-like array order which is assumed by gesv. + a = dpnp.moveaxis(a, (-2, -1), (0, 1)) + # The same for `b` if it is 3D; + # if it is 2D, transpose it. + if b.ndim > 2: + b = dpnp.moveaxis(b, (-2, -1), (0, 1)) + else: + b = b.T - a_usm_arr = dpnp.get_usm_ndarray(a) - b_usm_arr = dpnp.get_usm_ndarray(b) + a_usm_arr = dpnp.get_usm_ndarray(a) + b_usm_arr = dpnp.get_usm_ndarray(b) - _manager = dpu.SequentialOrderManager[exec_q] - dep_evs = _manager.submitted_events + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events - # oneMKL LAPACK gesv destroys `a` and assumes fortran-like array - # as input. - a_f = dpnp.empty_like(a, dtype=res_type, order="F", usm_type=res_usm_type) + # oneMKL LAPACK gesv destroys `a` and assumes fortran-like array + # as input. + a_f = dpnp.empty_like( + a, dtype=res_type, order="F", usm_type=res_usm_type + ) - ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_f.get_array(), - sycl_queue=exec_q, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, a_copy_ev) + ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, a_copy_ev) - # oneMKL LAPACK gesv overwrites `b` and assumes fortran-like array - # as input. - b_f = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) + # oneMKL LAPACK gesv overwrites `b` and assumes fortran-like array + # as input. + b_f = dpnp.empty_like( + b, order="F", dtype=res_type, usm_type=res_usm_type + ) - ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr, - dst=b_f.get_array(), - sycl_queue=exec_q, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, b_copy_ev) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) - ht_ev, gesv_batch_ev = li._gesv_batch( - exec_q, - a_f.get_array(), - b_f.get_array(), - depends=[a_copy_ev, b_copy_ev], - ) + ht_ev, gesv_batch_ev = li._gesv_batch( + exec_q, + a_f.get_array(), + b_f.get_array(), + depends=[a_copy_ev, b_copy_ev], + ) - _manager.add_event_pair(ht_ev, gesv_batch_ev) + _manager.add_event_pair(ht_ev, gesv_batch_ev) - # Gesv call overwtires `b` in Fortran order, reorder the axes - # to match C order by moving the last axis to the front and - # reshape it back to the original shape of `b`. - v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) + # Gesv call overwtires `b` in Fortran order, reorder the axes + # to match C order by moving the last axis to the front and + # reshape it back to the original shape of `b`. + v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) - # dpnp.moveaxis can make the array non-contiguous if it is not 2D - # Convert to contiguous to align with NumPy - if b.ndim > 2: - v = dpnp.ascontiguousarray(v) + # dpnp.moveaxis can make the array non-contiguous if it is not 2D + # Convert to contiguous to align with NumPy + if b.ndim > 2: + v = dpnp.ascontiguousarray(v) - return v + return v def _batched_qr(a, mode="reduced"): From 3143be6fd08239cd51f50903c471634abfdf4e60 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 16 Aug 2024 15:44:51 +0200 Subject: [PATCH 02/12] Pass sycl::queue by reference for getrs/getrs_batch --- dpnp/backend/extensions/lapack/getrs.cpp | 6 +++--- dpnp/backend/extensions/lapack/getrs.hpp | 4 ++-- dpnp/backend/extensions/lapack/getrs_batch.cpp | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 6a8f41a904e..3b3547f1a5c 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -41,7 +41,7 @@ namespace mkl_lapack = oneapi::mkl::lapack; namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue, +typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &, oneapi::mkl::transpose, const std::int64_t, const std::int64_t, @@ -56,7 +56,7 @@ typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue, static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types]; template -static sycl::event getrs_impl(sycl::queue exec_q, +static sycl::event getrs_impl(sycl::queue &exec_q, oneapi::mkl::transpose trans, const std::int64_t n, const std::int64_t nrhs, @@ -156,7 +156,7 @@ static sycl::event getrs_impl(sycl::queue exec_q, } std::pair - getrs(sycl::queue exec_q, + getrs(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array, diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 49ec7e1306f..19c1b19846d 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -33,14 +33,14 @@ namespace dpnp::extensions::lapack { extern std::pair - getrs(sycl::queue exec_q, + getrs(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array, const std::vector &depends = {}); extern std::pair - getrs_batch(sycl::queue exec_q, + getrs_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array, diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index b90b04223f2..e09095d3107 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -41,7 +41,7 @@ namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( - sycl::queue, + sycl::queue &, oneapi::mkl::transpose, std::int64_t, std::int64_t, @@ -61,7 +61,7 @@ static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[dpctl_td_ns::num_types]; template -static sycl::event getrs_batch_impl(sycl::queue exec_q, +static sycl::event getrs_batch_impl(sycl::queue &exec_q, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, @@ -188,7 +188,7 @@ static sycl::event getrs_batch_impl(sycl::queue exec_q, } std::pair - getrs_batch(sycl::queue exec_q, + getrs_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray a_array, dpctl::tensor::usm_ndarray ipiv_array, dpctl::tensor::usm_ndarray b_array, From 8267032d1b34a659b70910193c77591cc974b9d9 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Sun, 18 Aug 2024 17:28:16 +0200 Subject: [PATCH 03/12] Extend gesv_impl to use onemkl interfaces --- dpnp/backend/extensions/lapack/gesv.cpp | 120 +++++++++++++++++++++++- 1 file changed, 117 insertions(+), 3 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 660afb58193..efb5853609f 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -57,9 +57,123 @@ static sycl::event gesv_impl(sycl::queue &exec_q, const std::vector &depends) { #if defined(USE_ONEMKL_INTERFACES) - // Temporary flag for build only - // FIXME: Need to implement by using lapack::getrf and lapack::getrs - std::logic_error("Not Implemented"); + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; + + const std::int64_t scratchpad_size = std::max( + mkl_lapack::getrf_scratchpad_size(exec_q, n, n, lda), + mkl_lapack::getrs_scratchpad_size(exec_q, trans, n, nrhs, lda, ldb)); + + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + + std::int64_t *ipiv = nullptr; + try { + ipiv = helper::alloc_ipiv(n, exec_q); + } catch (const std::exception &e) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + throw; + } + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event getrf_event; + try { + getrf_event = mkl_lapack::getrf( + exec_q, + n, // The order of the square matrix A (0 ≤ n). + // It must be a non-negative integer. + n, // The number of columns in the square matrix A (0 ≤ n). + // It must be a non-negative integer. + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A. + // It must be at least max(1, n). + ipiv, // Pointer to the output array of pivot indices. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, + scratchpad, ipiv, e, error_msg); + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrf() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + throw std::runtime_error(error_msg.str()); + } + + sycl::event getrs_event; + try { + getrs_event = mkl_lapack::getrs( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A, must be at least max(1, + // n). It must be at least max(1, n). + ipiv, // Pointer to the output array of pivot indices that were used + // during factorization (n, ). + b, // Pointer to the matrix B of right-hand sides (ldb, nrhs). + ldb, // The leading dimension of matrix B, must be at least max(1, + // n). + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, {getrf_event}); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, + scratchpad, ipiv, e, error_msg); + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrs() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + throw std::runtime_error(error_msg.str()); + } + + sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(getrs_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad, ipiv]() { + sycl::free(scratchpad, ctx); + sycl::free(ipiv, ctx); + }); + }); + + return ht_ev; + #else type_utils::validate_type_for_device(exec_q); From a637b2951eb60aa80e33c4f9bee746a7d14a619d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Sun, 18 Aug 2024 18:24:13 +0200 Subject: [PATCH 04/12] Reduce code duplication in gesv_impl --- dpnp/backend/extensions/lapack/gesv.cpp | 88 ++++++------------------- 1 file changed, 19 insertions(+), 69 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index efb5853609f..1f1521e0403 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -56,7 +56,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, char *in_b, const std::vector &depends) { -#if defined(USE_ONEMKL_INTERFACES) type_utils::validate_type_for_device(exec_q); T *a = reinterpret_cast(in_a); @@ -65,17 +64,25 @@ static sycl::event gesv_impl(sycl::queue &exec_q, const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); + std::int64_t scratchpad_size; + sycl::event comp_event; + std::int64_t *ipiv = nullptr; + T *scratchpad = nullptr; + + std::stringstream error_msg; + bool is_exception_caught = false; + +#if defined(USE_ONEMKL_INTERFACES) // Use transpose::T if the LU-factorized array is passed as C-contiguous. // For F-contiguous we use transpose::N. oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; - const std::int64_t scratchpad_size = std::max( + scratchpad_size = std::max( mkl_lapack::getrf_scratchpad_size(exec_q, n, n, lda), mkl_lapack::getrs_scratchpad_size(exec_q, trans, n, nrhs, lda, ldb)); - T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); - std::int64_t *ipiv = nullptr; try { ipiv = helper::alloc_ipiv(n, exec_q); } catch (const std::exception &e) { @@ -84,9 +91,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, throw; } - std::stringstream error_msg; - bool is_exception_caught = false; - sycl::event getrf_event; try { getrf_event = mkl_lapack::getrf( @@ -102,28 +106,8 @@ static sycl::event gesv_impl(sycl::queue &exec_q, scratchpad, // Pointer to scratchpad memory to be used by MKL // routine for storing intermediate results. scratchpad_size, depends); - } catch (mkl_lapack::exception const &e) { - is_exception_caught = true; - gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, - scratchpad, ipiv, e, error_msg); - } catch (sycl::exception const &e) { - is_exception_caught = true; - error_msg << "Unexpected SYCL exception caught during getrf() call:\n" - << e.what(); - } - if (is_exception_caught) // an unexpected error occurs - { - if (scratchpad != nullptr) - sycl::free(scratchpad, exec_q); - if (ipiv != nullptr) - sycl::free(ipiv, exec_q); - throw std::runtime_error(error_msg.str()); - } - - sycl::event getrs_event; - try { - getrs_event = mkl_lapack::getrs( + comp_event = mkl_lapack::getrs( exec_q, trans, // Specifies the operation: whether or not to transpose // matrix A. Can be 'N' for no transpose, 'T' for transpose, @@ -150,45 +134,16 @@ static sycl::event gesv_impl(sycl::queue &exec_q, scratchpad, ipiv, e, error_msg); } catch (sycl::exception const &e) { is_exception_caught = true; - error_msg << "Unexpected SYCL exception caught during getrs() call:\n" + error_msg << "Unexpected SYCL exception caught during getrf and " + "getrs() call:\n" << e.what(); } - - if (is_exception_caught) // an unexpected error occurs - { - if (scratchpad != nullptr) - sycl::free(scratchpad, exec_q); - if (ipiv != nullptr) - sycl::free(ipiv, exec_q); - throw std::runtime_error(error_msg.str()); - } - - sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(getrs_event); - auto ctx = exec_q.get_context(); - cgh.host_task([ctx, scratchpad, ipiv]() { - sycl::free(scratchpad, ctx); - sycl::free(ipiv, ctx); - }); - }); - - return ht_ev; - #else - type_utils::validate_type_for_device(exec_q); - - T *a = reinterpret_cast(in_a); - T *b = reinterpret_cast(in_b); - - const std::int64_t lda = std::max(1UL, n); - const std::int64_t ldb = std::max(1UL, n); - - const std::int64_t scratchpad_size = + scratchpad_size = mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); - T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); - std::int64_t *ipiv = nullptr; try { ipiv = helper::alloc_ipiv(n, exec_q); } catch (const std::exception &e) { @@ -197,12 +152,8 @@ static sycl::event gesv_impl(sycl::queue &exec_q, throw; } - std::stringstream error_msg; - bool is_exception_caught = false; - - sycl::event gesv_event; try { - gesv_event = mkl_lapack::gesv( + comp_event = mkl_lapack::gesv( exec_q, n, // The order of the square matrix A // and the number of rows in matrix B (0 ≤ n). @@ -228,7 +179,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, error_msg << "Unexpected SYCL exception caught during gesv() call:\n" << e.what(); } - +#endif if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) @@ -239,7 +190,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, } sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gesv_event); + cgh.depends_on(comp_event); auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { sycl::free(scratchpad, ctx); @@ -248,7 +199,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, }); return ht_ev; -#endif // USE_ONEMKL_INTERFACES } std::pair From 16290c2228845bf4f8d1925374cf26b1b646b4f4 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 19 Aug 2024 00:51:17 +0200 Subject: [PATCH 05/12] Extend gesv_batch_impl to use onemkl interfaces --- dpnp/backend/extensions/lapack/gesv_batch.cpp | 164 ++++++++++++++++-- 1 file changed, 151 insertions(+), 13 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index 90aaf8ebcfd..d77ccc6d120 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -44,6 +44,8 @@ typedef sycl::event (*gesv_batch_impl_fn_ptr_t)( const std::int64_t, const std::int64_t, const std::int64_t, + const std::int64_t, + const std::int64_t, char *, char *, const std::vector &); @@ -56,6 +58,8 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, const std::int64_t batch_size, + [[maybe_unused]] const std::int64_t stride_a, + [[maybe_unused]] const std::int64_t stride_b, char *in_a, char *in_b, const std::vector &depends) @@ -65,23 +69,146 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, T *a = reinterpret_cast(in_a); T *b = reinterpret_cast(in_b); - const std::int64_t a_size = n * n; - const std::int64_t b_size = n * nrhs; - const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); + std::int64_t scratchpad_size; + sycl::event comp_event; + std::int64_t *ipiv = nullptr; + T *scratchpad = nullptr; + + std::stringstream error_msg; + bool is_exception_caught = false; + +#if defined(USE_ONEMKL_INTERFACES) + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; + const std::int64_t stride_ipiv = n; + + scratchpad_size = std::max( + mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, + stride_a, stride_ipiv, ldb, + stride_b, batch_size), + mkl_lapack::getrf_batch_scratchpad_size(exec_q, n, n, lda, stride_a, + stride_ipiv, batch_size)); + + scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + + // pass batch_size * n to allocate the memory for a 2D array of pivot + // indices + try { + ipiv = helper::alloc_ipiv(batch_size * n, exec_q); + } catch (const std::exception &e) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + throw; + } + + sycl::event getrf_batch_event; + try { + getrf_batch_event = mkl_lapack::getrf_batch( + exec_q, + n, // The order of each square matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + n, // The number of columns in each matrix in the batch; (0 ≤ n). + // It must be a non-negative integer. + a, // Pointer to the batch of square matrices, each of size (n x n). + lda, // The leading dimension of each matrix in the batch. + stride_a, // Stride between consecutive matrices in the batch. + ipiv, // Pointer to the array of pivot indices for each matrix in + // the batch. + stride_ipiv, // Stride between pivot indices: Spacing between pivot + // arrays in 'ipiv'. + batch_size, // Stride between pivot index arrays in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + + comp_event = mkl_lapack::getrs_batch( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of each square matrix A in the batch + // and the number of rows in each matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in each matrix B in the batch + // (0 ≤ nrhs). + a, // Pointer to the batch of square matrices A (n x n). + lda, // The leading dimension of each matrix A in the batch. + // It must be at least max(1, n). + stride_a, // Stride between individual matrices in the batch for + // matrix A. + ipiv, // Pointer to the batch of arrays of pivot indices. + stride_ipiv, // Stride between pivot index arrays in the batch. + b, // Pointer to the batch of matrices B (n, nrhs). + ldb, // The leading dimension of each matrix B in the batch. + // Must be at least max(1, n). + stride_b, // Stride between individual matrices in the batch for + // matrix B. + batch_size, // The number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, {getrf_batch_event}); + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + + error_msg << "Singular matrix. Errors in matrices with IDs: "; + for (size_t i = 0; i < error_matrices_ids.size(); ++i) { + error_msg << error_matrices_ids[i]; + if (i < error_matrices_ids.size() - 1) { + error_msg << ", "; + } + } + error_msg << "."; + + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + + throw LinAlgError(error_msg.str().c_str()); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + std::int64_t info = e.info(); + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else { + error_msg << "Unexpected MKL exception caught during getrf_batch() " + "and getrs_batch() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrf and " + "getrs() call:\n" + << e.what(); + } +#else + const std::int64_t a_size = n * n; + const std::int64_t b_size = n * nrhs; + // Get the number of independent linear streams const std::int64_t n_linear_streams = (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1)); - const std::int64_t scratchpad_size = + scratchpad_size = mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); - T *scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, - n_linear_streams, exec_q); + scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, + n_linear_streams, exec_q); - std::int64_t *ipiv = nullptr; try { ipiv = helper::alloc_ipiv_batch(n, n_linear_streams, exec_q); } catch (const std::exception &e) { @@ -93,9 +220,6 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, // Computation events to manage dependencies for each linear stream std::vector> comp_evs(n_linear_streams, depends); - std::stringstream error_msg; - bool is_exception_caught = false; - // Release GIL to avoid serialization of host task // submissions to the same queue in OneMKL py::gil_scoped_release release; @@ -147,7 +271,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, // Update the event dependencies for the current stream comp_evs[stream_id] = {gesv_event}; } - +#endif if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) @@ -158,9 +282,13 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, } sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { +#if defined(USE_ONEMKL_INTERFACES) + cgh.depends_on(comp_event); +#else for (const auto &ev : comp_evs) { cgh.depends_on(ev); } +#endif auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { sycl::free(scratchpad, ctx); @@ -242,9 +370,19 @@ std::pair const std::int64_t nrhs = (dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1; + auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector(); + auto const &dependent_vals_strides = dependent_vals.get_strides_vector(); + + // Get the strides for the batch matrices. + // Since the matrices are stored in F-contiguous order, + // the stride between batches is the last element in the strides vector. + const std::int64_t coeff_matrix_batch_stride = coeff_matrix_strides.back(); + const std::int64_t dependent_vals_batch_stride = + dependent_vals_strides.back(); + sycl::event gesv_ev = - gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data, - dependent_vals_data, depends); + gesv_batch_fn(exec_q, n, nrhs, batch_size, stride_a, stride_b, + coeff_matrix_data, dependent_vals_data, depends); sycl::event ht_ev = dpctl::utils::keep_args_alive( exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); From 0bd09897d86f5f9a546433da3c643789730c0480 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 19 Aug 2024 00:59:26 +0200 Subject: [PATCH 06/12] Remove getrs_batch implementation --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 - dpnp/backend/extensions/lapack/getrs.hpp | 15 - .../backend/extensions/lapack/getrs_batch.cpp | 343 ------------------ dpnp/backend/extensions/lapack/lapack_py.cpp | 11 - dpnp/linalg/dpnp_utils_linalg.py | 215 +++-------- 5 files changed, 48 insertions(+), 537 deletions(-) delete mode 100644 dpnp/backend/extensions/lapack/getrs_batch.cpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 426c0cb0ec7..91d2f832d58 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -36,7 +36,6 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 19c1b19846d..437799ef224 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -39,20 +39,5 @@ extern std::pair dpctl::tensor::usm_ndarray b_array, const std::vector &depends = {}); -extern std::pair - getrs_batch(sycl::queue &exec_q, - dpctl::tensor::usm_ndarray a_array, - dpctl::tensor::usm_ndarray ipiv_array, - dpctl::tensor::usm_ndarray b_array, - std::int64_t batch_size, - std::int64_t n, - std::int64_t nrhs, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t stride_b, - py::list dev_info, - const std::vector &depends = {}); - extern void init_getrs_dispatch_vector(void); -extern void init_getrs_batch_dispatch_vector(void); } // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp deleted file mode 100644 index e09095d3107..00000000000 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ /dev/null @@ -1,343 +0,0 @@ -//***************************************************************************** -// Copyright (c) 2024, Intel Corporation -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// - Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// - Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -// THE POSSIBILITY OF SUCH DAMAGE. -//***************************************************************************** - -#include - -// dpctl tensor headers -#include "utils/memory_overlap.hpp" -#include "utils/type_utils.hpp" - -#include "getrs.hpp" -#include "types_matrix.hpp" - -#include "dpnp_utils.hpp" - -namespace dpnp::extensions::lapack -{ -namespace mkl_lapack = oneapi::mkl::lapack; -namespace py = pybind11; -namespace type_utils = dpctl::tensor::type_utils; - -typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( - sycl::queue &, - oneapi::mkl::transpose, - std::int64_t, - std::int64_t, - char *, - std::int64_t, - std::int64_t, - std::int64_t *, - std::int64_t, - char *, - std::int64_t, - std::int64_t, - std::int64_t, - py::list, - const std::vector &); - -static getrs_batch_impl_fn_ptr_t - getrs_batch_dispatch_vector[dpctl_td_ns::num_types]; - -template -static sycl::event getrs_batch_impl(sycl::queue &exec_q, - oneapi::mkl::transpose trans, - std::int64_t n, - std::int64_t nrhs, - char *in_a, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t *ipiv, - std::int64_t stride_ipiv, - char *in_b, - std::int64_t ldb, - std::int64_t stride_b, - std::int64_t batch_size, - py::list dev_info, - const std::vector &depends) -{ - type_utils::validate_type_for_device(exec_q); - - T *a = reinterpret_cast(in_a); - T *b = reinterpret_cast(in_b); - - const std::int64_t scratchpad_size = - mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, - stride_a, stride_ipiv, ldb, - stride_b, batch_size); - T *scratchpad = nullptr; - - std::stringstream error_msg; - std::int64_t info = 0; - bool is_exception_caught = false; - - sycl::event getrs_batch_event; - try { - scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - - getrs_batch_event = mkl_lapack::getrs_batch( - exec_q, - trans, // Specifies the operation: whether or not to transpose - // matrix A. Can be 'N' for no transpose, 'T' for transpose, - // and 'C' for conjugate transpose. - n, // The order of each square matrix A in the batch - // and the number of rows in each matrix B (0 ≤ n). - // It must be a non-negative integer. - nrhs, // The number of right-hand sides, - // i.e., the number of columns in each matrix B in the batch - // (0 ≤ nrhs). - a, // Pointer to the batch of square matrices A (n x n). - lda, // The leading dimension of each matrix A in the batch. - // It must be at least max(1, n). - stride_a, // Stride between individual matrices in the batch for - // matrix A. - ipiv, // Pointer to the batch of arrays of pivot indices. - stride_ipiv, // Stride between pivot index arrays in the batch. - b, // Pointer to the batch of matrices B (n, nrhs). - ldb, // The leading dimension of each matrix B in the batch. - // Must be at least max(1, n). - stride_b, // Stride between individual matrices in the batch for - // matrix B. - batch_size, // The number of matrices in the batch. - scratchpad, // Pointer to scratchpad memory to be used by MKL - // routine for storing intermediate results. - scratchpad_size, depends); - } catch (mkl_lapack::batch_error const &be) { - // Get the indices of matrices within the batch that encountered an - // error - auto error_matrices_ids = be.ids(); - // Get the indices of the first zero diagonal elements of these matrices - auto error_info = be.exceptions(); - - auto error_matrices_ids_size = error_matrices_ids.size(); - auto dev_info_size = static_cast(py::len(dev_info)); - if (error_matrices_ids_size != dev_info_size) { - throw py::value_error("The size of `dev_info` must be equal to" + - std::to_string(error_matrices_ids_size) + - ", but currently it is " + - std::to_string(dev_info_size) + "."); - } - - for (size_t i = 0; i < error_matrices_ids.size(); ++i) { - // Assign the index of the first zero diagonal element in each - // error matrix to the corresponding index in 'dev_info' - dev_info[error_matrices_ids[i]] = error_info[i]; - } - } catch (mkl_lapack::exception const &e) { - is_exception_caught = true; - info = e.info(); - - if (info < 0) { - error_msg << "Parameter number " << -info - << " had an illegal value."; - } - else if (info == scratchpad_size && e.detail() != 0) { - error_msg - << "Insufficient scratchpad size. Required size is at least " - << e.detail(); - } - else { - error_msg << "Unexpected MKL exception caught during getrs_batch() " - "call:\nreason: " - << e.what() << "\ninfo: " << e.info(); - } - } catch (sycl::exception const &e) { - is_exception_caught = true; - error_msg - << "Unexpected SYCL exception caught during getrs_batch() call:\n" - << e.what(); - } - - if (is_exception_caught) // an unexpected error occurs - { - if (scratchpad != nullptr) { - sycl::free(scratchpad, exec_q); - } - - throw std::runtime_error(error_msg.str()); - } - - sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(getrs_batch_event); - auto ctx = exec_q.get_context(); - cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); - }); - - return ht_ev; -} - -std::pair - getrs_batch(sycl::queue &exec_q, - dpctl::tensor::usm_ndarray a_array, - dpctl::tensor::usm_ndarray ipiv_array, - dpctl::tensor::usm_ndarray b_array, - std::int64_t batch_size, - std::int64_t n, - std::int64_t nrhs, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t stride_b, - py::list dev_info, - const std::vector &depends) -{ - const int a_array_nd = a_array.get_ndim(); - const int b_array_nd = b_array.get_ndim(); - const int ipiv_array_nd = ipiv_array.get_ndim(); - - if (a_array_nd != 3) { - throw py::value_error( - "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + - ", but a 3-dimensional array is expected."); - } - - if (b_array_nd < 2 || b_array_nd > 3) { - throw py::value_error("The dependent values array has ndim=" + - std::to_string(b_array_nd) + ", but a " + - "2-dimensional or a " + - "3-dimensional array is expected."); - } - - if (ipiv_array_nd != 2) { - throw py::value_error("The array of pivot indices has ndim=" + - std::to_string(ipiv_array_nd) + - ", but a 2-dimensional array is expected."); - } - - const int dev_info_size = py::len(dev_info); - if (dev_info_size != batch_size) { - throw py::value_error("The size of 'dev_info' (" + - std::to_string(dev_info_size) + - ") does not match the expected batch size (" + - std::to_string(batch_size) + ")."); - } - - // check compatibility of execution queue and allocation queue - if (!dpctl::utils::queues_are_compatible(exec_q, - {a_array, b_array, ipiv_array})) - { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(a_array, ipiv_array) || overlap(a_array, b_array) || - overlap(ipiv_array, b_array)) - { - throw py::value_error("Arrays have overlapping segments of memory"); - } - - bool is_a_array_c_contig = a_array.is_c_contiguous(); - bool is_a_array_f_contig = a_array.is_f_contiguous(); - bool is_b_array_f_contig = b_array.is_f_contiguous(); - bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); - bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous(); - if (!is_a_array_c_contig && !is_a_array_f_contig) { - throw py::value_error("The LU-factorized array " - "must be either C-contiguous " - "or F-contiguous"); - } - if (!is_b_array_f_contig) { - throw py::value_error("The right-hand sides array " - "must be F-contiguous"); - } - if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { - throw py::value_error("The array of pivot indices " - "must be contiguous"); - } - - auto array_types = dpctl_td_ns::usm_ndarray_types(); - int a_array_type_id = - array_types.typenum_to_lookup_id(a_array.get_typenum()); - int b_array_type_id = - array_types.typenum_to_lookup_id(b_array.get_typenum()); - - if (a_array_type_id != b_array_type_id) { - throw py::value_error("The types of the LU-factorized and " - "right-hand sides arrays are mismatched"); - } - - getrs_batch_impl_fn_ptr_t getrs_batch_fn = - getrs_batch_dispatch_vector[a_array_type_id]; - if (getrs_batch_fn == nullptr) { - throw py::value_error( - "No getrs_batch implementation defined for the provided type " - "of the input matrix."); - } - - auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); - int ipiv_array_type_id = - ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); - - if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { - throw py::value_error("The type of 'ipiv_array' must be int64."); - } - - // Use transpose::T if the LU-factorized array is passed as C-contiguous. - // For F-contiguous we use transpose::N. - oneapi::mkl::transpose trans = is_a_array_c_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; - - char *a_array_data = a_array.get_data(); - char *b_array_data = b_array.get_data(); - char *ipiv_array_data = ipiv_array.get_data(); - std::int64_t *d_ipiv = reinterpret_cast(ipiv_array_data); - - const std::int64_t lda = std::max(1UL, n); - const std::int64_t ldb = std::max(1UL, n); - - sycl::event getrs_batch_ev = - getrs_batch_fn(exec_q, trans, n, nrhs, a_array_data, lda, stride_a, - d_ipiv, stride_ipiv, b_array_data, ldb, stride_b, - batch_size, dev_info, depends); - - sycl::event ht_ev = dpctl::utils::keep_args_alive( - exec_q, {a_array, ipiv_array, b_array}, {getrs_batch_ev}); - - return std::make_pair(ht_ev, getrs_batch_ev); -} - -template -struct GetrsBatchContigFactory -{ - fnT get() - { - if constexpr (types::GetrsTypePairSupportFactory::is_defined) { - return getrs_batch_impl; - } - else { - return nullptr; - } - } -}; - -void init_getrs_batch_dispatch_vector(void) -{ - dpctl_td_ns::DispatchVectorBuilder - contig; - contig.populate_dispatch_vector(getrs_batch_dispatch_vector); -} -} // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 4667906b3da..b2981089339 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -59,7 +59,6 @@ void init_dispatch_vectors(void) lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); lapack_ext::init_getrs_dispatch_vector(); - lapack_ext::init_getrs_batch_dispatch_vector(); lapack_ext::init_orgqr_batch_dispatch_vector(); lapack_ext::init_orgqr_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); @@ -154,16 +153,6 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), py::arg("b_array"), py::arg("depends") = py::list()); - m.def("_getrs_batch", &lapack_ext::getrs_batch, - "Call `getrs_batch` from OneMKL LAPACK library to return " - "the solves of linear equations with a batch of LU-factored " - "square coefficient matrix, with multiple right-hand sides", - py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("b_array"), py::arg("batch_size"), py::arg("n"), - py::arg("nrhs"), py::arg("stride_a"), py::arg("stride_ipiv"), - py::arg("stride_b"), py::arg("dev_info_array"), - py::arg("depends") = py::list()); - m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " "the real orthogonal matrix Qi of the QR factorization " diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 878cbe3b997..7865775418d 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -272,185 +272,66 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type): a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1])) - new = True - - if new: - _manager = dpu.SequentialOrderManager[exec_q] - dep_evs = _manager.submitted_events - - # oneMKL LAPACK getri_batch overwrites `a` - a_h = dpnp.empty_like( - a, order="C", dtype=res_type, usm_type=res_usm_type - ) - - # use DPCTL tensor function to fill the matrix array - # with content from the input array `a` - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a.get_array(), - dst=a_h.get_array(), - sycl_queue=exec_q, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, copy_ev) - - batch_size = a.shape[0] - n = a.shape[1] - - ipiv_h = dpnp.empty( - (batch_size, n), - dtype=dpnp.int64, - usm_type=res_usm_type, - sycl_queue=exec_q, - ) - dev_info = [0] * batch_size - - a_stride = a_h.strides[0] - ipiv_stride = n - - # Call the LAPACK extension function _getrf_batch - # to perform LU decomposition of a batch of general matrices - ht_ev, getrf_batch_ev = li._getrf_batch( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - dev_info, - n, - a_stride, - ipiv_stride, - batch_size, - depends=[copy_ev], - ) - _manager.add_event_pair(ht_ev, getrf_batch_ev) - - _check_lapack_dev_info(dev_info) - - # The same for `b` if it is 3D; - # if it is 2D, transpose it. - if b.ndim > 2: - b = dpnp.moveaxis(b, (-2, -1), (0, 1)) - else: - b = b.T - b_usm_arr = dpnp.get_usm_ndarray(b) - - # oneMKL LAPACK getrs overwrites `b` and assumes fortran-like array as - # input. - # Allocate 'F' order memory for dpnp arrays to comply with - # these requirements. - b_f = dpnp.empty_like( - b, order="F", dtype=res_type, usm_type=res_usm_type - ) - - # use DPCTL tensor function to fill the array of multiple dependent - # variables with content from the input array `b` - ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr, - dst=b_f.get_array(), - sycl_queue=b.sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, b_copy_ev) - - nrhs = b_f.shape[1] if b_f.ndim > 2 else 1 - b_stride = b_f.strides[-1] - - # Call the LAPACK extension function _getrs_batch - # to solve the system of linear equations with a batch of LU-factored - # coefficient square matrix, with multiple right-hand sides. - ht_ev, getrs_batch_ev = li._getrs_batch( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - b_f.get_array(), - batch_size, - n, - nrhs, - a_stride, - ipiv_stride, - b_stride, - dev_info, - depends=[getrf_batch_ev, b_copy_ev], - ) - _manager.add_event_pair(ht_ev, getrs_batch_ev) - - _check_lapack_dev_info(dev_info) - - # Getrs_batch call overwtires `b` in Fortran order, reorder the axes - # to match C order by moving the last axis to the front and - # reshape it back to the original shape of `b`. - v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) - - # dpnp.moveaxis can make the array non-contiguous if it is not 2D - # Convert to contiguous to align with NumPy - if b.ndim > 2: - v = dpnp.ascontiguousarray(v) - - return v - + # Reorder the elements by moving the last two axes of `a` to the front + # to match fortran-like array order which is assumed by gesv. + a = dpnp.moveaxis(a, (-2, -1), (0, 1)) + # The same for `b` if it is 3D; + # if it is 2D, transpose it. + if b.ndim > 2: + b = dpnp.moveaxis(b, (-2, -1), (0, 1)) else: - # Reorder the elements by moving the last two axes of `a` to the front - # to match fortran-like array order which is assumed by gesv. - a = dpnp.moveaxis(a, (-2, -1), (0, 1)) - # The same for `b` if it is 3D; - # if it is 2D, transpose it. - if b.ndim > 2: - b = dpnp.moveaxis(b, (-2, -1), (0, 1)) - else: - b = b.T + b = b.T - a_usm_arr = dpnp.get_usm_ndarray(a) - b_usm_arr = dpnp.get_usm_ndarray(b) + a_usm_arr = dpnp.get_usm_ndarray(a) + b_usm_arr = dpnp.get_usm_ndarray(b) - _manager = dpu.SequentialOrderManager[exec_q] - dep_evs = _manager.submitted_events + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events - # oneMKL LAPACK gesv destroys `a` and assumes fortran-like array - # as input. - a_f = dpnp.empty_like( - a, dtype=res_type, order="F", usm_type=res_usm_type - ) + # oneMKL LAPACK gesv destroys `a` and assumes fortran-like array + # as input. + a_f = dpnp.empty_like(a, dtype=res_type, order="F", usm_type=res_usm_type) - ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_f.get_array(), - sycl_queue=exec_q, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, a_copy_ev) + ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, a_copy_ev) - # oneMKL LAPACK gesv overwrites `b` and assumes fortran-like array - # as input. - b_f = dpnp.empty_like( - b, order="F", dtype=res_type, usm_type=res_usm_type - ) + # oneMKL LAPACK gesv overwrites `b` and assumes fortran-like array + # as input. + b_f = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) - ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr, - dst=b_f.get_array(), - sycl_queue=exec_q, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, b_copy_ev) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) - ht_ev, gesv_batch_ev = li._gesv_batch( - exec_q, - a_f.get_array(), - b_f.get_array(), - depends=[a_copy_ev, b_copy_ev], - ) + ht_ev, gesv_batch_ev = li._gesv_batch( + exec_q, + a_f.get_array(), + b_f.get_array(), + depends=[a_copy_ev, b_copy_ev], + ) - _manager.add_event_pair(ht_ev, gesv_batch_ev) + _manager.add_event_pair(ht_ev, gesv_batch_ev) - # Gesv call overwtires `b` in Fortran order, reorder the axes - # to match C order by moving the last axis to the front and - # reshape it back to the original shape of `b`. - v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) + # Gesv call overwtires `b` in Fortran order, reorder the axes + # to match C order by moving the last axis to the front and + # reshape it back to the original shape of `b`. + v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) - # dpnp.moveaxis can make the array non-contiguous if it is not 2D - # Convert to contiguous to align with NumPy - if b.ndim > 2: - v = dpnp.ascontiguousarray(v) + # dpnp.moveaxis can make the array non-contiguous if it is not 2D + # Convert to contiguous to align with NumPy + if b.ndim > 2: + v = dpnp.ascontiguousarray(v) - return v + return v def _batched_qr(a, mode="reduced"): From 6d58280cd6340aa615e424a83e43f3b6cc936981 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 19 Aug 2024 11:12:34 +0200 Subject: [PATCH 07/12] Pass correct batch_strides to gesv_batch_fn --- dpnp/backend/extensions/lapack/gesv_batch.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index d77ccc6d120..3afa7ed3013 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -381,8 +381,9 @@ std::pair dependent_vals_strides.back(); sycl::event gesv_ev = - gesv_batch_fn(exec_q, n, nrhs, batch_size, stride_a, stride_b, - coeff_matrix_data, dependent_vals_data, depends); + gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride, + dependent_vals_batch_stride, coeff_matrix_data, + dependent_vals_data, depends); sycl::event ht_ev = dpctl::utils::keep_args_alive( exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); From cee01c1f0768b2b986c512108022eb60e0ecf656 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 20 Aug 2024 14:28:59 +0200 Subject: [PATCH 08/12] Apply review comments --- dpnp/backend/extensions/lapack/gesv.cpp | 5 +++-- dpnp/backend/extensions/lapack/gesv_batch.cpp | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 1f1521e0403..fe18add12a1 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -134,7 +134,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, scratchpad, ipiv, e, error_msg); } catch (sycl::exception const &e) { is_exception_caught = true; - error_msg << "Unexpected SYCL exception caught during getrf and " + error_msg << "Unexpected SYCL exception caught during getrf() or " "getrs() call:\n" << e.what(); } @@ -179,7 +179,8 @@ static sycl::event gesv_impl(sycl::queue &exec_q, error_msg << "Unexpected SYCL exception caught during gesv() call:\n" << e.what(); } -#endif +#endif // USE_ONEMKL_INTERFACES + if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index 3afa7ed3013..ae470168601 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -186,12 +186,12 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, } else { error_msg << "Unexpected MKL exception caught during getrf_batch() " - "and getrs_batch() call:\nreason: " + "or getrs_batch() call:\nreason: " << e.what() << "\ninfo: " << e.info(); } } catch (sycl::exception const &e) { is_exception_caught = true; - error_msg << "Unexpected SYCL exception caught during getrf and " + error_msg << "Unexpected SYCL exception caught during getrf() or " "getrs() call:\n" << e.what(); } @@ -271,7 +271,8 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, // Update the event dependencies for the current stream comp_evs[stream_id] = {gesv_event}; } -#endif +#endif // USE_ONEMKL_INTERFACES + if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) @@ -288,7 +289,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, for (const auto &ev : comp_evs) { cgh.depends_on(ev); } -#endif +#endif // USE_ONEMKL_INTERFACES auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { sycl::free(scratchpad, ctx); From 07106b23b174d590797b225f0d028e76a3aaaac7 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 20 Aug 2024 14:39:56 +0200 Subject: [PATCH 09/12] Reduce dublicate code for gesv_impl --- dpnp/backend/extensions/lapack/gesv.cpp | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index fe18add12a1..04237a1fb64 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -64,7 +64,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - std::int64_t scratchpad_size; + std::int64_t scratchpad_size = 0; sycl::event comp_event; std::int64_t *ipiv = nullptr; T *scratchpad = nullptr; @@ -81,6 +81,12 @@ static sycl::event gesv_impl(sycl::queue &exec_q, mkl_lapack::getrf_scratchpad_size(exec_q, n, n, lda), mkl_lapack::getrs_scratchpad_size(exec_q, trans, n, nrhs, lda, ldb)); +#else + scratchpad_size = + mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); + +#endif // USE_ONEMKL_INTERFACES + scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); try { @@ -91,6 +97,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, throw; } +#if defined(USE_ONEMKL_INTERFACES) sycl::event getrf_event; try { getrf_event = mkl_lapack::getrf( @@ -139,19 +146,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, << e.what(); } #else - scratchpad_size = - mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); - - scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); - - try { - ipiv = helper::alloc_ipiv(n, exec_q); - } catch (const std::exception &e) { - if (scratchpad != nullptr) - sycl::free(scratchpad, exec_q); - throw; - } - try { comp_event = mkl_lapack::gesv( exec_q, From 09df96738131bfb7c6ce288af97abad10f5c15c1 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 20 Aug 2024 15:01:42 +0200 Subject: [PATCH 10/12] Replace maybe_unused to if defined/else --- dpnp/backend/extensions/lapack/gesv_batch.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index ae470168601..899f0a8843c 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -44,8 +44,10 @@ typedef sycl::event (*gesv_batch_impl_fn_ptr_t)( const std::int64_t, const std::int64_t, const std::int64_t, +#if defined(USE_ONEMKL_INTERFACES) const std::int64_t, const std::int64_t, +#endif // USE_ONEMKL_INTERFACES char *, char *, const std::vector &); @@ -58,8 +60,10 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, const std::int64_t batch_size, - [[maybe_unused]] const std::int64_t stride_a, - [[maybe_unused]] const std::int64_t stride_b, +#if defined(USE_ONEMKL_INTERFACES) + const std::int64_t stride_a, + const std::int64_t stride_b, +#endif // USE_ONEMKL_INTERFACES char *in_a, char *in_b, const std::vector &depends) @@ -72,7 +76,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - std::int64_t scratchpad_size; + std::int64_t scratchpad_size = 0; sycl::event comp_event; std::int64_t *ipiv = nullptr; T *scratchpad = nullptr; @@ -374,6 +378,9 @@ std::pair auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector(); auto const &dependent_vals_strides = dependent_vals.get_strides_vector(); + sycl::event gesv_ev; + +#if defined(USE_ONEMKL_INTERFACES) // Get the strides for the batch matrices. // Since the matrices are stored in F-contiguous order, // the stride between batches is the last element in the strides vector. @@ -381,10 +388,14 @@ std::pair const std::int64_t dependent_vals_batch_stride = dependent_vals_strides.back(); - sycl::event gesv_ev = + gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride, dependent_vals_batch_stride, coeff_matrix_data, dependent_vals_data, depends); +#else + gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data, + dependent_vals_data, depends); +#endif // USE_ONEMKL_INTERFACES sycl::event ht_ev = dpctl::utils::keep_args_alive( exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); From c64935c4c7d0c2c1f9882f00a3f92ade9c679856 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 20 Aug 2024 15:05:33 +0200 Subject: [PATCH 11/12] Expand comments for trans parameter --- dpnp/backend/extensions/lapack/gesv.cpp | 1 + dpnp/backend/extensions/lapack/gesv_batch.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 04237a1fb64..5169fd338d3 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -75,6 +75,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, #if defined(USE_ONEMKL_INTERFACES) // Use transpose::T if the LU-factorized array is passed as C-contiguous. // For F-contiguous we use transpose::N. + // Since gesv takes F-contiguous as input, we use transpose::N. oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; scratchpad_size = std::max( diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index 899f0a8843c..d611458f196 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -87,6 +87,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q, #if defined(USE_ONEMKL_INTERFACES) // Use transpose::T if the LU-factorized array is passed as C-contiguous. // For F-contiguous we use transpose::N. + // Since gesv_batch takes F-contiguous as input, we use transpose::N. oneapi::mkl::transpose trans = oneapi::mkl::transpose::N; const std::int64_t stride_ipiv = n; From 37c84dda1f9d32306d5bafff308202ad84d3c1c4 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 20 Aug 2024 17:21:49 +0200 Subject: [PATCH 12/12] Apply review comments --- dpnp/backend/extensions/lapack/gesv.cpp | 3 +-- dpnp/backend/extensions/lapack/gesv_batch.cpp | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 5169fd338d3..e2f6d3ebd76 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -67,7 +67,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q, std::int64_t scratchpad_size = 0; sycl::event comp_event; std::int64_t *ipiv = nullptr; - T *scratchpad = nullptr; std::stringstream error_msg; bool is_exception_caught = false; @@ -88,7 +87,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q, #endif // USE_ONEMKL_INTERFACES - scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); try { ipiv = helper::alloc_ipiv(n, exec_q); diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp index d611458f196..3d0813f2484 100644 --- a/dpnp/backend/extensions/lapack/gesv_batch.cpp +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -376,12 +376,12 @@ std::pair const std::int64_t nrhs = (dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1; - auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector(); - auto const &dependent_vals_strides = dependent_vals.get_strides_vector(); - sycl::event gesv_ev; #if defined(USE_ONEMKL_INTERFACES) + auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector(); + auto const &dependent_vals_strides = dependent_vals.get_strides_vector(); + // Get the strides for the batch matrices. // Since the matrices are stored in F-contiguous order, // the stride between batches is the last element in the strides vector.