diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index c6e90f75..d9d2362d 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,8 +1,8 @@ import re +from collections import defaultdict from collections.abc import Mapping from functools import lru_cache -from inspect import signature -from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union +from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union from warnings import warn from . import _array_module as xp @@ -323,7 +323,7 @@ def result_type(*dtypes: DataType): "numeric": numeric_dtypes, "integer or boolean": bool_and_all_int_dtypes, } -func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} +func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes) for name, func in name_to_func.items(): if m := r_in_dtypes.search(func.__doc__): dtype_category = m.group(1) @@ -331,8 +331,6 @@ def result_type(*dtypes: DataType): dtype_category = "floating-point" dtypes = category_to_dtypes[dtype_category] func_in_dtypes[name] = dtypes - elif any("x" in name for name in signature(func).parameters.keys()): - func_in_dtypes[name] = all_dtypes # See https://github.com/data-apis/array-api/pull/413 func_in_dtypes["expm1"] = float_dtypes diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 2db804b1..e30f0755 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -20,22 +20,18 @@ def squeeze(x, /, axis): ... """ +from collections import defaultdict +from copy import copy from inspect import Parameter, Signature, signature from types import FunctionType -from typing import Any, Callable, Dict, List, Literal, get_args +from typing import Any, Callable, Dict, Literal, get_args +from warnings import warn import pytest -from hypothesis import given, note, settings -from hypothesis import strategies as st -from hypothesis.strategies import DataObject from . import dtype_helpers as dh -from . import hypothesis_helpers as hh -from . import xps -from ._array_module import _UndefinedStub from ._array_module import mod as xp -from .stubs import array_methods, category_to_funcs, extension_to_funcs -from .typing import Array, DataType +from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func pytestmark = pytest.mark.ci @@ -93,6 +89,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature): stub_param.name in sig.parameters.keys() ), f"Argument '{stub_param.name}' missing from signature" param = next(p for p in params if p.name == stub_param.name) + f_stub_kind = kind_to_str[stub_param.kind] assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], ( f"{param.name} is a {kind_to_str[param.kind]}, " f"but should be a {f_stub_kind} " @@ -100,17 +97,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature): ) -def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]: - if func_name in dh.func_in_dtypes.keys(): - dtypes = dh.func_in_dtypes[func_name] - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] - return st.sampled_from(dtypes) - else: - return xps.scalar_dtypes() - - -def make_pretty_func(func_name: str, *args: Any, **kwargs: Any): +def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str: f_sig = f"{func_name}(" f_sig += ", ".join(str(a) for a in args) if len(kwargs) != 0: @@ -121,96 +108,165 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any): return f_sig -matrixy_funcs: List[FunctionType] = [ - *category_to_funcs["linear_algebra"], - *extension_to_funcs["linalg"], +# We test uninspectable signatures by passing valid, manually-defined arguments +# to the signature's function/method. +# +# Arguments which require use of the array module are specified as string +# expressions to be eval()'d on runtime. This is as opposed to just using the +# array module whilst setting up the tests, which is prone to halt the entire +# test suite if an array module doesn't support a given expression. +func_to_specified_args = defaultdict( + dict, + { + "permute_dims": {"axes": 0}, + "reshape": {"shape": (1, 5)}, + "broadcast_to": {"shape": (1, 5)}, + "asarray": {"obj": [0, 1, 2, 3, 4]}, + "full_like": {"fill_value": 42}, + "matrix_power": {"n": 2}, + }, +) +func_to_specified_arg_exprs = defaultdict( + dict, + { + "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"}, + "iinfo": {"type": "xp.int64"}, + "finfo": {"type": "xp.float64"}, + "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"}, + "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"}, + "solve": { + a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"] + }, + }, +) +# We default most array arguments heuristically. As functions/methods work only +# with arrays of certain dtypes and shapes, we specify only supported arrays +# respective to the function. +casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"] +matrixy_names = [ + f.__name__ + for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"] ] -matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs] matrixy_names += ["__matmul__", "triu", "tril"] +for func_name, func in name_to_func.items(): + stub_sig = signature(func) + array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"} + if func in array_methods: + array_argnames.add("self") + array_argnames -= set(func_to_specified_arg_exprs[func_name].keys()) + if len(array_argnames) > 0: + in_dtypes = dh.func_in_dtypes[func_name] + for dtype_name in ["float64", "bool", "int64", "complex128"]: + # We try float64 first because uninspectable numerical functions + # tend to support float inputs first-and-foremost (i.e. PyTorch) + try: + dtype = getattr(xp, dtype_name) + except AttributeError: + pass + else: + if dtype in in_dtypes: + if func_name in casty_names: + shape = () + elif func_name in matrixy_names: + shape = (3, 3) + else: + shape = (5,) + fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})" + break + else: + warn( + f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does " + "not contain any assumed dtypes, so skipping specifying fallback array." + ) + continue + for argname in array_argnames: + func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr + +def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature): + params = list(stub_sig.parameters.values()) -@given(data=st.data()) -@settings(max_examples=1) -def _test_uninspectable_func( - func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject -): - skip_msg = ( - f"Signature for {func_name}() is not inspectable " - "and is too troublesome to test for otherwise" + if len(params) == 0: + func() + return + + uninspectable_msg = ( + f"Note {func_name}() is not inspectable so arguments are passed " + "manually to test the signature." ) - if func_name in [ - # 0d shapes - "__bool__", - "__int__", - "__index__", - "__float__", - # x2 elements must be >=0 - "pow", - "bitwise_left_shift", - "bitwise_right_shift", - # axis default invalid with 0d shapes - "sort", - # shape requirements - *matrixy_names, - ]: - pytest.skip(skip_msg) - - param_to_value: Dict[Parameter, Any] = {} - for param in stub_sig.parameters.values(): - if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]: + + argname_to_arg = copy(func_to_specified_args[func_name]) + argname_to_expr = func_to_specified_arg_exprs[func_name] + for argname, expr in argname_to_expr.items(): + assert argname not in argname_to_arg.keys() # sanity check + try: + argname_to_arg[argname] = eval(expr, {"xp": xp}) + except Exception as e: pytest.skip( - skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})" - ) - elif param.default != Parameter.empty: - value = param.default - elif param.name in ["x", "x1"]: - dtypes = get_dtypes_strategy(func_name) - value = data.draw( - xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name + f"Exception occured when evaluating {argname}={expr}: {e}\n" + f"{uninspectable_msg}" ) - elif param.name in ["x2", "other"]: - if param.name == "x2": - assert "x1" in [p.name for p in param_to_value.keys()] # sanity check - orig = next(v for p, v in param_to_value.items() if p.name == "x1") + + posargs = [] + posorkw_args = {} + kwargs = {} + no_arg_msg = ( + "We have no argument specified for '{}'. Please ensure you're using " + "the latest version of array-api-tests, then open an issue if one " + f"doesn't already exist. {uninspectable_msg}" + ) + for param in params: + if param.kind == Parameter.POSITIONAL_ONLY: + try: + posargs.append(argname_to_arg[param.name]) + except KeyError: + pytest.skip(no_arg_msg.format(param.name)) + elif param.kind == Parameter.POSITIONAL_OR_KEYWORD: + if param.default == Parameter.empty: + try: + posorkw_args[param.name] = argname_to_arg[param.name] + except KeyError: + pytest.skip(no_arg_msg.format(param.name)) else: - assert array is not None # sanity check - orig = array - value = data.draw( - xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name - ) + assert argname_to_arg[param.name] + posorkw_args[param.name] = param.default + elif param.kind == Parameter.KEYWORD_ONLY: + assert param.default != Parameter.empty # sanity check + kwargs[param.name] = param.default else: - pytest.skip( - skip_msg + f" (because no default was found for argument {param.name})" - ) - param_to_value[param] = value - - args: List[Any] = [ - v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY - ] - kwargs: Dict[str, Any] = { - p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY - } - f_func = make_pretty_func(func_name, *args, **kwargs) - note(f"trying {f_func}") - func(*args, **kwargs) + assert param.kind in VAR_KINDS # sanity check + pytest.skip(no_arg_msg.format(param.name)) + if len(posorkw_args) == 0: + func(*posargs, **kwargs) + else: + posorkw_name_to_arg_pairs = list(posorkw_args.items()) + for i in range(len(posorkw_name_to_arg_pairs), -1, -1): + extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]] + extra_kwargs = dict(posorkw_name_to_arg_pairs[i:]) + func(*posargs, *extra_posargs, **kwargs, **extra_kwargs) -def _test_func_signature(func: Callable, stub: FunctionType, array=None): +def _test_func_signature(func: Callable, stub: FunctionType, is_method=False): stub_sig = signature(stub) # If testing against array, ignore 'self' arg in stub as it won't be present # in func (which should be a method). - if array is not None: + if is_method: stub_params = list(stub_sig.parameters.values()) - del stub_params[0] + if stub_params[0].name == "self": + del stub_params[0] stub_sig = Signature( parameters=stub_params, return_annotation=stub_sig.return_annotation ) try: sig = signature(func) - _test_inspectable_func(sig, stub_sig) except ValueError: - _test_uninspectable_func(stub.__name__, func, stub_sig, array) + try: + _test_uninspectable_func(stub.__name__, func, stub_sig) + except Exception as e: + raise e from None # suppress parent exception for cleaner pytest output + else: + _test_inspectable_func(sig, stub_sig) @pytest.mark.parametrize( @@ -244,11 +300,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType): @pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__) -@given(st.data()) -@settings(max_examples=1) -def test_array_method_signature(stub: FunctionType, data: DataObject): - dtypes = get_dtypes_strategy(stub.__name__) - x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x") +def test_array_method_signature(stub: FunctionType): + x_expr = func_to_specified_arg_exprs[stub.__name__]["self"] + try: + x = eval(x_expr, {"xp": xp}) + except Exception as e: + pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}") assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}" method = getattr(x, stub.__name__) - _test_func_signature(method, stub, array=x) + _test_func_signature(method, stub, is_method=True)