diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index aede1a17a0..7c0f96ec7f 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -211,7 +211,6 @@ cdef int get_array_dlpack_device_id( cdef c_dpctl.SyclQueue ary_sycl_queue cdef c_dpctl.SyclDevice ary_sycl_device cdef DPCTLSyclDeviceRef pDRef = NULL - cdef DPCTLSyclDeviceRef tDRef = NULL cdef int device_id = -1 ary_sycl_queue = usm_ary.get_sycl_queue() @@ -228,26 +227,18 @@ cdef int get_array_dlpack_device_id( "on non-partitioned SYCL devices on platforms where " "default_context oneAPI extension is not supported." ) + device_id = ary_sycl_device.get_overall_ordinal() else: if not usm_ary.sycl_context == default_context: raise DLPackCreationError( "to_dlpack_capsule: DLPack can only export arrays based on USM " "allocations bound to a default platform SYCL context" ) - # Find the unpartitioned parent of the allocation device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - tDRef = DPCTLDevice_GetParentDevice(pDRef) - while tDRef is not NULL: - DPCTLDevice_Delete(pDRef) - pDRef = tDRef - tDRef = DPCTLDevice_GetParentDevice(pDRef) - ary_sycl_device = c_dpctl.SyclDevice._create(pDRef) + device_id = get_parent_device_ordinal_id(ary_sycl_device) - device_id = ary_sycl_device.get_overall_ordinal() if device_id < 0: raise DLPackCreationError( - "to_dlpack_capsule: failed to determine device_id" + "get_array_dlpack_device_id: failed to determine device_id" ) return device_id @@ -295,11 +286,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): device_id = get_array_dlpack_device_id(usm_ary) - if device_id < 0: - raise DLPackCreationError( - "to_dlpack_capsule: failed to determine device_id" - ) - dlm_tensor = stdlib.malloc( sizeof(DLManagedTensor)) if dlm_tensor is NULL: diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index b85ff77152..6baa899c9d 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -261,6 +261,7 @@ def test_legacy_dlpack_capsule(): del cap assert x._pointer == y._pointer + x = dpt.arange(100, dtype="u4") x2 = dpt.reshape(x, (10, 10)).mT cap = x2.__dlpack__(max_version=legacy_ver) y = _dlp.from_dlpack_capsule(cap) @@ -268,12 +269,14 @@ def test_legacy_dlpack_capsule(): assert x2._pointer == y._pointer del x2 + x = dpt.arange(100, dtype="f4") x2 = dpt.asarray(dpt.reshape(x, (10, 10)), order="F") cap = x2.__dlpack__(max_version=legacy_ver) y = _dlp.from_dlpack_capsule(cap) del cap assert x2._pointer == y._pointer + x = dpt.arange(100, dtype="c8") x3 = x[::-2] cap = x3.__dlpack__(max_version=legacy_ver) y = _dlp.from_dlpack_capsule(cap) @@ -321,3 +324,16 @@ def test_versioned_dlpack_capsule(): y = _dlp.from_dlpack_versioned_capsule(cap) assert x._pointer != y._pointer assert not y.flags.writable + + +def test_from_dlpack_kwargs(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + + y = dpt.from_dlpack(x, copy=True) + assert x._pointer != y._pointer + + z = dpt.from_dlpack(x, device=x.sycl_device) + assert z._pointer == x._pointer