diff --git a/dpctl/tensor/libtensor/include/kernels/constructors.hpp b/dpctl/tensor/libtensor/include/kernels/constructors.hpp index 4023d291af..9b77b47a84 100644 --- a/dpctl/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpctl/tensor/libtensor/include/kernels/constructors.hpp @@ -129,6 +129,7 @@ sycl::event lin_space_step_impl(sycl::queue exec_q, char *array_data, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(exec_q); sycl::event lin_space_step_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.parallel_for>( @@ -270,6 +271,8 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q, char *array_data, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(exec_q); + bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64); sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -378,6 +381,7 @@ sycl::event full_contig_impl(sycl::queue q, char *dst_p, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(q); sycl::event fill_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); dstTy *p = reinterpret_cast(dst_p); @@ -496,6 +500,7 @@ sycl::event eye_impl(sycl::queue exec_q, char *array_data, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(exec_q); sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.parallel_for>( @@ -576,6 +581,8 @@ sycl::event tri_impl(sycl::queue exec_q, Ty *src = reinterpret_cast(src_p); Ty *dst = reinterpret_cast(dst_p); + dpctl::tensor::type_utils::validate_type_for_device(exec_q); + sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index 0d5a1d21ca..bd70f18334 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -215,6 +215,9 @@ copy_and_cast_generic_impl(sycl::queue q, const std::vector &depends, const std::vector &additional_depends) { + dpctl::tensor::type_utils::validate_type_for_device(q); + dpctl::tensor::type_utils::validate_type_for_device(q); + sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); @@ -317,6 +320,9 @@ copy_and_cast_nd_specialized_impl(sycl::queue q, py::ssize_t dst_offset, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(q); + dpctl::tensor::type_utils::validate_type_for_device(q); + sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.parallel_for>( @@ -486,6 +492,10 @@ void copy_and_cast_from_host_impl( const std::vector &additional_depends) { py::ssize_t nelems_range = src_max_nelem_offset - src_min_nelem_offset + 1; + + dpctl::tensor::type_utils::validate_type_for_device(q); + dpctl::tensor::type_utils::validate_type_for_device(q); + sycl::buffer npy_buf( reinterpret_cast(host_src_p) + src_min_nelem_offset, sycl::range<1>(nelems_range), {sycl::property::buffer::use_host_ptr{}}); @@ -637,6 +647,8 @@ copy_for_reshape_generic_impl(sycl::queue q, char *dst_p, const std::vector &depends) { + dpctl::tensor::type_utils::validate_type_for_device(q); + sycl::event copy_for_reshape_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.parallel_for>( diff --git a/dpctl/tensor/libtensor/include/utils/type_utils.hpp b/dpctl/tensor/libtensor/include/utils/type_utils.hpp index 181ff89adc..b6f4a657f4 100644 --- a/dpctl/tensor/libtensor/include/utils/type_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_utils.hpp @@ -23,7 +23,9 @@ //===----------------------------------------------------------------------===// #pragma once +#include #include +#include namespace dpctl { @@ -68,6 +70,29 @@ template dstTy convert_impl(const srcTy &v) } } +template void validate_type_for_device(const sycl::device &d) +{ + if constexpr (std::is_same_v) { + if (!d.has(sycl::aspect::fp64)) { + throw std::runtime_error("Device " + + d.get_info() + + " does not support type 'double'"); + } + } + else if constexpr (std::is_same_v) { + if (!d.has(sycl::aspect::fp16)) { + throw std::runtime_error("Device " + + d.get_info() + + " does not support type 'half'"); + } + } +} + +template void validate_type_for_device(const sycl::queue &q) +{ + validate_type_for_device(q.get_device()); +} + } // namespace type_utils } // namespace tensor } // namespace dpctl