Skip to content

Commit

Permalink
Applying the review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Apr 29, 2023
1 parent 153f1ca commit 5859aec
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 21 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# CMake build and local install directory
_skbuild
build
build_cython
dpnp.egg-info

Expand Down
27 changes: 19 additions & 8 deletions dpnp/backend/extensions/lapack/heevd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;

template <typename T, typename RealT>
static inline sycl::event call_heevd(sycl::queue exec_q,
const oneapi::mkl::job jobz,
const oneapi::mkl::uplo upper_lower,
const std::int64_t n,
T* a,
RealT* w,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event>& depends)
static sycl::event call_heevd(sycl::queue exec_q,
const oneapi::mkl::job jobz,
const oneapi::mkl::uplo upper_lower,
const std::int64_t n,
T* a,
RealT* w,
std::vector<sycl::event>& host_task_events,
const std::vector<sycl::event>& depends)
{
validate_type_for_device<T>(exec_q);
validate_type_for_device<RealT>(exec_q);
Expand Down Expand Up @@ -171,6 +171,17 @@ std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
// throw py::value_error("Arrays index overlapping segments of memory");
// }

bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
if (!is_eig_vecs_f_contig)
{
throw py::value_error("An array with input matrix / ouput eigenvectors must be F-contiguous");
}
else if (!is_eig_vals_c_contig)
{
throw py::value_error("An array with output eigenvalues must be C-contiguous");
}

int eig_vecs_typenum = eig_vecs.get_typenum();
int eig_vals_typenum = eig_vals.get_typenum();
auto const& dpctl_capi = dpctl::detail::dpctl_capi::get();
Expand Down
27 changes: 19 additions & 8 deletions dpnp/backend/extensions/lapack/syevd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;

template <typename T>
static inline sycl::event call_syevd(sycl::queue exec_q,
const oneapi::mkl::job jobz,
const oneapi::mkl::uplo upper_lower,
const std::int64_t n,
T* a,
T* w,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event>& depends)
static sycl::event call_syevd(sycl::queue exec_q,
const oneapi::mkl::job jobz,
const oneapi::mkl::uplo upper_lower,
const std::int64_t n,
T* a,
T* w,
std::vector<sycl::event>& host_task_events,
const std::vector<sycl::event>& depends)
{
validate_type_for_device<T>(exec_q);

Expand Down Expand Up @@ -170,6 +170,17 @@ std::pair<sycl::event, sycl::event> syevd(sycl::queue exec_q,
// throw py::value_error("Arrays index overlapping segments of memory");
// }

bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
if (!is_eig_vecs_f_contig)
{
throw py::value_error("An array with input matrix / ouput eigenvectors must be F-contiguous");
}
else if (!is_eig_vals_c_contig)
{
throw py::value_error("An array with output eigenvalues must be C-contiguous");
}

int eig_vecs_typenum = eig_vecs.get_typenum();
int eig_vals_typenum = eig_vals.get_typenum();
auto const& dpctl_capi = dpctl::detail::dpctl_capi::get();
Expand Down
5 changes: 1 addition & 4 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,8 @@ def dpnp_eigh(a, UPLO):
# call LAPACK extension function to get eigenvalues and eigenvectors of a portion of matrix A
ht_lapack_ev[i], _ = getattr(li, lapack_func)(a_sycl_queue, jobz, uplo, eig_vecs[i].get_array(), w[i].get_array(), depends=[copy_ev])

# TODO: remove once dpctl fix is available
ht_lapack_ev[i].wait()

for i in range(op_count):
# ht_lapack_ev[i].wait()
ht_lapack_ev[i].wait()
ht_copy_ev[i].wait()

# combine the list of eigenvectors into a single array
Expand Down

0 comments on commit 5859aec

Please sign in to comment.