Skip to content

Commit

Permalink
Add additional tests for array API inspection utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Jan 24, 2025
1 parent 5932e2f commit 06f266c
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions dpctl/tests/test_tensor_array_api_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,51 @@ def test_array_api_inspection_dtype_kind():
)
== info.dtypes()
)
assert info.dtypes(
kind=("integral", "real floating", "complex floating")
) == info.dtypes(kind="numeric")


def test_array_api_inspection_dtype_kind_errors():
info = dpt.__array_namespace_info__()
try:
info.default_device()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

with pytest.raises(ValueError):
info.dtypes(kind="error")

with pytest.raises(TypeError):
info.dtypes(kind={0: "real floating"})


def test_array_api_inspection_device_types():
info = dpt.__array_namespace_info__()
try:
dev = info.default_device()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

q = dpctl.SyclQueue(dev)
assert info.default_dtypes(device=q)
assert info.dtypes(device=q)

dev_dpt = dpt.Device.create_device(dev)
assert info.default_dtypes(device=dev_dpt)
assert info.dtypes(device=dev_dpt)

filter = dev.get_filter_string()
assert info.default_dtypes(device=filter)
assert info.dtypes(device=filter)


def test_array_api_inspection_device_errors():
info = dpt.__array_namespace_info__()

bad_dev = dict()
with pytest.raises(TypeError):
info.dtypes(device=bad_dev)

with pytest.raises(TypeError):
info.default_dtypes(device=bad_dev)

0 comments on commit 06f266c

Please sign in to comment.