Skip to content

Commit

Permalink
Simplified tests to leverage support for __eq__
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Oct 9, 2022
1 parent 5eb5365 commit 3dec850
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type):


def test_usm_ndarray_flags():
assert dpt.usm_ndarray((5,)).flags.flags == 3
assert dpt.usm_ndarray((5, 2)).flags.flags == 1
assert dpt.usm_ndarray((5, 2), order="F").flags.flags == 2
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.flags == 2
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.flags == 1
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.flags == 2
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.flags == 3
assert dpt.usm_ndarray((5,)).flags.fnc
assert dpt.usm_ndarray((5, 2)).flags.c_contiguous
assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fnc


@pytest.mark.parametrize(
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_usm_ndarray_props():
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
Xusm.ndim
repr(Xusm)
Xusm.flags.flags
Xusm.flags
Xusm.__sycl_usm_array_interface__
Xusm.device
Xusm.strides
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_pyx_capi_get_flags():
fn_restype=ctypes.c_int,
)
flags = get_flags_fn(X)
assert type(flags) is int and flags == X.flags.flags
assert type(flags) is int and X.flags == flags


def test_pyx_capi_get_offset():
Expand Down Expand Up @@ -919,7 +919,7 @@ def test_reshape():

X = dpt.usm_ndarray((1,))
Y = dpt.reshape(X, X.shape)
assert Y.flags.flags == X.flags.flags
assert Y.flags == X.flags

A = dpt.usm_ndarray((0,), "i4")
A1 = dpt.reshape(A, (0,))
Expand Down Expand Up @@ -1402,7 +1402,7 @@ def test_triu_order_k(order, k):
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.triu(Xnp, k)
assert Y.dtype == Ynp.dtype
assert X.flags.flags == Y.flags.flags
assert X.flags == Y.flags
assert np.array_equal(Ynp, dpt.asnumpy(Y))


Expand All @@ -1423,7 +1423,7 @@ def test_tril_order_k(order, k):
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
Ynp = np.tril(Xnp, k)
assert Y.dtype == Ynp.dtype
assert X.flags.flags == Y.flags.flags
assert X.flags == Y.flags
assert np.array_equal(Ynp, dpt.asnumpy(Y))


Expand Down

0 comments on commit 3dec850

Please sign in to comment.