Skip to content

Commit da72671

Browse files
Add more dlpack tests
1 parent 8a92deb commit da72671

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,45 @@ def test_dlpack_capsule_readonly_array_to_kdlcpu():
664664
assert not y1.flags["W"]
665665

666666

667+
def test_to_dlpack_capsule_c_and_f_contig():
668+
try:
669+
x = dpt.asarray(np.random.rand(2, 3))
670+
except dpctl.SyclDeviceCreationError:
671+
pytest.skip("No default device available")
672+
673+
cap = _dlp.to_dlpack_capsule(x)
674+
y = _dlp.from_dlpack_capsule(cap)
675+
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
676+
assert x.strides == y.strides
677+
678+
x_f = x.T
679+
cap = _dlp.to_dlpack_capsule(x_f)
680+
yf = _dlp.from_dlpack_capsule(cap)
681+
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
682+
assert x_f.strides == yf.strides
683+
del cap
684+
685+
686+
def test_to_dlpack_versioned_capsule_c_and_f_contig():
687+
try:
688+
x = dpt.asarray(np.random.rand(2, 3))
689+
max_supported_ver = _dlp.get_build_dlpack_version()
690+
except dpctl.SyclDeviceCreationError:
691+
pytest.skip("No default device available")
692+
693+
cap = x.__dlpack__(max_version=max_supported_ver)
694+
y = _dlp.from_dlpack_capsule(cap)
695+
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
696+
assert x.strides == y.strides
697+
698+
x_f = x.T
699+
cap = x_f.__dlpack__(max_version=max_supported_ver)
700+
yf = _dlp.from_dlpack_capsule(cap)
701+
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
702+
assert x_f.strides == yf.strides
703+
del cap
704+
705+
667706
def test_used_dlpack_capsule_from_numpy():
668707
get_queue_or_skip()
669708

0 commit comments

Comments
 (0)