Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

x.__dlpack_device__() returns ID of the parent device #1560

Merged
merged 5 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,20 @@ cdef class SyclPlatform(_SyclPlatform):
"""Returns the default platform context for this platform

Returns:
SyclContext: The default context for the platform.
SyclContext
The default context for the platform.
Raises:
SyclContextCreationError
If default_context is not supported
"""
cdef DPCTLSyclContextRef CRef = (
DPCTLPlatform_GetDefaultContext(self._platform_ref)
)

if (CRef == NULL):
raise RuntimeError("Getting default error ran into a problem")
raise SyclContextCreationError(
"Getting default_context ran into a problem"
)
else:
return SyclContext._create(CRef)

Expand Down
3 changes: 3 additions & 0 deletions dpctl/tensor/_dlpack.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# cython: language_level=3
# cython: linetrace=True

from .._sycl_device cimport SyclDevice
from ._usmarray cimport usm_ndarray


Expand All @@ -32,6 +33,8 @@ cpdef usm_ndarray from_dlpack_capsule(object dltensor) except +

cpdef from_dlpack(array)

cdef int get_parent_device_ordinal_id(SyclDevice dev) except *

cdef class DLPackCreationError(Exception):
"""
A DLPackCreateError exception is raised when constructing
Expand Down
42 changes: 34 additions & 8 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,39 @@ cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil
dlm_tensor.manager_ctx = NULL
stdlib.free(dlm_tensor)

cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
try:
if _IS_LINUX:
default_context = dev.sycl_platform.default_context
else:
default_context = None
except RuntimeError:
# RT does not support default_context, e.g. Windows
default_context = None

return default_context


cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
cdef c_dpctl.SyclDevice p_dev

pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
if pDRef is not NULL:
# if dev is a sub-device, find its parent
# and return its overall ordinal id
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
p_dev = c_dpctl.SyclDevice._create(pDRef)
return p_dev.get_overall_ordinal()

# return overall ordinal id of argument device
return dev.get_overall_ordinal()


cpdef to_dlpack_capsule(usm_ndarray usm_ary):
"""
Expand Down Expand Up @@ -168,14 +201,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
ary_sycl_queue = usm_ary.get_sycl_queue()
ary_sycl_device = ary_sycl_queue.get_sycl_device()

try:
if _IS_LINUX:
default_context = ary_sycl_device.sycl_platform.default_context
else:
default_context = None
except RuntimeError:
# RT does not support default_context, e.g. Windows
default_context = None
default_context = _get_default_context(ary_sycl_device)
if default_context is None:
# check that ary_sycl_device is a non-partitioned device
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,10 @@ cdef class usm_ndarray:
DLPackCreationError: when array is allocation on a partitioned
SYCL device
"""
cdef int dev_id = (<c_dpctl.SyclDevice>self.sycl_device).get_overall_ordinal()
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
if dev_id < 0:
raise c_dlpack.DLPackCreationError(
"DLPack protocol is only supported for non-partitioned devices"
"Could not determine id of the device where array was allocated."
)
else:
return (
Expand Down
36 changes: 36 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,39 @@ def test_from_dlpack_fortran_contig_array_roundtripping():

assert dpt.all(dpt.equal(ar2d_f, ar2d_r))
assert dpt.all(dpt.equal(ar2d_c, ar2d_r))


def test_dlpack_from_subdevice():
"""
This test checks that array allocated on a sub-device,
with memory bound to platform-default SyclContext can be
exported and imported via DLPack.
"""
n = 64
try:
dev = dpctl.SyclDevice()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")
try:
sdevs = dev.create_sub_devices(partition="next_partitionable")
except dpctl.SyclSubDeviceCreationError:
sdevs = None
try:
sdevs = (
dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs
)
except dpctl.SyclSubDeviceCreationError:
pytest.skip("Default device can not be partitioned")
assert isinstance(sdevs, list) and len(sdevs) > 0
try:
ctx = sdevs[0].sycl_platform.default_context
except dpctl.SyclContextCreationError:
pytest.skip("Platform's default_context is not available")
try:
q = dpctl.SyclQueue(ctx, sdevs[0])
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q)
ar2 = dpt.from_dlpack(ar)
assert ar2.sycl_device == sdevs[0]
13 changes: 11 additions & 2 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,17 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
auto P = unwrap<platform>(PRef);
if (P) {
const auto &default_ctx = P->ext_oneapi_get_default_context();
return wrap<context>(new context(default_ctx));
#ifdef SYCL_EXT_ONEAPI_DEFAULT_CONTEXT
try {
const auto &default_ctx = P->ext_oneapi_get_default_context();
return wrap<context>(new context(default_ctx));
} catch (const std::exception &ex) {
error_handler(ex, __FILE__, __func__, __LINE__);
return nullptr;
}
#else
return nullptr;
#endif
}
else {
error_handler(
Expand Down