diff --git a/libsyclinterface/include/dpctl_sycl_platform_interface.h b/libsyclinterface/include/dpctl_sycl_platform_interface.h index 1c01dcfb69..1d2238e652 100644 --- a/libsyclinterface/include/dpctl_sycl_platform_interface.h +++ b/libsyclinterface/include/dpctl_sycl_platform_interface.h @@ -142,4 +142,16 @@ DPCTLPlatform_GetVersion(__dpctl_keep const DPCTLSyclPlatformRef PRef); DPCTL_API __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms(void); +/*! + * @brief Returns a DPCTLSyclContextRef for default platform context. + * + * @param PRef Opaque pointer to a sycl::platform + * @return A DPCTLSyclContextRef value for the default platform associated + * with this platform. + * @ingroup PlatformInterface + */ +DPCTL_API +__dpctl_give DPCTLSyclContextRef +DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef); + DPCTL_C_EXTERN_C_END diff --git a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp index 36c9bd16b7..5d81fcf204 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp @@ -41,6 +41,7 @@ using namespace cl::sycl; namespace { DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef); +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef); DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device_selector, DPCTLSyclDeviceSelectorRef); DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector, DPCTLPlatformVectorRef); @@ -202,3 +203,18 @@ __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms() // the wrap function is defined inside dpctl_vector_templ.cpp return wrap(Platforms); } + +__dpctl_give DPCTLSyclContextRef +DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + auto P = unwrap(PRef); + if (P) { + auto default_ctx = P->ext_oneapi_get_default_context(); + return wrap(new context(default_ctx)); + } + else { + error_handler("Driver version cannot be looked up for a NULL platform.", + __FILE__, __func__, __LINE__); + return nullptr; + } +} diff --git a/libsyclinterface/tests/test_sycl_platform_interface.cpp b/libsyclinterface/tests/test_sycl_platform_interface.cpp index 594d4856e2..1fe9c80117 100644 --- a/libsyclinterface/tests/test_sycl_platform_interface.cpp +++ b/libsyclinterface/tests/test_sycl_platform_interface.cpp @@ -25,6 +25,7 @@ //===----------------------------------------------------------------------===// #include "Support/CBindingWrapping.h" +#include "dpctl_sycl_context_interface.h" #include "dpctl_sycl_device_selector_interface.h" #include "dpctl_sycl_platform_interface.h" #include "dpctl_sycl_platform_manager.h" @@ -82,6 +83,16 @@ void check_platform_backend(__dpctl_keep const DPCTLSyclPlatformRef PRef) }()); } +void check_platform_default_context( + __dpctl_keep const DPCTLSyclPlatformRef PRef) +{ + DPCTLSyclContextRef CRef = nullptr; + EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(PRef)); + EXPECT_TRUE(CRef != nullptr); + + EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef)); +} + } // namespace struct TestDPCTLSyclPlatformInterface @@ -167,6 +178,14 @@ TEST_F(TestDPCTLSyclPlatformNull, ChkGetVersion) ASSERT_TRUE(version == nullptr); } +TEST_F(TestDPCTLSyclPlatformNull, ChkGetDefaultConext) +{ + DPCTLSyclContextRef CRef = nullptr; + + EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(NullPRef)); + EXPECT_TRUE(CRef == nullptr); +} + struct TestDPCTLSyclDefaultPlatform : public ::testing::Test { DPCTLSyclPlatformRef PRef = nullptr; @@ -207,6 +226,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkGetBackend) check_platform_backend(PRef); } +TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDefaultContext) +{ + check_platform_default_context(PRef); +} + TEST_P(TestDPCTLSyclPlatformInterface, ChkCopy) { DPCTLSyclPlatformRef Copied_PRef = nullptr;