From 4f39b22bac1443e1d81e3bc1b53dba94d86612a0 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Tue, 27 Aug 2024 16:10:40 +0200 Subject: [PATCH] Move check for incompatible container earlier --- src/sparse_blas/backends/mkl_common/mkl_spmm.cxx | 8 ++++---- src/sparse_blas/backends/mkl_common/mkl_spmv.cxx | 8 ++++---- src/sparse_blas/backends/mkl_common/mkl_spsv.cxx | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx index eb1b45ebf..857a1983c 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmm.cxx @@ -144,12 +144,12 @@ void spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl:: oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, sycl::buffer /*workspace*/) { - common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, - spmm_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, + spmm_descr); if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { return; } @@ -166,12 +166,12 @@ sycl::event spmm_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, oneapi::mkl::sparse::spmm_alg alg, oneapi::mkl::sparse::spmm_descr_t spmm_descr, void * /*workspace*/, const std::vector &dependencies) { - common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, - spmm_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spmm_optimize(queue, opA, opB, alpha, A_view, A_handle, B_handle, beta, C_handle, alg, + spmm_descr); if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); } diff --git a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx index 4e5aeffdb..7ddd534d8 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spmv.cxx @@ -134,12 +134,12 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, sycl::buffer /*workspace*/) { - common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, - spmv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, + spmv_descr); if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { return; } @@ -166,12 +166,12 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr, void * /*workspace*/, const std::vector &dependencies) { - common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, - spmv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spmv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, beta, y_handle, alg, + spmv_descr); if (alg == oneapi::mkl::sparse::spmv_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); } diff --git a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx index 371fac38b..078a5abac 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx +++ b/src/sparse_blas/backends/mkl_common/mkl_spsv.cxx @@ -130,11 +130,11 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a oneapi::mkl::sparse::dense_vector_handle_t y_handle, oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, sycl::buffer /*workspace*/) { - common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (!internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) { return; } @@ -151,11 +151,11 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const oneapi::mkl::sparse::spsv_alg alg, oneapi::mkl::sparse::spsv_descr_t spsv_descr, void * /*workspace*/, const std::vector &dependencies) { - common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); auto internal_A_handle = detail::get_internal_handle(A_handle); if (internal_A_handle->all_use_buffer()) { detail::throw_incompatible_container(__func__); } + common_spsv_optimize(queue, opA, alpha, A_view, A_handle, x_handle, y_handle, alg, spsv_descr); if (alg == oneapi::mkl::sparse::spsv_alg::no_optimize_alg) { return detail::collapse_dependencies(queue, dependencies); }