diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 05722b8d05..12458ced6f 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -987,6 +987,8 @@ sycl::event keep_args_alive(sycl::queue q, return host_task_ev; } +/*! @brief Check if all allocation queues are the same as the + execution queue */ template bool queues_are_compatible(sycl::queue exec_q, const sycl::queue (&alloc_qs)[num]) @@ -1000,6 +1002,21 @@ bool queues_are_compatible(sycl::queue exec_q, return true; } +/*! @brief Check if all allocation queues of usm_ndarays are the same as + the execution queue */ +template +bool queues_are_compatible(sycl::queue exec_q, + const ::dpctl::tensor::usm_ndarray (&arrs)[num]) +{ + for (std::size_t i = 0; i < num; ++i) { + + if (exec_q != arrs[i].get_queue()) { + return false; + } + } + return true; +} + } // end namespace utils } // end namespace dpctl diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp index c81430d54b..c629f585c3 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp @@ -160,10 +160,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src, } // check compatibility of execution queue and allocation queue - sycl::queue src_q = src.get_queue(); - sycl::queue dst_q = dst.get_queue(); - - if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { throw py::value_error( "Execution queue is not compatible with allocation queues"); } diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp index 7f7e866bb1..3a07ac2bd3 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp @@ -101,10 +101,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, } // check same contexts - sycl::queue src_q = src.get_queue(); - sycl::queue dst_q = dst.get_queue(); - - if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { throw py::value_error( "Execution queue is not compatible with allocation queues"); } diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp index c6b42e48ff..5474d96fe5 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp @@ -101,9 +101,7 @@ void copy_numpy_ndarray_into_usm_ndarray( } } - sycl::queue dst_q = dst.get_queue(); - - if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) { throw py::value_error("Execution queue is not compatible with the " "allocation queue"); } diff --git a/dpctl/tensor/libtensor/source/eye_ctor.cpp b/dpctl/tensor/libtensor/source/eye_ctor.cpp index d36447749a..867e862633 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.cpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.cpp @@ -61,8 +61,7 @@ usm_ndarray_eye(py::ssize_t k, "usm_ndarray_eye: Expecting 2D array to populate"); } - sycl::queue dst_q = dst.get_queue(); - if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) { throw py::value_error("Execution queue is not compatible with the " "allocation queue"); } diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index e5b1da362b..4ccbcd9277 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -69,8 +69,7 @@ usm_ndarray_full(py::object py_value, return std::make_pair(sycl::event(), sycl::event()); } - sycl::queue dst_q = dst.get_queue(); - if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) { throw py::value_error( "Execution queue is not compatible with the allocation queue"); } diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index 8b72923679..6c225690a1 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -78,8 +78,7 @@ usm_ndarray_linear_sequence_step(py::object start, "usm_ndarray_linspace: Non-contiguous arrays are not supported"); } - sycl::queue dst_q = dst.get_queue(); - if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) { throw py::value_error( "Execution queue is not compatible with the allocation queue"); } @@ -127,8 +126,7 @@ usm_ndarray_linear_sequence_affine(py::object start, "usm_ndarray_linspace: Non-contiguous arrays are not supported"); } - sycl::queue dst_q = dst.get_queue(); - if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) { + if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) { throw py::value_error( "Execution queue context is not the same as allocation context"); } diff --git a/dpctl/tensor/libtensor/source/triul_ctor.cpp b/dpctl/tensor/libtensor/source/triul_ctor.cpp index fccf483931..3967914425 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.cpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.cpp @@ -121,11 +121,8 @@ usm_ndarray_triul(sycl::queue exec_q, throw py::value_error("Array dtype are not the same."); } - // check same contexts - sycl::queue src_q = src.get_queue(); - sycl::queue dst_q = dst.get_queue(); - - if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) { + // check same queues + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { throw py::value_error( "Execution queue context is not the same as allocation contexts"); }