diff --git a/sycl/source/detail/queue_impl.hpp b/sycl/source/detail/queue_impl.hpp index 9a74c21f8fd39..875c4065ae370 100644 --- a/sycl/source/detail/queue_impl.hpp +++ b/sycl/source/detail/queue_impl.hpp @@ -62,10 +62,8 @@ class queue_impl { ContextImplPtr DefaultContext = detail::getSyclObjImpl( Device->get_platform().ext_oneapi_get_default_context()); - - if (DefaultContext->hasDevice(Device)) + if (isValidDevice(DefaultContext, Device)) return DefaultContext; - return detail::getSyclObjImpl( context{createSyclObjFromImpl(Device), {}, {}}); } @@ -109,11 +107,20 @@ class queue_impl { "Queue cannot be constructed with both of " "discard_events and enable_profiling."); } - if (!Context->hasDevice(Device)) - throw cl::sycl::invalid_object_error( + if (!isValidDevice(Context, Device)) { + if (!Context->is_host() && + Context->getPlugin().getBackend() == backend::opencl) + throw sycl::invalid_object_error( + "Queue cannot be constructed with the given context and device " + "since the device is not a member of the context (descendants of " + "devices from the context are not supported on OpenCL yet).", + PI_ERROR_INVALID_DEVICE); + throw sycl::invalid_object_error( "Queue cannot be constructed with the given context and device " - "as the context does not contain the given device.", + "since the device is neither a member of the context nor a " + "descendant of its member.", PI_ERROR_INVALID_DEVICE); + } if (!MHostQueue) { const QueueOrder QOrder = MPropList.has_property() @@ -476,6 +483,27 @@ class queue_impl { } private: + /// Helper function for checking whether a device is either a member of a + /// context or a descendnant of its member. + /// \return True iff the device or its parent is a member of the context. + static bool isValidDevice(const ContextImplPtr &Context, + DeviceImplPtr Device) { + // OpenCL does not support creating a queue with a descendant of a device + // from the given context yet. + // TODO remove once this limitation is lifted + if (!Context->is_host() && + Context->getPlugin().getBackend() == backend::opencl) + return Context->hasDevice(Device); + + while (!Context->hasDevice(Device)) { + if (Device->isRootDevice()) + return false; + Device = detail::getSyclObjImpl( + Device->get_info()); + } + return true; + } + /// Performs command group submission to the queue. /// /// \param CGF is a function object containing command group. diff --git a/sycl/unittests/queue/CMakeLists.txt b/sycl/unittests/queue/CMakeLists.txt index 80b236d8235f3..d5210051a1298 100644 --- a/sycl/unittests/queue/CMakeLists.txt +++ b/sycl/unittests/queue/CMakeLists.txt @@ -1,4 +1,5 @@ add_sycl_unittest(QueueTests OBJECT + DeviceCheck.cpp EventClear.cpp USM.cpp Wait.cpp diff --git a/sycl/unittests/queue/DeviceCheck.cpp b/sycl/unittests/queue/DeviceCheck.cpp new file mode 100644 index 0000000000000..b2e8b4499d42f --- /dev/null +++ b/sycl/unittests/queue/DeviceCheck.cpp @@ -0,0 +1,170 @@ +//==----------------- DeviceCheck.cpp --- queue unit tests -----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +namespace { + +inline constexpr auto EnableDefaultContextsName = + "SYCL_ENABLE_DEFAULT_CONTEXTS"; + +pi_result redefinedContextCreate(const pi_context_properties *properties, + pi_uint32 num_devices, + const pi_device *devices, + void (*pfn_notify)(const char *errinfo, + const void *private_info, + size_t cb, void *user_data), + void *user_data, pi_context *ret_context) { + *ret_context = reinterpret_cast(1); + return PI_SUCCESS; +} + +pi_result redefinedContextRelease(pi_context context) { return PI_SUCCESS; } + +pi_device ParentDevice = nullptr; +pi_platform PiPlatform = nullptr; + +pi_result redefinedDeviceGetInfo(pi_device device, pi_device_info param_name, + size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + if (param_name == PI_DEVICE_INFO_PARTITION_PROPERTIES) { + if (param_value) { + auto *Result = + reinterpret_cast(param_value); + *Result = PI_DEVICE_PARTITION_EQUALLY; + } + if (param_value_size_ret) + *param_value_size_ret = sizeof(pi_device_partition_property); + } else if (param_name == PI_DEVICE_INFO_MAX_COMPUTE_UNITS) { + auto *Result = reinterpret_cast(param_value); + *Result = 2; + } else if (param_name == PI_DEVICE_INFO_PARENT_DEVICE) { + auto *Result = reinterpret_cast(param_value); + *Result = (device == ParentDevice) ? nullptr : ParentDevice; + } else if (param_name == PI_DEVICE_INFO_PLATFORM) { + auto *Result = reinterpret_cast(param_value); + *Result = PiPlatform; + } else if (param_name == PI_DEVICE_INFO_EXTENSIONS) { + if (param_value_size_ret) { + *param_value_size_ret = 0; + } + } + return PI_SUCCESS; +} + +pi_result redefinedDevicePartition( + pi_device device, const pi_device_partition_property *properties, + pi_uint32 num_devices, pi_device *out_devices, pi_uint32 *out_num_devices) { + if (out_devices) { + for (pi_uint32 I = 0; I < num_devices; ++I) { + out_devices[I] = reinterpret_cast(1); + } + } + if (out_num_devices) + *out_num_devices = num_devices; + return PI_SUCCESS; +} + +pi_result redefinedDeviceRetain(pi_device device) { return PI_SUCCESS; } + +pi_result redefinedDeviceRelease(pi_device device) { return PI_SUCCESS; } + +pi_result redefinedQueueCreate(pi_context context, pi_device device, + pi_queue_properties properties, + pi_queue *queue) { + return PI_SUCCESS; +} + +pi_result redefinedQueueRelease(pi_queue queue) { return PI_SUCCESS; } + +// Check that the device is verified to be either a member of the context or a +// descendant of its member. +TEST(QueueDeviceCheck, CheckDeviceRestriction) { + unittest::ScopedEnvVar EnableDefaultContexts( + EnableDefaultContextsName, "1", + detail::SYCLConfig::reset); + + platform Plt{default_selector()}; + if (Plt.is_host()) { + std::cout << "The test is not supported on host, skipping" << std::endl; + GTEST_SKIP(); + } + PiPlatform = detail::getSyclObjImpl(Plt)->getHandleRef(); + // Create default context normally to avoid issues during its release, which + // takes plase after Mock is destroyed. + context DefaultCtx = Plt.ext_oneapi_get_default_context(); + device Dev = DefaultCtx.get_devices()[0]; + + unittest::PiMock Mock{Plt}; + Mock.redefine(redefinedContextCreate); + Mock.redefine(redefinedContextRelease); + Mock.redefine(redefinedDeviceGetInfo); + Mock.redefine(redefinedDevicePartition); + Mock.redefine(redefinedDeviceRelease); + Mock.redefine(redefinedDeviceRetain); + Mock.redefine(redefinedQueueCreate); + Mock.redefine(redefinedQueueRelease); + + // Device is a member of the context. + { + queue Q{Dev}; + EXPECT_EQ(Q.get_context().get_platform(), Plt); + EXPECT_EQ(Q.get_context(), DefaultCtx); + queue Q2{DefaultCtx, Dev}; + } + // Device is a descendant of a member of the context. + { + ParentDevice = detail::getSyclObjImpl(Dev)->getHandleRef(); + std::vector Subdevices = + Dev.create_sub_devices(2); + queue Q{Subdevices[0]}; + // OpenCL backend does not support using a descendant here yet. + EXPECT_EQ(Q.get_context() == DefaultCtx, + Q.get_backend() != backend::opencl); + try { + queue Q2{DefaultCtx, Subdevices[0]}; + EXPECT_NE(Q.get_backend(), backend::opencl); + } catch (sycl::invalid_object_error &e) { + EXPECT_EQ(Q.get_backend(), backend::opencl); + EXPECT_EQ(std::strcmp( + e.what(), + "Queue cannot be constructed with the given context and " + "device since the device is not a member of the context " + "(descendants of devices from the context are not " + "supported on OpenCL yet). -33 (PI_ERROR_INVALID_DEVICE)"), + 0); + } + } + // Device is neither of the two. + { + ParentDevice = nullptr; + device Device = detail::createSyclObjFromImpl( + std::make_shared(reinterpret_cast(0x01), + detail::getSyclObjImpl(Plt))); + queue Q{Device}; + EXPECT_NE(Q.get_context(), DefaultCtx); + try { + queue Q2{DefaultCtx, Device}; + EXPECT_TRUE(false); + } catch (sycl::invalid_object_error &e) { + EXPECT_NE( + std::strstr(e.what(), + "Queue cannot be constructed with the given context and " + "device"), + nullptr); + } + } +} +} // anonymous namespace