Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default context #827

Merged
merged 9 commits into from
May 2, 2022
2 changes: 2 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
const DPCTLSyclPlatformRef)


cdef extern from "syclinterface/dpctl_sycl_context_interface.h":
Expand Down
20 changes: 20 additions & 0 deletions dpctl/_sycl_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ from ._backend cimport ( # noqa: E211
DPCTLDevice_GetMaxWriteImageArgs,
DPCTLDevice_GetName,
DPCTLDevice_GetParentDevice,
DPCTLDevice_GetPlatform,
DPCTLDevice_GetPreferredVectorWidthChar,
DPCTLDevice_GetPreferredVectorWidthDouble,
DPCTLDevice_GetPreferredVectorWidthFloat,
Expand Down Expand Up @@ -80,6 +81,7 @@ from ._backend cimport ( # noqa: E211
DPCTLSize_t_Array_Delete,
DPCTLSyclDeviceRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclPlatformRef,
_aspect_type,
_backend_type,
_device_type,
Expand All @@ -91,6 +93,8 @@ from .enum_types import backend_type, device_type
from libc.stdint cimport int64_t, uint32_t
from libc.stdlib cimport free, malloc

from ._sycl_platform cimport SyclPlatform

import collections
import warnings

Expand Down Expand Up @@ -639,6 +643,22 @@ cdef class SyclDevice(_SyclDevice):
self._device_ref
)

@property
def platform(self):
""" Returns the platform associated with this device.

Returns:
:class:`dpctl.SyclPlatform`: The platform associated with this
device.
"""
cdef DPCTLSyclPlatformRef PRef = (
DPCTLDevice_GetPlatform(self._device_ref)
)
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
if (PRef == NULL):
raise RuntimeError("Could not get platform for device.")
else:
return SyclPlatform._create(PRef)
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved

@property
def preferred_vector_width_char(self):
""" Returns the preferred native vector width size for built-in scalar
Expand Down
25 changes: 23 additions & 2 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatform_CreateFromSelector,
DPCTLPlatform_Delete,
DPCTLPlatform_GetBackend,
DPCTLPlatform_GetDefaultContext,
DPCTLPlatform_GetName,
DPCTLPlatform_GetPlatforms,
DPCTLPlatform_GetVendor,
Expand All @@ -40,15 +41,19 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatformVector_GetAt,
DPCTLPlatformVector_Size,
DPCTLPlatformVectorRef,
DPCTLSyclContextRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclPlatformRef,
_backend_type,
)

import warnings

from ._sycl_context import SyclContextCreationError
from .enum_types import backend_type

from ._sycl_context cimport SyclContext

__all__ = [
"get_platforms",
"lsplatform",
Expand Down Expand Up @@ -236,10 +241,10 @@ cdef class SyclPlatform(_SyclPlatform):

@property
def backend(self):
"""Returns the backend_type enum value for this device
"""Returns the backend_type enum value for this platform

Returns:
backend_type: The backend for the device.
backend_type: The backend for the platform.
"""
cdef _backend_type BTy = (
DPCTLPlatform_GetBackend(self._platform_ref)
Expand All @@ -255,6 +260,22 @@ cdef class SyclPlatform(_SyclPlatform):
else:
raise ValueError("Unknown backend type.")

@property
def default_context(self):
"""Returns the default platform context for this platform

Returns:
SyclContext: The default context for the platform.
"""
cdef DPCTLSyclContextRef CRef = (
DPCTLPlatform_GetDefaultContext(self._platform_ref)
)

if (CRef == NULL):
raise
else:
return SyclContext._create(CRef)


def lsplatform(verbosity=0):
"""
Expand Down
7 changes: 7 additions & 0 deletions dpctl/tests/test_sycl_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ def check_profiling_timer_resolution(device):
assert isinstance(resol, int) and resol > 0


def check_platform(device):
p = device.platform
assert isinstance(p, dpctl.SyclPlatform)


list_of_checks = [
check_get_max_compute_units,
check_get_max_work_item_dims,
Expand Down Expand Up @@ -552,6 +557,8 @@ def check_profiling_timer_resolution(device):
check_repr,
check_get_global_mem_size,
check_get_local_mem_size,
check_profiling_timer_resolution,
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
check_platform,
]


Expand Down
5 changes: 5 additions & 0 deletions dpctl/tests/test_sycl_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def check_repr(platform):
assert r != ""


def check_default_context(platform):
r = platform.default_context
assert type(r) is dpctl.SyclContext
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved


list_of_checks = [
check_name,
check_vendor,
Expand Down
12 changes: 12 additions & 0 deletions libsyclinterface/include/dpctl_sycl_platform_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,16 @@ DPCTLPlatform_GetVersion(__dpctl_keep const DPCTLSyclPlatformRef PRef);
DPCTL_API
__dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms(void);

/*!
* @brief Returns a DPCTLSyclContextRef for default platform context.
*
* @param PRef Opaque pointer to a sycl::platform
* @return A DPCTLSyclContextRef value for the default platform associated
* with this platform.
* @ingroup PlatformInterface
*/
DPCTL_API
__dpctl_give DPCTLSyclContextRef
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);

DPCTL_C_EXTERN_C_END
16 changes: 16 additions & 0 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using namespace cl::sycl;
namespace
{
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device_selector, DPCTLSyclDeviceSelectorRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclPlatformRef>,
DPCTLPlatformVectorRef);
Expand Down Expand Up @@ -202,3 +203,18 @@ __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
// the wrap function is defined inside dpctl_vector_templ.cpp
return wrap(Platforms);
}

__dpctl_give DPCTLSyclContextRef
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
auto P = unwrap(PRef);
if (P) {
auto default_ctx = P->ext_oneapi_get_default_context();
return wrap(new context(default_ctx));
}
else {
error_handler("Driver version cannot be looked up for a NULL platform.",
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
__FILE__, __func__, __LINE__);
return nullptr;
}
}
24 changes: 24 additions & 0 deletions libsyclinterface/tests/test_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===----------------------------------------------------------------------===//

#include "Support/CBindingWrapping.h"
#include "dpctl_sycl_context_interface.h"
#include "dpctl_sycl_device_selector_interface.h"
#include "dpctl_sycl_platform_interface.h"
#include "dpctl_sycl_platform_manager.h"
Expand Down Expand Up @@ -82,6 +83,16 @@ void check_platform_backend(__dpctl_keep const DPCTLSyclPlatformRef PRef)
}());
}

void check_platform_default_context(
__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
DPCTLSyclContextRef CRef = nullptr;
EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(PRef));
EXPECT_TRUE(CRef != nullptr);

EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef));
}

} // namespace

struct TestDPCTLSyclPlatformInterface
Expand Down Expand Up @@ -167,6 +178,14 @@ TEST_F(TestDPCTLSyclPlatformNull, ChkGetVersion)
ASSERT_TRUE(version == nullptr);
}

TEST_F(TestDPCTLSyclPlatformNull, ChkGetDefaultConext)
{
DPCTLSyclContextRef CRef = nullptr;

EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(NullPRef));
EXPECT_TRUE(CRef == nullptr);
}

struct TestDPCTLSyclDefaultPlatform : public ::testing::Test
{
DPCTLSyclPlatformRef PRef = nullptr;
Expand Down Expand Up @@ -207,6 +226,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkGetBackend)
check_platform_backend(PRef);
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDefaultContext)
{
check_platform_default_context(PRef);
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkCopy)
{
DPCTLSyclPlatformRef Copied_PRef = nullptr;
Expand Down
1 change: 0 additions & 1 deletion libsyclinterface/tests/test_sycl_queue_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckMemset)

ASSERT_NO_FATAL_FAILURE(DPCTLfree_with_queue(p, QRef));

bool equal = true;
oleksandr-pavlyk marked this conversation as resolved.
Show resolved Hide resolved
for (size_t i = 0; i < nbytes; ++i) {
ASSERT_TRUE(host_arr[i] == val);
}
Expand Down