Skip to content

Commit

Permalink
Implement syevd_batch and heevd_batch (#1936)
Browse files Browse the repository at this point in the history
* Implement syevd_batch and heevd_batch

* Move include dpctl type_utils header to sourse files

* Add memory alocation check for scratchpad

* Add more checks for scratchpad_size

* Move includes

* Allocate memory for w with expected shape

* Applied review comments

* Add common_evd_checks to reduce dublicate code

* Remove host_task_events from syevd and heevd

* Applied review comments

* Use init_evd_dispatch_table instead of init_evd_batch_dispatch_table

* Move init_evd_dispatch_table to evd_common_utils.hpp

* Add helper function check_zeros_shape

* Implement alloc_scratchpad function to evd_batch_common.hpp

* Make round_up_mult as inline

* Add comment for check_zeros_shape

* Make alloc_scratchpad as inline
  • Loading branch information
vlad-perevezentsev authored Jul 26, 2024
1 parent e33a82b commit aa534f8
Show file tree
Hide file tree
Showing 15 changed files with 851 additions and 146 deletions.
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
)
Expand Down
23 changes: 23 additions & 0 deletions dpnp/backend/extensions/lapack/common_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
//*****************************************************************************

#pragma once
#include <complex>
#include <cstring>
#include <pybind11/pybind11.h>
#include <stdexcept>

namespace dpnp::extensions::lapack::helper
{
namespace py = pybind11;

template <typename T>
struct value_type_of
{
Expand All @@ -40,4 +44,23 @@ struct value_type_of<std::complex<T>>
{
using type = T;
};

// Rounds up the number `value` to the nearest multiple of `mult`.
template <typename intT>
inline intT round_up_mult(intT value, intT mult)
{
intT q = (value + (mult - 1)) / mult;
return q * mult;
}

// Checks if the shape array has any non-zero dimension.
inline bool check_zeros_shape(int ndim, const py::ssize_t *shape)
{
size_t src_nelems(1);

for (int i = 0; i < ndim; ++i) {
src_nelems *= static_cast<size_t>(shape[i]);
}
return src_nelems == 0;
}
} // namespace dpnp::extensions::lapack::helper
152 changes: 152 additions & 0 deletions dpnp/backend/extensions/lapack/evd_batch_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <oneapi/mkl.hpp>
#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/type_dispatch.hpp"

#include "common_helpers.hpp"
#include "evd_common_utils.hpp"
#include "types_matrix.hpp"

namespace dpnp::extensions::lapack::evd
{
typedef sycl::event (*evd_batch_impl_fn_ptr_t)(
sycl::queue &,
const oneapi::mkl::job,
const oneapi::mkl::uplo,
const std::int64_t,
const std::int64_t,
char *,
char *,
const std::vector<sycl::event> &);

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace py = pybind11;

template <typename dispatchT>
std::pair<sycl::event, sycl::event>
evd_batch_func(sycl::queue &exec_q,
const std::int8_t jobz,
const std::int8_t upper_lower,
dpctl::tensor::usm_ndarray &eig_vecs,
dpctl::tensor::usm_ndarray &eig_vals,
const std::vector<sycl::event> &depends,
const dispatchT &evd_batch_dispatch_table)
{
const int eig_vecs_nd = eig_vecs.get_ndim();

const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();

constexpr int expected_eig_vecs_nd = 3;
constexpr int expected_eig_vals_nd = 2;

common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
expected_eig_vecs_nd, expected_eig_vals_nd);

if (eig_vecs_shape[2] != eig_vals_shape[0] ||
eig_vecs_shape[0] != eig_vals_shape[1])
{
throw py::value_error(
"The shape of 'eig_vals' must be (batch_size, n), "
"where batch_size = " +
std::to_string(eig_vecs_shape[0]) +
" and n = " + std::to_string(eig_vecs_shape[1]));
}

// Ensure `batch_size` and `n` are non-zero, otherwise return empty events
if (helper::check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
// nothing to do
return std::make_pair(sycl::event(), sycl::event());
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
const int eig_vecs_type_id =
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
const int eig_vals_type_id =
array_types.typenum_to_lookup_id(eig_vals.get_typenum());

evd_batch_impl_fn_ptr_t evd_batch_fn =
evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
if (evd_batch_fn == nullptr) {
throw py::value_error(
"Types of input vectors and result array are mismatched.");
}

char *eig_vecs_data = eig_vecs.get_data();
char *eig_vals_data = eig_vals.get_data();

const std::int64_t batch_size = eig_vecs_shape[2];
const std::int64_t n = eig_vecs_shape[1];

const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
const oneapi::mkl::uplo uplo_val =
static_cast<oneapi::mkl::uplo>(upper_lower);

sycl::event evd_batch_ev =
evd_batch_fn(exec_q, jobz_val, uplo_val, batch_size, n, eig_vecs_data,
eig_vals_data, depends);

sycl::event ht_ev = dpctl::utils::keep_args_alive(
exec_q, {eig_vecs, eig_vals}, {evd_batch_ev});

return std::make_pair(ht_ev, evd_batch_ev);
}

template <typename T>
inline T *alloc_scratchpad(std::int64_t scratchpad_size,
std::int64_t n_linear_streams,
sycl::queue &exec_q)
{
// Get padding size to ensure memory allocations are aligned to 256 bytes
// for better performance
const std::int64_t padding = 256 / sizeof(T);

if (scratchpad_size <= 0) {
throw std::runtime_error(
"Invalid scratchpad size: must be greater than zero."
" Calculated scratchpad size: " +
std::to_string(scratchpad_size));
}

// Calculate the total scratchpad memory size needed for all linear
// streams with proper alignment
const size_t alloc_scratch_size =
helper::round_up_mult(n_linear_streams * scratchpad_size, padding);

// Allocate memory for the total scratchpad
T *scratchpad = sycl::malloc_device<T>(alloc_scratch_size, exec_q);
if (!scratchpad) {
throw std::runtime_error("Device allocation for scratchpad failed");
}

return scratchpad;
}
} // namespace dpnp::extensions::lapack::evd
89 changes: 18 additions & 71 deletions dpnp/backend/extensions/lapack/evd_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@
#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/output_validation.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"

#include "common_helpers.hpp"
#include "evd_common_utils.hpp"
#include "types_matrix.hpp"

namespace dpnp::extensions::lapack::evd
{
using dpnp::extensions::lapack::helper::check_zeros_shape;

typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &,
const oneapi::mkl::job,
const oneapi::mkl::uplo,
const std::int64_t,
char *,
char *,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
Expand All @@ -61,70 +61,30 @@ std::pair<sycl::event, sycl::event>
const dispatchT &evd_dispatch_table)
{
const int eig_vecs_nd = eig_vecs.get_ndim();
const int eig_vals_nd = eig_vals.get_ndim();

if (eig_vecs_nd != 2) {
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
" of an output array with eigenvectors");
}
else if (eig_vals_nd != 1) {
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
" of an output array with eigenvalues");
}

const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();

if (eig_vecs_shape[0] != eig_vecs_shape[1]) {
throw py::value_error("Output array with eigenvectors with be square");
}
else if (eig_vecs_shape[0] != eig_vals_shape[0]) {
throw py::value_error(
"Eigenvectors and eigenvalues have different shapes");
}
constexpr int expected_eig_vecs_nd = 2;
constexpr int expected_eig_vals_nd = 1;

size_t src_nelems(1);
common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
expected_eig_vecs_nd, expected_eig_vals_nd);

for (int i = 0; i < eig_vecs_nd; ++i) {
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
if (eig_vecs_shape[0] != eig_vals_shape[0]) {
throw py::value_error(
"Eigenvectors and eigenvalues have different shapes");
}

if (src_nelems == 0) {
if (check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
// nothing to do
return std::make_pair(sycl::event(), sycl::event());
}

dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vecs);
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vals);

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
throw py::value_error(
"Execution queue is not compatible with allocation queues");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(eig_vecs, eig_vals)) {
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
"overlapping segments of memory");
}

bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
if (!is_eig_vecs_f_contig) {
throw py::value_error(
"An array with input matrix / output eigenvectors "
"must be F-contiguous");
}
else if (!is_eig_vals_c_contig) {
throw py::value_error(
"An array with output eigenvalues must be C-contiguous");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int eig_vecs_type_id =
const int eig_vecs_type_id =
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
int eig_vals_type_id =
const int eig_vals_type_id =
array_types.typenum_to_lookup_id(eig_vals.get_typenum());

evd_impl_fn_ptr_t evd_fn =
Expand All @@ -142,25 +102,12 @@ std::pair<sycl::event, sycl::event>
const oneapi::mkl::uplo uplo_val =
static_cast<oneapi::mkl::uplo>(upper_lower);

std::vector<sycl::event> host_task_events;
sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data,
eig_vals_data, host_task_events, depends);
eig_vals_data, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {eig_vecs, eig_vals}, host_task_events);
sycl::event ht_ev =
dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, {evd_ev});

return std::make_pair(args_ev, evd_ev);
}

template <typename dispatchT,
template <typename fnT, typename T, typename RealT>
typename factoryT>
void init_evd_dispatch_table(
dispatchT evd_dispatch_table[][dpctl_td_ns::num_types])
{
dpctl_td_ns::DispatchTableBuilder<dispatchT, factoryT,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_table(evd_dispatch_table);
return std::make_pair(ht_ev, evd_ev);
}
} // namespace dpnp::extensions::lapack::evd
Loading

0 comments on commit aa534f8

Please sign in to comment.