diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index b534e59328..baf870c036 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -101,9 +101,7 @@ def test_from_dlpack(shape, typestr, usm_type): X = dpt.empty(shape, dtype=typestr, usm_type=usm_type, device=sycl_dev) Y = dpt.from_dlpack(X) assert X.shape == Y.shape - assert X.dtype == Y.dtype or ( - str(X.dtype) == "bool" and str(Y.dtype) == "uint8" - ) + assert X.dtype == Y.dtype assert X.sycl_device == Y.sycl_device assert X.usm_type == Y.usm_type assert X._pointer == Y._pointer @@ -125,9 +123,7 @@ def test_from_dlpack_strides(mod, typestr, usm_type): X = X0[slice(-start - 1, None, -mod)] Y = dpt.from_dlpack(X) assert X.shape == Y.shape - assert X.dtype == Y.dtype or ( - str(X.dtype) == "bool" and str(Y.dtype) == "uint8" - ) + assert X.dtype == Y.dtype assert X.sycl_device == Y.sycl_device assert X.usm_type == Y.usm_type assert X._pointer == Y._pointer