Skip to content

Commit 8a92deb

Browse files
Fix dlpack contiguous stride reconstruction
1 parent 2b3f1ec commit 8a92deb

File tree

1 file changed

+43
-9
lines changed

1 file changed

+43
-9
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ from .._backend cimport (
3636
DPCTLSyclDeviceRef,
3737
DPCTLSyclUSMRef,
3838
)
39-
from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
39+
from ._usmarray cimport (
40+
USM_ARRAY_C_CONTIGUOUS,
41+
USM_ARRAY_F_CONTIGUOUS,
42+
USM_ARRAY_WRITABLE,
43+
usm_ndarray,
44+
)
4045

4146
import ctypes
4247

@@ -291,14 +296,29 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
291296
for i in range(nd):
292297
shape_strides_ptr[i] = shape_ptr[i]
293298
strides_ptr = usm_ary.get_strides()
299+
flags = usm_ary.flags_
294300
if strides_ptr:
295301
for i in range(nd):
296302
shape_strides_ptr[nd + i] = strides_ptr[i]
297303
else:
298-
si = 1
299-
for i in range(0, nd):
300-
shape_strides_ptr[nd + i] = si
301-
si = si * shape_ptr[i]
304+
if flags & USM_ARRAY_C_CONTIGUOUS:
305+
si = 1
306+
for i in range(nd - 1, -1, -1):
307+
shape_strides_ptr[nd + i] = si
308+
si = si * shape_ptr[i]
309+
elif flags & USM_ARRAY_F_CONTIGUOUS:
310+
si = 1
311+
for i in range(0, nd):
312+
shape_strides_ptr[nd + i] = si
313+
si = si * shape_ptr[i]
314+
else:
315+
stdlib.free(shape_strides_ptr)
316+
stdlib.free(dlm_tensor)
317+
raise BufferError(
318+
"to_dlpack_capsule: Could not reconstruct strides "
319+
"for non-contiguous tensor"
320+
)
321+
302322
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
303323

304324
ary_dt = usm_ary.dtype
@@ -409,10 +429,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
409429
for i in range(nd):
410430
shape_strides_ptr[nd + i] = strides_ptr[i]
411431
else:
412-
si = 1
413-
for i in range(0, nd):
414-
shape_strides_ptr[nd + i] = si
415-
si = si * shape_ptr[i]
432+
if flags & USM_ARRAY_C_CONTIGUOUS:
433+
si = 1
434+
for i in range(nd - 1, -1, -1):
435+
shape_strides_ptr[nd + i] = si
436+
si = si * shape_ptr[i]
437+
elif flags & USM_ARRAY_F_CONTIGUOUS:
438+
si = 1
439+
for i in range(0, nd):
440+
shape_strides_ptr[nd + i] = si
441+
si = si * shape_ptr[i]
442+
else:
443+
stdlib.free(shape_strides_ptr)
444+
stdlib.free(dlmv_tensor)
445+
raise BufferError(
446+
"to_dlpack_versioned_capsule: Could not reconstruct "
447+
"strides for non-contiguous tensor"
448+
)
449+
416450
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
417451

418452
# this can all be a function for building the dl_tensor

0 commit comments

Comments
 (0)