Skip to content

Commit

Permalink
Move check for incompatible container earlier
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Aug 27, 2024
1 parent 6a533df commit 4f39b22
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/sparse_blas/backends/mkl_common/mkl_spmm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::
oneapi::mkl::sparse::dense_matrix_handle_t C_handle,
oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return;
}
Expand All @@ -166,12 +166,12 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA,
oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t spmm_descr, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg,
spmm_descr);
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
Expand Down
8 changes: 4 additions & 4 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
return;
}
Expand All @@ -166,12 +166,12 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t spmv_descr, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg,
spmv_descr);
if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
Expand Down
4 changes: 2 additions & 2 deletions src/sparse_blas/backends/mkl_common/mkl_spsv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr);
if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) {
return;
}
Expand All @@ -151,11 +151,11 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spsv_alg alg,
oneapi::mkl::sparse::spsv_descr_t spsv_descr, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__func__);
}
common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr);
if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
Expand Down

0 comments on commit 4f39b22

Please sign in to comment.