Skip to content

Bump array-api submodule and utilise its all-versions setup #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,7 @@ You need to specify the array library to test. It can be specified via the
$ export ARRAY_API_TESTS_MODULE=numpy.array_api
```

Alternately, change the `array_module` variable in `array_api_tests/_array_module.py`
line, e.g.

```diff
- array_module = None
+ import numpy.array_api as array_module
```
Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.

### Run the suite

Expand Down
2 changes: 1 addition & 1 deletion array-api
Submodule array-api updated 211 files
35 changes: 29 additions & 6 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
import os
from functools import wraps
from os import getenv
from importlib import import_module

from hypothesis import strategies as st
from hypothesis.extra import array_api

from . import _version
from ._array_module import mod as _xp

__all__ = ["api_version", "xps"]
__all__ = ["xp", "api_version", "xps"]


# You can comment the following out and instead import the specific array module
# you want to test, e.g. `import numpy.array_api as xp`.
if "ARRAY_API_TESTS_MODULE" in os.environ:
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
_module, _sub = xp_name, None
if "." in xp_name:
_module, _sub = xp_name.split(".", 1)
xp = import_module(_module)
if _sub:
try:
xp = getattr(xp, _sub)
except AttributeError:
# _sub may be a submodule that needs to be imported. WE can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
xp = import_module(xp_name)
else:
raise RuntimeError(
"No array module specified - either edit __init__.py or set the "
"ARRAY_API_TESTS_MODULE environment variable."
)


# We monkey patch floats() to always disable subnormals as they are out-of-scope
Expand Down Expand Up @@ -43,9 +66,9 @@ def _from_dtype(*a, **kw):
pass


api_version = getenv(
"ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12")
api_version = os.getenv(
"ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12")
)
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)
xps = array_api.make_strategies_namespace(xp, api_version=api_version)

__version__ = _version.get_versions()["version"]
36 changes: 3 additions & 33 deletions array_api_tests/_array_module.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,5 @@
import os
from importlib import import_module
from . import stubs, xp

from . import stubs

# Replace this with a specific array module to test it, for example,
#
# import numpy as array_module
array_module = None

if array_module is None:
if 'ARRAY_API_TESTS_MODULE' in os.environ:
mod_name = os.environ['ARRAY_API_TESTS_MODULE']
_module, _sub = mod_name, None
if '.' in mod_name:
_module, _sub = mod_name.split('.', 1)
mod = import_module(_module)
if _sub:
try:
mod = getattr(mod, _sub)
except AttributeError:
# _sub may be a submodule that needs to be imported. WE can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
mod = import_module(mod_name)
else:
raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable")
else:
mod = array_module
mod_name = mod.__name__
# Names from the spec. This is what should actually be imported from this
# file.

class _UndefinedStub:
"""
Expand All @@ -45,7 +15,7 @@ def __init__(self, name):
self.name = name

def _raise(self, *args, **kwargs):
raise AssertionError(f"{self.name} is not defined in {mod_name}")
raise AssertionError(f"{self.name} is not defined in {xp.__name__}")

def __repr__(self):
return f"<undefined stub for {self.name!r}>"
Expand All @@ -67,6 +37,6 @@ def __repr__(self):

for attr in _top_level_attrs:
try:
globals()[attr] = getattr(mod, attr)
globals()[attr] = getattr(xp, attr)
except AttributeError:
globals()[attr] = _UndefinedStub(attr)
7 changes: 4 additions & 3 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from warnings import warn

from . import api_version
from ._array_module import mod as xp
from . import xp
from .stubs import name_to_func
from .typing import DataType, ScalarType

Expand Down Expand Up @@ -352,6 +352,9 @@ def result_type(*dtypes: DataType):
"boolean": (xp.bool,),
"integer": all_int_dtypes,
"floating-point": real_float_dtypes,
"real-valued": real_float_dtypes,
"real-valued floating-point": real_float_dtypes,
"complex floating-point": complex_dtypes,
"numeric": numeric_dtypes,
"integer or boolean": bool_and_all_int_dtypes,
}
Expand All @@ -364,8 +367,6 @@ def result_type(*dtypes: DataType):
dtype_category = "floating-point"
dtypes = category_to_dtypes[dtype_category]
func_in_dtypes[name] = dtypes
# See https://github.com/data-apis/array-api/pull/413
func_in_dtypes["expm1"] = real_float_dtypes


func_returns_bool = {
Expand Down
19 changes: 13 additions & 6 deletions array_api_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from types import FunctionType, ModuleType
from typing import Dict, List

from . import api_version

__all__ = [
"name_to_func",
"array_methods",
Expand All @@ -15,20 +17,21 @@
"extension_to_funcs",
]

spec_module = "_" + api_version.replace('.', '_')

spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification"
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification"
assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
sigs_dir = spec_dir / "signatures"
sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module
assert sigs_dir.exists()

spec_abs_path: str = str(spec_dir.resolve())
sys.path.append(spec_abs_path)
assert find_spec("signatures") is not None
sigs_abs_path: str = str(sigs_dir.parent.parent.resolve())
sys.path.append(sigs_abs_path)
assert find_spec(f"array_api_stubs.{spec_module}") is not None

name_to_mod: Dict[str, ModuleType] = {}
for path in sigs_dir.glob("*.py"):
name = path.name.replace(".py", "")
name_to_mod[name] = import_module(f"signatures.{name}")
name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}")

array = name_to_mod["array_object"].array
array_methods = [
Expand Down Expand Up @@ -70,3 +73,7 @@
for func in funcs:
if func.__name__ not in name_to_func.keys():
name_to_func[func.__name__] = func

# sanity check public attributes are not empty
for attr in __all__:
assert len(locals()[attr]) != 0, f"{attr} is empty"
2 changes: 1 addition & 1 deletion array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as _xp
from . import xp as _xp
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape

pytestmark = pytest.mark.ci
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from . import dtype_helpers as dh
from ._array_module import mod as xp
from . import xp
from .typing import Array

pytestmark = pytest.mark.ci
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as _xp
from . import xp as _xp
from .typing import DataType

pytestmark = pytest.mark.ci
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as xp
from . import xp

pytestmark = [
pytest.mark.ci,
Expand Down
10 changes: 5 additions & 5 deletions array_api_tests/test_has_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from ._array_module import mod as xp, mod_name
from . import xp
from .stubs import (array_attributes, array_methods, category_to_funcs,
extension_to_funcs, EXTENSIONS)

Expand All @@ -27,13 +27,13 @@
def test_has_names(category, name):
if category in EXTENSIONS:
ext_mod = getattr(xp, category)
assert hasattr(ext_mod, name), f"{mod_name} is missing the {category} extension function {name}()"
assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()"
elif category.startswith('array_'):
# TODO: This would fail if ones() is missing.
arr = xp.ones((1, 1))
if category == 'array_attribute':
assert hasattr(arr, name), f"The {mod_name} array object is missing the attribute {name}"
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}"
else:
assert hasattr(arr, name), f"The {mod_name} array object is missing the method {name}()"
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()"
else:
assert hasattr(xp, name), f"{mod_name} is missing the {category} function {name}()"
assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()"
2 changes: 1 addition & 1 deletion array_api_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def squeeze(x, /, axis):
import pytest

from . import dtype_helpers as dh
from ._array_module import mod as xp
from . import xp
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func

pytestmark = pytest.mark.ci
Expand Down
36 changes: 30 additions & 6 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as xp
from . import xp, xps
from .stubs import category_to_funcs

pytestmark = pytest.mark.ci
Expand Down Expand Up @@ -126,6 +125,8 @@ def abs_cond(i: float) -> bool:
"infinity": float("inf"),
"0": 0.0,
"1": 1.0,
"False": 0.0,
"True": 1.0,
}
r_value = re.compile(r"([+-]?)(.+)")
r_pi = re.compile(r"(\d?)π(?:/(\d))?")
Expand Down Expand Up @@ -158,7 +159,10 @@ def parse_value(value_str: str) -> float:
if denominator := pi_m.group(2):
value /= int(denominator)
else:
value = repr_to_value[m.group(2)]
try:
value = repr_to_value[m.group(2)]
except KeyError as e:
raise ParseError(value_str) from e
if sign := m.group(1):
if sign == "-":
value *= -1
Expand Down Expand Up @@ -507,7 +511,10 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(<{self}>)"


r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
r_case_block = re.compile(
r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*"
r"(?:.+\n--+)?(?:\.\. versionchanged.*)?"
)
r_case = re.compile(r"\s+-\s*(.*)\.")


Expand Down Expand Up @@ -1121,6 +1128,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
iop_params = []
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
for stub in category_to_funcs["elementwise"]:
# if stub.__name__ == "abs":
# import ipdb; ipdb.set_trace()

if stub.__doc__ is None:
warn(f"{stub.__name__}() stub has no docstring")
continue
Expand Down Expand Up @@ -1167,6 +1177,8 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
op = getattr(operator, op_name)
name_to_func[op_name] = op
# We collect inplace operator test cases seperately
if stub.__name__ == "equal":
break
iop_name = "__i" + op_name[2:]
iop = getattr(operator, iop_name)
for case in cases:
Expand Down Expand Up @@ -1197,6 +1209,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
# its False - Hypothesis will complain if we reject too many examples, thus
# indicating we've done something wrong.

# sanity checks
assert len(unary_params) != 0
assert len(binary_params) != 0
assert len(iop_params) != 0


@pytest.mark.parametrize("func_name, func, case", unary_params)
@given(
Expand Down Expand Up @@ -1254,7 +1271,12 @@ def test_binary(func_name, func, case, x1, x2, data):

res = func(x1, x2)
# sanity check
ph.assert_result_shape(func_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=result_shape)
ph.assert_result_shape(
func_name,
in_shapes=[x1.shape, x2.shape],
out_shape=res.shape,
expected=result_shape,
)

good_example = False
for l_idx, r_idx, o_idx in all_indices:
Expand Down Expand Up @@ -1306,7 +1328,9 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
res = xp.asarray(x1, copy=True)
res = iop(res, x2)
# sanity check
ph.assert_result_shape(iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape)
ph.assert_result_shape(
iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape
)

good_example = False
for l_idx, r_idx, o_idx in all_indices:
Expand Down
2 changes: 1 addition & 1 deletion reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def pytest_metadata(metadata):
"""
Additional global metadata for --json-report.
"""
metadata['array_api_tests_module'] = xp.mod_name
metadata['array_api_tests_module'] = xp.__name__
metadata['array_api_tests_version'] = __version__

@fixture(autouse=True)
Expand Down