Skip to content

Commit

Permalink
Reuse get parent device ordinal id routine (#1672)
Browse files Browse the repository at this point in the history
* Reused get_parent_device_ordinal_id routine

* test_legacy_dlpack_capsule uses 4 kinds of dtype

Added test to use non-default copy keyword, and non-default device keyword
argument.
  • Loading branch information
oleksandr-pavlyk authored May 12, 2024
1 parent 1205b0b commit 3e3ab03
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
20 changes: 3 additions & 17 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ cdef int get_array_dlpack_device_id(
cdef c_dpctl.SyclQueue ary_sycl_queue
cdef c_dpctl.SyclDevice ary_sycl_device
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
cdef int device_id = -1

ary_sycl_queue = usm_ary.get_sycl_queue()
Expand All @@ -228,26 +227,18 @@ cdef int get_array_dlpack_device_id(
"on non-partitioned SYCL devices on platforms where "
"default_context oneAPI extension is not supported."
)
device_id = ary_sycl_device.get_overall_ordinal()
else:
if not usm_ary.sycl_context == default_context:
raise DLPackCreationError(
"to_dlpack_capsule: DLPack can only export arrays based on USM "
"allocations bound to a default platform SYCL context"
)
# Find the unpartitioned parent of the allocation device
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
if pDRef is not NULL:
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
ary_sycl_device = c_dpctl.SyclDevice._create(pDRef)
device_id = get_parent_device_ordinal_id(ary_sycl_device)

device_id = ary_sycl_device.get_overall_ordinal()
if device_id < 0:
raise DLPackCreationError(
"to_dlpack_capsule: failed to determine device_id"
"get_array_dlpack_device_id: failed to determine device_id"
)

return device_id
Expand Down Expand Up @@ -295,11 +286,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):

device_id = get_array_dlpack_device_id(usm_ary)

if device_id < 0:
raise DLPackCreationError(
"to_dlpack_capsule: failed to determine device_id"
)

dlm_tensor = <DLManagedTensor *> stdlib.malloc(
sizeof(DLManagedTensor))
if dlm_tensor is NULL:
Expand Down
16 changes: 16 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,22 @@ def test_legacy_dlpack_capsule():
del cap
assert x._pointer == y._pointer

x = dpt.arange(100, dtype="u4")
x2 = dpt.reshape(x, (10, 10)).mT
cap = x2.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
del cap
assert x2._pointer == y._pointer
del x2

x = dpt.arange(100, dtype="f4")
x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F")
cap = x2.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
del cap
assert x2._pointer == y._pointer

x = dpt.arange(100, dtype="c8")
x3 = x[::-2]
cap = x3.__dlpack__(max_version=legacy_ver)
y = _dlp.from_dlpack_capsule(cap)
Expand Down Expand Up @@ -321,3 +324,16 @@ def test_versioned_dlpack_capsule():
y = _dlp.from_dlpack_versioned_capsule(cap)
assert x._pointer != y._pointer
assert not y.flags.writable


def test_from_dlpack_kwargs():
try:
x = dpt.arange(100, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

y = dpt.from_dlpack(x, copy=True)
assert x._pointer != y._pointer

z = dpt.from_dlpack(x, device=x.sycl_device)
assert z._pointer == x._pointer

0 comments on commit 3e3ab03

Please sign in to comment.