diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index e479e6996c..aede1a17a0 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -202,6 +202,57 @@ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *: return dev.get_overall_ordinal() +cdef int get_array_dlpack_device_id( + usm_ndarray usm_ary +) except *: + """Finds ordinal number of the parent of device where array + was allocated. + """ + cdef c_dpctl.SyclQueue ary_sycl_queue + cdef c_dpctl.SyclDevice ary_sycl_device + cdef DPCTLSyclDeviceRef pDRef = NULL + cdef DPCTLSyclDeviceRef tDRef = NULL + cdef int device_id = -1 + + ary_sycl_queue = usm_ary.get_sycl_queue() + ary_sycl_device = ary_sycl_queue.get_sycl_device() + + default_context = _get_default_context(ary_sycl_device) + if default_context is None: + # check that ary_sycl_device is a non-partitioned device + pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) + if pDRef is not NULL: + DPCTLDevice_Delete(pDRef) + raise DLPackCreationError( + "to_dlpack_capsule: DLPack can only export arrays allocated " + "on non-partitioned SYCL devices on platforms where " + "default_context oneAPI extension is not supported." + ) + else: + if not usm_ary.sycl_context == default_context: + raise DLPackCreationError( + "to_dlpack_capsule: DLPack can only export arrays based on USM " + "allocations bound to a default platform SYCL context" + ) + # Find the unpartitioned parent of the allocation device + pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) + if pDRef is not NULL: + tDRef = DPCTLDevice_GetParentDevice(pDRef) + while tDRef is not NULL: + DPCTLDevice_Delete(pDRef) + pDRef = tDRef + tDRef = DPCTLDevice_GetParentDevice(pDRef) + ary_sycl_device = c_dpctl.SyclDevice._create(pDRef) + + device_id = ary_sycl_device.get_overall_ordinal() + if device_id < 0: + raise DLPackCreationError( + "to_dlpack_capsule: failed to determine device_id" + ) + + return device_id + + cpdef to_dlpack_capsule(usm_ndarray usm_ary): """ to_dlpack_capsule(usm_ary) @@ -225,10 +276,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): ValueError: when array elements data type could not be represented in ``DLManagedTensor``. """ - cdef c_dpctl.SyclQueue ary_sycl_queue - cdef c_dpctl.SyclDevice ary_sycl_device - cdef DPCTLSyclDeviceRef pDRef = NULL - cdef DPCTLSyclDeviceRef tDRef = NULL cdef DLManagedTensor *dlm_tensor = NULL cdef DLTensor *dl_tensor = NULL cdef int nd = usm_ary.get_ndim() @@ -245,38 +292,9 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): cdef Py_ssize_t si = 1 ary_base = usm_ary.get_base() - ary_sycl_queue = usm_ary.get_sycl_queue() - ary_sycl_device = ary_sycl_queue.get_sycl_device() - default_context = _get_default_context(ary_sycl_device) - if default_context is None: - # check that ary_sycl_device is a non-partitioned device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - DPCTLDevice_Delete(pDRef) - raise DLPackCreationError( - "to_dlpack_capsule: DLPack can only export arrays allocated " - "on non-partitioned SYCL devices on platforms where " - "default_context oneAPI extension is not supported." - ) - else: - if not usm_ary.sycl_context == default_context: - raise DLPackCreationError( - "to_dlpack_capsule: DLPack can only export arrays based on USM " - "allocations bound to a default platform SYCL context" - ) - # Find the unpartitioned parent of the allocation device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - tDRef = DPCTLDevice_GetParentDevice(pDRef) - while tDRef is not NULL: - DPCTLDevice_Delete(pDRef) - pDRef = tDRef - tDRef = DPCTLDevice_GetParentDevice(pDRef) - ary_sycl_device = c_dpctl.SyclDevice._create(pDRef) + device_id = get_array_dlpack_device_id(usm_ary) - # Find ordinal number of the parent device - device_id = ary_sycl_device.get_overall_ordinal() if device_id < 0: raise DLPackCreationError( "to_dlpack_capsule: failed to determine device_id" @@ -376,10 +394,6 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied): ValueError: when array elements data type could not be represented in ``DLManagedTensorVersioned``. """ - cdef c_dpctl.SyclQueue ary_sycl_queue - cdef c_dpctl.SyclDevice ary_sycl_device - cdef DPCTLSyclDeviceRef pDRef = NULL - cdef DPCTLSyclDeviceRef tDRef = NULL cdef DLManagedTensorVersioned *dlmv_tensor = NULL cdef DLTensor *dl_tensor = NULL cdef uint32_t dlmv_flags = 0 @@ -397,43 +411,9 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied): cdef Py_ssize_t si = 1 ary_base = usm_ary.get_base() - ary_sycl_queue = usm_ary.get_sycl_queue() - ary_sycl_device = ary_sycl_queue.get_sycl_device() - - default_context = _get_default_context(ary_sycl_device) - if default_context is None: - # check that ary_sycl_device is a non-partitioned device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - DPCTLDevice_Delete(pDRef) - raise DLPackCreationError( - "to_dlpack_versioned_capsule: DLPack can only export arrays " - "allocated on non-partitioned SYCL devices on platforms " - "where default_context oneAPI extension is not supported." - ) - else: - if not usm_ary.sycl_context == default_context: - raise DLPackCreationError( - "to_dlpack_versioned_capsule: DLPack can only export arrays " - "based on USM allocations bound to a default platform SYCL " - "context" - ) - # Find the unpartitioned parent of the allocation device - pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref()) - if pDRef is not NULL: - tDRef = DPCTLDevice_GetParentDevice(pDRef) - while tDRef is not NULL: - DPCTLDevice_Delete(pDRef) - pDRef = tDRef - tDRef = DPCTLDevice_GetParentDevice(pDRef) - ary_sycl_device = c_dpctl.SyclDevice._create(pDRef) # Find ordinal number of the parent device - device_id = ary_sycl_device.get_overall_ordinal() - if device_id < 0: - raise DLPackCreationError( - "to_dlpack_versioned_capsule: failed to determine device_id" - ) + device_id = get_array_dlpack_device_id(usm_ary) dlmv_tensor = stdlib.malloc( sizeof(DLManagedTensorVersioned))