@@ -36,7 +36,12 @@ from .._backend cimport (
3636 DPCTLSyclDeviceRef,
3737 DPCTLSyclUSMRef,
3838)
39- from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
39+ from ._usmarray cimport (
40+ USM_ARRAY_C_CONTIGUOUS,
41+ USM_ARRAY_F_CONTIGUOUS,
42+ USM_ARRAY_WRITABLE,
43+ usm_ndarray,
44+ )
4045
4146import ctypes
4247
@@ -291,14 +296,29 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
291296 for i in range (nd):
292297 shape_strides_ptr[i] = shape_ptr[i]
293298 strides_ptr = usm_ary.get_strides()
299+ flags = usm_ary.flags_
294300 if strides_ptr:
295301 for i in range (nd):
296302 shape_strides_ptr[nd + i] = strides_ptr[i]
297303 else :
298- si = 1
299- for i in range (0 , nd):
300- shape_strides_ptr[nd + i] = si
301- si = si * shape_ptr[i]
304+ if flags & USM_ARRAY_C_CONTIGUOUS:
305+ si = 1
306+ for i in range (nd - 1 , - 1 , - 1 ):
307+ shape_strides_ptr[nd + i] = si
308+ si = si * shape_ptr[i]
309+ elif flags & USM_ARRAY_F_CONTIGUOUS:
310+ si = 1
311+ for i in range (0 , nd):
312+ shape_strides_ptr[nd + i] = si
313+ si = si * shape_ptr[i]
314+ else :
315+ stdlib.free(shape_strides_ptr)
316+ stdlib.free(dlm_tensor)
317+ raise BufferError(
318+ " to_dlpack_capsule: Could not reconstruct strides "
319+ " for non-contiguous tensor"
320+ )
321+
302322 strides_ptr = < Py_ssize_t * > & shape_strides_ptr[nd]
303323
304324 ary_dt = usm_ary.dtype
@@ -409,10 +429,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
409429 for i in range (nd):
410430 shape_strides_ptr[nd + i] = strides_ptr[i]
411431 else :
412- si = 1
413- for i in range (0 , nd):
414- shape_strides_ptr[nd + i] = si
415- si = si * shape_ptr[i]
432+ if flags & USM_ARRAY_C_CONTIGUOUS:
433+ si = 1
434+ for i in range (nd - 1 , - 1 , - 1 ):
435+ shape_strides_ptr[nd + i] = si
436+ si = si * shape_ptr[i]
437+ elif flags & USM_ARRAY_F_CONTIGUOUS:
438+ si = 1
439+ for i in range (0 , nd):
440+ shape_strides_ptr[nd + i] = si
441+ si = si * shape_ptr[i]
442+ else :
443+ stdlib.free(shape_strides_ptr)
444+ stdlib.free(dlmv_tensor)
445+ raise BufferError(
446+ " to_dlpack_versioned_capsule: Could not reconstruct "
447+ " strides for non-contiguous tensor"
448+ )
449+
416450 strides_ptr = < Py_ssize_t * > & shape_strides_ptr[nd]
417451
418452 # this can all be a function for building the dl_tensor
0 commit comments