Skip to content

Commit

Permalink
Extend gesv_impl/gesv_batch_impl for work with oneMKL Interfaces (#…
Browse files Browse the repository at this point in the history
…2001)

* Implement batch solve via getrf_batch and getrs_batch

* Pass sycl::queue by reference for getrs/getrs_batch

* Extend gesv_impl to use onemkl interfaces

* Reduce code duplication in gesv_impl

* Extend gesv_batch_impl to use onemkl interfaces

* Remove getrs_batch implementation

* Pass correct batch_strides to gesv_batch_fn

* Reduce dublicate code for gesv_impl

* Replace maybe_unused to if defined/else

* Expand comments for trans parameter

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Aug 21, 2024
1 parent 20cfa81 commit f22a4cb
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 29 deletions.
85 changes: 72 additions & 13 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
char *in_b,
const std::vector<sycl::event> &depends)
{
#if defined(USE_ONEMKL_INTERFACES)
// Temporary flag for build only
// FIXME: Need to implement by using lapack::getrf and lapack::getrs
std::logic_error("Not Implemented");
#else
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
Expand All @@ -69,12 +64,31 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

const std::int64_t scratchpad_size =
std::int64_t scratchpad_size = 0;
sycl::event comp_event;
std::int64_t *ipiv = nullptr;

std::stringstream error_msg;
bool is_exception_caught = false;

#if defined(USE_ONEMKL_INTERFACES)
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
// For F-contiguous we use transpose::N.
// Since gesv takes F-contiguous as input, we use transpose::N.
oneapi::mkl::transpose trans = oneapi::mkl::transpose::N;

scratchpad_size = std::max(
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda),
mkl_lapack::getrs_scratchpad_size<T>(exec_q, trans, n, nrhs, lda, ldb));

#else
scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);

#endif // USE_ONEMKL_INTERFACES

T *scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);

std::int64_t *ipiv = nullptr;
try {
ipiv = helper::alloc_ipiv(n, exec_q);
} catch (const std::exception &e) {
Expand All @@ -83,12 +97,57 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
throw;
}

std::stringstream error_msg;
bool is_exception_caught = false;
#if defined(USE_ONEMKL_INTERFACES)
sycl::event getrf_event;
try {
getrf_event = mkl_lapack::getrf(
exec_q,
n, // The order of the square matrix A (0 ≤ n).
// It must be a non-negative integer.
n, // The number of columns in the square matrix A (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the square matrix A (n x n).
lda, // The leading dimension of matrix A.
// It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);

sycl::event gesv_event;
comp_event = mkl_lapack::getrs(
exec_q,
trans, // Specifies the operation: whether or not to transpose
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
// and 'C' for conjugate transpose.
n, // The order of the square matrix A
// and the number of rows in matrix B (0 ≤ n).
// It must be a non-negative integer.
nrhs, // The number of right-hand sides,
// i.e., the number of columns in matrix B (0 ≤ nrhs).
a, // Pointer to the square matrix A (n x n).
lda, // The leading dimension of matrix A, must be at least max(1,
// n). It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices that were used
// during factorization (n, ).
b, // Pointer to the matrix B of right-hand sides (ldb, nrhs).
ldb, // The leading dimension of matrix B, must be at least max(1,
// n).
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, {getrf_event});
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size,
scratchpad, ipiv, e, error_msg);
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during getrf() or "
"getrs() call:\n"
<< e.what();
}
#else
try {
gesv_event = mkl_lapack::gesv(
comp_event = mkl_lapack::gesv(
exec_q,
n, // The order of the square matrix A
// and the number of rows in matrix B (0 ≤ n).
Expand All @@ -114,6 +173,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
<< e.what();
}
#endif // USE_ONEMKL_INTERFACES

if (is_exception_caught) // an unexpected error occurs
{
Expand All @@ -125,7 +185,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
}

sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(gesv_event);
cgh.depends_on(comp_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
Expand All @@ -134,7 +194,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
});

return ht_ev;
#endif // USE_ONEMKL_INTERFACES
}

std::pair<sycl::event, sycl::event>
Expand Down
176 changes: 164 additions & 12 deletions dpnp/backend/extensions/lapack/gesv_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ typedef sycl::event (*gesv_batch_impl_fn_ptr_t)(
const std::int64_t,
const std::int64_t,
const std::int64_t,
#if defined(USE_ONEMKL_INTERFACES)
const std::int64_t,
const std::int64_t,
#endif // USE_ONEMKL_INTERFACES
char *,
char *,
const std::vector<sycl::event> &);
Expand All @@ -56,6 +60,10 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
const std::int64_t n,
const std::int64_t nrhs,
const std::int64_t batch_size,
#if defined(USE_ONEMKL_INTERFACES)
const std::int64_t stride_a,
const std::int64_t stride_b,
#endif // USE_ONEMKL_INTERFACES
char *in_a,
char *in_b,
const std::vector<sycl::event> &depends)
Expand All @@ -65,23 +73,147 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
T *a = reinterpret_cast<T *>(in_a);
T *b = reinterpret_cast<T *>(in_b);

const std::int64_t a_size = n * n;
const std::int64_t b_size = n * nrhs;

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

std::int64_t scratchpad_size = 0;
sycl::event comp_event;
std::int64_t *ipiv = nullptr;
T *scratchpad = nullptr;

std::stringstream error_msg;
bool is_exception_caught = false;

#if defined(USE_ONEMKL_INTERFACES)
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
// For F-contiguous we use transpose::N.
// Since gesv_batch takes F-contiguous as input, we use transpose::N.
oneapi::mkl::transpose trans = oneapi::mkl::transpose::N;
const std::int64_t stride_ipiv = n;

scratchpad_size = std::max(
mkl_lapack::getrs_batch_scratchpad_size<T>(exec_q, trans, n, nrhs, lda,
stride_a, stride_ipiv, ldb,
stride_b, batch_size),
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
stride_ipiv, batch_size));

scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);

// pass batch_size * n to allocate the memory for a 2D array of pivot
// indices
try {
ipiv = helper::alloc_ipiv(batch_size * n, exec_q);
} catch (const std::exception &e) {
if (scratchpad != nullptr)
sycl::free(scratchpad, exec_q);
throw;
}

sycl::event getrf_batch_event;
try {
getrf_batch_event = mkl_lapack::getrf_batch(
exec_q,
n, // The order of each square matrix in the batch; (0 ≤ n).
// It must be a non-negative integer.
n, // The number of columns in each matrix in the batch; (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the batch of square matrices, each of size (n x n).
lda, // The leading dimension of each matrix in the batch.
stride_a, // Stride between consecutive matrices in the batch.
ipiv, // Pointer to the array of pivot indices for each matrix in
// the batch.
stride_ipiv, // Stride between pivot indices: Spacing between pivot
// arrays in 'ipiv'.
batch_size, // Stride between pivot index arrays in the batch.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);

comp_event = mkl_lapack::getrs_batch(
exec_q,
trans, // Specifies the operation: whether or not to transpose
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
// and 'C' for conjugate transpose.
n, // The order of each square matrix A in the batch
// and the number of rows in each matrix B (0 ≤ n).
// It must be a non-negative integer.
nrhs, // The number of right-hand sides,
// i.e., the number of columns in each matrix B in the batch
// (0 ≤ nrhs).
a, // Pointer to the batch of square matrices A (n x n).
lda, // The leading dimension of each matrix A in the batch.
// It must be at least max(1, n).
stride_a, // Stride between individual matrices in the batch for
// matrix A.
ipiv, // Pointer to the batch of arrays of pivot indices.
stride_ipiv, // Stride between pivot index arrays in the batch.
b, // Pointer to the batch of matrices B (n, nrhs).
ldb, // The leading dimension of each matrix B in the batch.
// Must be at least max(1, n).
stride_b, // Stride between individual matrices in the batch for
// matrix B.
batch_size, // The number of matrices in the batch.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, {getrf_batch_event});
} catch (mkl_lapack::batch_error const &be) {
// Get the indices of matrices within the batch that encountered an
// error
auto error_matrices_ids = be.ids();

error_msg << "Singular matrix. Errors in matrices with IDs: ";
for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
error_msg << error_matrices_ids[i];
if (i < error_matrices_ids.size() - 1) {
error_msg << ", ";
}
}
error_msg << ".";

if (scratchpad != nullptr)
sycl::free(scratchpad, exec_q);
if (ipiv != nullptr)
sycl::free(ipiv, exec_q);

throw LinAlgError(error_msg.str().c_str());
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
std::int64_t info = e.info();
if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else {
error_msg << "Unexpected MKL exception caught during getrf_batch() "
"or getrs_batch() call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during getrf() or "
"getrs() call:\n"
<< e.what();
}
#else
const std::int64_t a_size = n * n;
const std::int64_t b_size = n * nrhs;

// Get the number of independent linear streams
const std::int64_t n_linear_streams =
(batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1));

const std::int64_t scratchpad_size =
scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);

T *scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
n_linear_streams, exec_q);
scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
n_linear_streams, exec_q);

std::int64_t *ipiv = nullptr;
try {
ipiv = helper::alloc_ipiv_batch<T>(n, n_linear_streams, exec_q);
} catch (const std::exception &e) {
Expand All @@ -93,9 +225,6 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
// Computation events to manage dependencies for each linear stream
std::vector<std::vector<sycl::event>> comp_evs(n_linear_streams, depends);

std::stringstream error_msg;
bool is_exception_caught = false;

// Release GIL to avoid serialization of host task
// submissions to the same queue in OneMKL
py::gil_scoped_release release;
Expand Down Expand Up @@ -147,6 +276,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
// Update the event dependencies for the current stream
comp_evs[stream_id] = {gesv_event};
}
#endif // USE_ONEMKL_INTERFACES

if (is_exception_caught) // an unexpected error occurs
{
Expand All @@ -158,9 +288,13 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
}

sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
#if defined(USE_ONEMKL_INTERFACES)
cgh.depends_on(comp_event);
#else
for (const auto &ev : comp_evs) {
cgh.depends_on(ev);
}
#endif // USE_ONEMKL_INTERFACES
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
Expand Down Expand Up @@ -242,9 +376,27 @@ std::pair<sycl::event, sycl::event>
const std::int64_t nrhs =
(dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1;

sycl::event gesv_ev =
gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data,
sycl::event gesv_ev;

#if defined(USE_ONEMKL_INTERFACES)
auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector();
auto const &dependent_vals_strides = dependent_vals.get_strides_vector();

// Get the strides for the batch matrices.
// Since the matrices are stored in F-contiguous order,
// the stride between batches is the last element in the strides vector.
const std::int64_t coeff_matrix_batch_stride = coeff_matrix_strides.back();
const std::int64_t dependent_vals_batch_stride =
dependent_vals_strides.back();

gesv_ev =
gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride,
dependent_vals_batch_stride, coeff_matrix_data,
dependent_vals_data, depends);
#else
gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data,
dependent_vals_data, depends);
#endif // USE_ONEMKL_INTERFACES

sycl::event ht_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, {gesv_ev});
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/lapack/getrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
const std::int64_t,
const std::int64_t,
Expand All @@ -56,7 +56,7 @@ typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue,
static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event getrs_impl(sycl::queue exec_q,
static sycl::event getrs_impl(sycl::queue &exec_q,
oneapi::mkl::transpose trans,
const std::int64_t n,
const std::int64_t nrhs,
Expand Down Expand Up @@ -156,7 +156,7 @@ static sycl::event getrs_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
getrs(sycl::queue exec_q,
getrs(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray ipiv_array,
dpctl::tensor::usm_ndarray b_array,
Expand Down
Loading

0 comments on commit f22a4cb

Please sign in to comment.