Skip to content

Commit

Permalink
restore code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 26, 2024
1 parent c5b82db commit 7b99449
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 64 deletions.
8 changes: 4 additions & 4 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import pytest

wrapped_libraries = ["numpy", "paddle", "torch"]
all_libraries = wrapped_libraries + []
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
all_libraries = wrapped_libraries + ["jax.numpy"]

# `sparse` added array API support as of Python 3.10.
# if sys.version_info >= (3, 10):
# all_libraries.append('sparse')
if sys.version_info >= (3, 10):
all_libraries.append('sparse')

def import_(library, wrapper=False):
if library == 'cupy':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ def test_all(library):
all_names = module.__all__

if set(dir_names) != set(all_names):
assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
55 changes: 17 additions & 38 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# import jax
import numpy as np
import pytest
# import torch
import torch
import paddle

import array_api_compat
Expand Down Expand Up @@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)

# def test_jax_zero_gradient():
# jx = jax.numpy.arange(4)
# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
# assert (array_api_compat.get_namespace(jax_zero) is
# array_api_compat.get_namespace(jx))
def test_jax_zero_gradient():
jx = jax.numpy.arange(4)

Check failure on line 77 in tests/test_array_namespace.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

tests/test_array_namespace.py:77:10: F821 Undefined name `jax`
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)

Check failure on line 78 in tests/test_array_namespace.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

tests/test_array_namespace.py:78:16: F821 Undefined name `jax`

Check failure on line 78 in tests/test_array_namespace.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

tests/test_array_namespace.py:78:25: F821 Undefined name `jax`

Check failure on line 78 in tests/test_array_namespace.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

tests/test_array_namespace.py:78:34: F821 Undefined name `jax`
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
Expand All @@ -87,53 +87,32 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))

# def test_array_namespace_errors_torch():
# y = torch.asarray([1, 2])
# x = np.asarray([1, 2])
# pytest.raises(TypeError, lambda: array_namespace(x, y))
def test_array_namespace_errors_torch():
y = torch.asarray([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))


def test_array_namespace_errors_paddle():
y = paddle.to_tensor([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))


# def test_api_version():
# x = torch.asarray([1, 2])
# torch_ = import_("torch", wrapper=True)
# assert array_namespace(x, api_version="2023.12") == torch_
# assert array_namespace(x, api_version=None) == torch_
# assert array_namespace(x) == torch_
# # Should issue a warning
# with warnings.catch_warnings(record=True) as w:
# assert array_namespace(x, api_version="2021.12") == torch_
# assert len(w) == 1
# assert "2021.12" in str(w[0].message)

# # Should issue a warning
# with warnings.catch_warnings(record=True) as w:
# assert array_namespace(x, api_version="2022.12") == torch_
# assert len(w) == 1
# assert "2022.12" in str(w[0].message)

# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))

def test_api_version():
x = paddle.asarray([1, 2])
paddle_ = import_("paddle", wrapper=True)
assert array_namespace(x, api_version="2023.12") == paddle_
assert array_namespace(x, api_version=None) == paddle_
assert array_namespace(x) == paddle_
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == paddle_
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)

# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2022.12") == paddle_
assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)

Expand Down
16 changes: 8 additions & 8 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@

is_array_functions = {
'numpy': 'is_numpy_array',
# 'cupy': 'is_cupy_array',
'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
# 'dask.array': 'is_dask_array',
# 'jax.numpy': 'is_jax_array',
# 'sparse': 'is_pydata_sparse_array',
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
'paddle': 'is_paddle_array',
}

is_namespace_functions = {
'numpy': 'is_numpy_namespace',
# 'cupy': 'is_cupy_namespace',
'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
# 'dask.array': 'is_dask_namespace',
# 'jax.numpy': 'is_jax_namespace',
# 'sparse': 'is_pydata_sparse_namespace',
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'paddle': 'is_paddle_namespace',
}

Expand Down
4 changes: 2 additions & 2 deletions tests/test_no_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def _test_dependency(mod):

@pytest.mark.parametrize("library",
[
"numpy",
"paddle", "array_api_strict",
"numpy", "cupy", "numpy", "torch", "dask.array",
"jax.numpy", "sparse", "paddle", "array_api_strict"
]
)
def test_numpy_dependency(library):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ def test_vendoring_numpy():
uses_numpy._test_numpy()


# def test_vendoring_cupy():
# pytest.importorskip("cupy")
def test_vendoring_cupy():
pytest.importorskip("cupy")

# from vendor_test import uses_cupy
from vendor_test import uses_cupy

# uses_cupy._test_cupy()
uses_cupy._test_cupy()


# def test_vendoring_torch():
# from vendor_test import uses_torch
def test_vendoring_torch():
from vendor_test import uses_torch

# uses_torch._test_torch()
uses_torch._test_torch()


# def test_vendoring_dask():
# from vendor_test import uses_dask
# uses_dask._test_dask()
def test_vendoring_dask():
from vendor_test import uses_dask
uses_dask._test_dask()


def test_vendoring_paddle():
Expand Down

0 comments on commit 7b99449

Please sign in to comment.