Skip to content

Commit

Permalink
Merge pull request #1560 from IntelPython/dlpack-support-sub-device
Browse files Browse the repository at this point in the history
x.__dlpack_device__() returns ID of the parent device
  • Loading branch information
oleksandr-pavlyk authored Feb 26, 2024
2 parents 16f23f7 + db9dd04 commit fe75a16
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 14 deletions.
10 changes: 8 additions & 2 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,20 @@ cdef class SyclPlatform(_SyclPlatform):
"""Returns the default platform context for this platform
Returns:
SyclContext: The default context for the platform.
SyclContext
The default context for the platform.
Raises:
SyclContextCreationError
If default_context is not supported
"""
cdef DPCTLSyclContextRef CRef = (
DPCTLPlatform_GetDefaultContext(self._platform_ref)
)

if (CRef == NULL):
raise RuntimeError("Getting default error ran into a problem")
raise SyclContextCreationError(
"Getting default_context ran into a problem"
)
else:
return SyclContext._create(CRef)

Expand Down
3 changes: 3 additions & 0 deletions dpctl/tensor/_dlpack.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# cython: language_level=3
# cython: linetrace=True

from .._sycl_device cimport SyclDevice
from ._usmarray cimport usm_ndarray


Expand All @@ -32,6 +33,8 @@ cpdef usm_ndarray from_dlpack_capsule(object dltensor) except +

cpdef from_dlpack(array)

cdef int get_parent_device_ordinal_id(SyclDevice dev) except *

cdef class DLPackCreationError(Exception):
"""
A DLPackCreateError exception is raised when constructing
Expand Down
42 changes: 34 additions & 8 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,39 @@ cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil
dlm_tensor.manager_ctx = NULL
stdlib.free(dlm_tensor)

cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
try:
if _IS_LINUX:
default_context = dev.sycl_platform.default_context
else:
default_context = None
except RuntimeError:
# RT does not support default_context, e.g. Windows
default_context = None

return default_context


cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
cdef c_dpctl.SyclDevice p_dev

pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
if pDRef is not NULL:
# if dev is a sub-device, find its parent
# and return its overall ordinal id
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
p_dev = c_dpctl.SyclDevice._create(pDRef)
return p_dev.get_overall_ordinal()

# return overall ordinal id of argument device
return dev.get_overall_ordinal()


cpdef to_dlpack_capsule(usm_ndarray usm_ary):
"""
Expand Down Expand Up @@ -168,14 +201,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
ary_sycl_queue = usm_ary.get_sycl_queue()
ary_sycl_device = ary_sycl_queue.get_sycl_device()

try:
if _IS_LINUX:
default_context = ary_sycl_device.sycl_platform.default_context
else:
default_context = None
except RuntimeError:
# RT does not support default_context, e.g. Windows
default_context = None
default_context = _get_default_context(ary_sycl_device)
if default_context is None:
# check that ary_sycl_device is a non-partitioned device
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,10 @@ cdef class usm_ndarray:
DLPackCreationError: when array is allocation on a partitioned
SYCL device
"""
cdef int dev_id = (<c_dpctl.SyclDevice>self.sycl_device).get_overall_ordinal()
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
if dev_id < 0:
raise c_dlpack.DLPackCreationError(
"DLPack protocol is only supported for non-partitioned devices"
"Could not determine id of the device where array was allocated."
)
else:
return (
Expand Down
36 changes: 36 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,39 @@ def test_from_dlpack_fortran_contig_array_roundtripping():

assert dpt.all(dpt.equal(ar2d_f, ar2d_r))
assert dpt.all(dpt.equal(ar2d_c, ar2d_r))


def test_dlpack_from_subdevice():
"""
This test checks that array allocated on a sub-device,
with memory bound to platform-default SyclContext can be
exported and imported via DLPack.
"""
n = 64
try:
dev = dpctl.SyclDevice()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")
try:
sdevs = dev.create_sub_devices(partition="next_partitionable")
except dpctl.SyclSubDeviceCreationError:
sdevs = None
try:
sdevs = (
dev.create_sub_devices(partition=[1, 1]) if sdevs is None else sdevs
)
except dpctl.SyclSubDeviceCreationError:
pytest.skip("Default device can not be partitioned")
assert isinstance(sdevs, list) and len(sdevs) > 0
try:
ctx = sdevs[0].sycl_platform.default_context
except dpctl.SyclContextCreationError:
pytest.skip("Platform's default_context is not available")
try:
q = dpctl.SyclQueue(ctx, sdevs[0])
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

ar = dpt.arange(n, dtype=dpt.int32, sycl_queue=q)
ar2 = dpt.from_dlpack(ar)
assert ar2.sycl_device == sdevs[0]
13 changes: 11 additions & 2 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,17 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
auto P = unwrap<platform>(PRef);
if (P) {
const auto &default_ctx = P->ext_oneapi_get_default_context();
return wrap<context>(new context(default_ctx));
#ifdef SYCL_EXT_ONEAPI_DEFAULT_CONTEXT
try {
const auto &default_ctx = P->ext_oneapi_get_default_context();
return wrap<context>(new context(default_ctx));
} catch (const std::exception &ex) {
error_handler(ex, __FILE__, __func__, __LINE__);
return nullptr;
}
#else
return nullptr;
#endif
}
else {
error_handler(
Expand Down

0 comments on commit fe75a16

Please sign in to comment.