Skip to content

Commit

Permalink
Fix device kwarg-only argument being used as positional for calls t…
Browse files Browse the repository at this point in the history
…o `default_dtypes` throughout tests
  • Loading branch information
ndgrigorian committed Apr 1, 2024
1 parent f0ebcf1 commit aa033eb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dpctl/tests/test_tensor_array_api_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes():

info = dpt.__array_namespace_info__()
default_dts_nodev = info.default_dtypes()
default_dts_dev = info.default_dtypes(dev)
default_dts_dev = info.default_dtypes(device=dev)

assert (
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_mixed_index_getitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
y = x[i0, i1b, i2]
Expand All @@ -503,7 +503,7 @@ def test_mixed_index_setitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
v_shape = (3, int(dpt.sum(i1b, dtype="i8")))
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np):
assert hay_stack.ndim == 1

info_ = dpt.__array_namespace_info__()
default_dts_dev = info_.default_dtypes(hay_stack.device)
default_dts_dev = info_.default_dtypes(device=hay_stack.device)
index_dt = default_dts_dev["indexing"]

p_left = dpt.searchsorted(hay_stack, needles, side="left")
Expand Down

0 comments on commit aa033eb

Please sign in to comment.