diff --git a/README.md b/README.md index fc60ac71..3050b9a3 100644 --- a/README.md +++ b/README.md @@ -36,13 +36,7 @@ You need to specify the array library to test. It can be specified via the $ export ARRAY_API_TESTS_MODULE=numpy.array_api ``` -Alternately, change the `array_module` variable in `array_api_tests/_array_module.py` -line, e.g. - -```diff -- array_module = None -+ import numpy.array_api as array_module -``` +Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`. ### Run the suite diff --git a/array-api b/array-api index c5808f2b..ab69aa24 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit c5808f2b173ea52d813c450bec7b1beaf2973299 +Subproject commit ab69aa240025ff1d52525ce3859b69ebfd6b7faf diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index e083d522..9af6796b 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,13 +1,36 @@ +import os from functools import wraps -from os import getenv +from importlib import import_module from hypothesis import strategies as st from hypothesis.extra import array_api from . import _version -from ._array_module import mod as _xp -__all__ = ["api_version", "xps"] +__all__ = ["xp", "api_version", "xps"] + + +# You can comment the following out and instead import the specific array module +# you want to test, e.g. `import numpy.array_api as xp`. +if "ARRAY_API_TESTS_MODULE" in os.environ: + xp_name = os.environ["ARRAY_API_TESTS_MODULE"] + _module, _sub = xp_name, None + if "." in xp_name: + _module, _sub = xp_name.split(".", 1) + xp = import_module(_module) + if _sub: + try: + xp = getattr(xp, _sub) + except AttributeError: + # _sub may be a submodule that needs to be imported. WE can't + # do this in every case because some array modules are not + # submodules that can be imported (like mxnet.nd). + xp = import_module(xp_name) +else: + raise RuntimeError( + "No array module specified - either edit __init__.py or set the " + "ARRAY_API_TESTS_MODULE environment variable." + ) # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -43,9 +66,9 @@ def _from_dtype(*a, **kw): pass -api_version = getenv( - "ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12") +api_version = os.getenv( + "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12") ) -xps = array_api.make_strategies_namespace(_xp, api_version=api_version) +xps = array_api.make_strategies_namespace(xp, api_version=api_version) __version__ = _version.get_versions()["version"] diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 8a7c7887..1c52a983 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -1,35 +1,5 @@ -import os -from importlib import import_module +from . import stubs, xp -from . import stubs - -# Replace this with a specific array module to test it, for example, -# -# import numpy as array_module -array_module = None - -if array_module is None: - if 'ARRAY_API_TESTS_MODULE' in os.environ: - mod_name = os.environ['ARRAY_API_TESTS_MODULE'] - _module, _sub = mod_name, None - if '.' in mod_name: - _module, _sub = mod_name.split('.', 1) - mod = import_module(_module) - if _sub: - try: - mod = getattr(mod, _sub) - except AttributeError: - # _sub may be a submodule that needs to be imported. WE can't - # do this in every case because some array modules are not - # submodules that can be imported (like mxnet.nd). - mod = import_module(mod_name) - else: - raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable") -else: - mod = array_module - mod_name = mod.__name__ -# Names from the spec. This is what should actually be imported from this -# file. class _UndefinedStub: """ @@ -45,7 +15,7 @@ def __init__(self, name): self.name = name def _raise(self, *args, **kwargs): - raise AssertionError(f"{self.name} is not defined in {mod_name}") + raise AssertionError(f"{self.name} is not defined in {xp.__name__}") def __repr__(self): return f"" @@ -67,6 +37,6 @@ def __repr__(self): for attr in _top_level_attrs: try: - globals()[attr] = getattr(mod, attr) + globals()[attr] = getattr(xp, attr) except AttributeError: globals()[attr] = _UndefinedStub(attr) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 9c2f3bfe..3052d54f 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -6,7 +6,7 @@ from warnings import warn from . import api_version -from ._array_module import mod as xp +from . import xp from .stubs import name_to_func from .typing import DataType, ScalarType @@ -352,6 +352,9 @@ def result_type(*dtypes: DataType): "boolean": (xp.bool,), "integer": all_int_dtypes, "floating-point": real_float_dtypes, + "real-valued": real_float_dtypes, + "real-valued floating-point": real_float_dtypes, + "complex floating-point": complex_dtypes, "numeric": numeric_dtypes, "integer or boolean": bool_and_all_int_dtypes, } @@ -364,8 +367,6 @@ def result_type(*dtypes: DataType): dtype_category = "floating-point" dtypes = category_to_dtypes[dtype_category] func_in_dtypes[name] = dtypes -# See https://github.com/data-apis/array-api/pull/413 -func_in_dtypes["expm1"] = real_float_dtypes func_returns_bool = { diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 0134765b..39bb1223 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -6,6 +6,8 @@ from types import FunctionType, ModuleType from typing import Dict, List +from . import api_version + __all__ = [ "name_to_func", "array_methods", @@ -15,20 +17,21 @@ "extension_to_funcs", ] +spec_module = "_" + api_version.replace('.', '_') -spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification" +spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification" assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`" -sigs_dir = spec_dir / "signatures" +sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module assert sigs_dir.exists() -spec_abs_path: str = str(spec_dir.resolve()) -sys.path.append(spec_abs_path) -assert find_spec("signatures") is not None +sigs_abs_path: str = str(sigs_dir.parent.parent.resolve()) +sys.path.append(sigs_abs_path) +assert find_spec(f"array_api_stubs.{spec_module}") is not None name_to_mod: Dict[str, ModuleType] = {} for path in sigs_dir.glob("*.py"): name = path.name.replace(".py", "") - name_to_mod[name] = import_module(f"signatures.{name}") + name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}") array = name_to_mod["array_object"].array array_methods = [ @@ -70,3 +73,7 @@ for func in funcs: if func.__name__ not in name_to_func.keys(): name_to_func[func.__name__] = func + +# sanity check public attributes are not empty +for attr in __all__: + assert len(locals()[attr]) != 0, f"{attr} is empty" diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 42c0aef0..d44bdeba 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -13,7 +13,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from ._array_module import mod as _xp +from . import xp as _xp from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci diff --git a/array_api_tests/test_constants.py b/array_api_tests/test_constants.py index 51c02714..01bc5456 100644 --- a/array_api_tests/test_constants.py +++ b/array_api_tests/test_constants.py @@ -4,7 +4,7 @@ import pytest from . import dtype_helpers as dh -from ._array_module import mod as xp +from . import xp from .typing import Array pytestmark = pytest.mark.ci diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index fa69bbcd..ccae9930 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -11,7 +11,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from ._array_module import mod as _xp +from . import xp as _xp from .typing import DataType pytestmark = pytest.mark.ci diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 7dc70d56..39d96d6c 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -14,7 +14,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from ._array_module import mod as xp +from . import xp pytestmark = [ pytest.mark.ci, diff --git a/array_api_tests/test_has_names.py b/array_api_tests/test_has_names.py index 3c5c3263..53eb0965 100644 --- a/array_api_tests/test_has_names.py +++ b/array_api_tests/test_has_names.py @@ -5,7 +5,7 @@ import pytest -from ._array_module import mod as xp, mod_name +from . import xp from .stubs import (array_attributes, array_methods, category_to_funcs, extension_to_funcs, EXTENSIONS) @@ -27,13 +27,13 @@ def test_has_names(category, name): if category in EXTENSIONS: ext_mod = getattr(xp, category) - assert hasattr(ext_mod, name), f"{mod_name} is missing the {category} extension function {name}()" + assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()" elif category.startswith('array_'): # TODO: This would fail if ones() is missing. arr = xp.ones((1, 1)) if category == 'array_attribute': - assert hasattr(arr, name), f"The {mod_name} array object is missing the attribute {name}" + assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}" else: - assert hasattr(arr, name), f"The {mod_name} array object is missing the method {name}()" + assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()" else: - assert hasattr(xp, name), f"{mod_name} is missing the {category} function {name}()" + assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()" diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index e30f0755..ed68e99f 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -30,7 +30,7 @@ def squeeze(x, /, axis): import pytest from . import dtype_helpers as dh -from ._array_module import mod as xp +from . import xp from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func pytestmark = pytest.mark.ci diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 345d7fe5..cd9c81ba 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -32,8 +32,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps -from ._array_module import mod as xp +from . import xp, xps from .stubs import category_to_funcs pytestmark = pytest.mark.ci @@ -126,6 +125,8 @@ def abs_cond(i: float) -> bool: "infinity": float("inf"), "0": 0.0, "1": 1.0, + "False": 0.0, + "True": 1.0, } r_value = re.compile(r"([+-]?)(.+)") r_pi = re.compile(r"(\d?)π(?:/(\d))?") @@ -158,7 +159,10 @@ def parse_value(value_str: str) -> float: if denominator := pi_m.group(2): value /= int(denominator) else: - value = repr_to_value[m.group(2)] + try: + value = repr_to_value[m.group(2)] + except KeyError as e: + raise ParseError(value_str) from e if sign := m.group(1): if sign == "-": value *= -1 @@ -507,7 +511,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(<{self}>)" -r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters") +r_case_block = re.compile( + r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*" + r"(?:.+\n--+)?(?:\.\. versionchanged.*)?" +) r_case = re.compile(r"\s+-\s*(.*)\.") @@ -1121,6 +1128,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: iop_params = [] func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: + # if stub.__name__ == "abs": + # import ipdb; ipdb.set_trace() + if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") continue @@ -1167,6 +1177,8 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: op = getattr(operator, op_name) name_to_func[op_name] = op # We collect inplace operator test cases seperately + if stub.__name__ == "equal": + break iop_name = "__i" + op_name[2:] iop = getattr(operator, iop_name) for case in cases: @@ -1197,6 +1209,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: # its False - Hypothesis will complain if we reject too many examples, thus # indicating we've done something wrong. +# sanity checks +assert len(unary_params) != 0 +assert len(binary_params) != 0 +assert len(iop_params) != 0 + @pytest.mark.parametrize("func_name, func, case", unary_params) @given( @@ -1254,7 +1271,12 @@ def test_binary(func_name, func, case, x1, x2, data): res = func(x1, x2) # sanity check - ph.assert_result_shape(func_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=result_shape) + ph.assert_result_shape( + func_name, + in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, + expected=result_shape, + ) good_example = False for l_idx, r_idx, o_idx in all_indices: @@ -1306,7 +1328,9 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): res = xp.asarray(x1, copy=True) res = iop(res, x2) # sanity check - ph.assert_result_shape(iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape) + ph.assert_result_shape( + iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape + ) good_example = False for l_idx, r_idx, o_idx in all_indices: diff --git a/reporting.py b/reporting.py index f7c7d6b9..d73085ff 100644 --- a/reporting.py +++ b/reporting.py @@ -49,7 +49,7 @@ def pytest_metadata(metadata): """ Additional global metadata for --json-report. """ - metadata['array_api_tests_module'] = xp.mod_name + metadata['array_api_tests_module'] = xp.__name__ metadata['array_api_tests_version'] = __version__ @fixture(autouse=True)