diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 9ef50705..8acf11fa 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -34,6 +34,7 @@ "is_int_dtype", "is_float_dtype", "get_scalar_type", + "is_scalar", "dtype_ranges", "default_int", "default_uint", @@ -189,6 +190,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType: else: return bool +def is_scalar(x): + return isinstance(x, (int, float, complex, bool)) def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: dtype_value_pairs = [] @@ -275,6 +278,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg): _dtype = x_dtype else: _dtype = default_dtype + elif api_version >= '2023.12': + # Starting in 2023.12, floats should not promote with dtype=None + _dtype = x_dtype elif is_float_dtype(x_dtype, include_complex=False): if dtype_nbits[x_dtype] > dtype_nbits[default_float]: _dtype = x_dtype @@ -322,6 +328,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg): ) else: default_complex = None + if dtype_nbits[default_int] == 32: default_uint = _name_to_dtype.get("uint32") else: diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 7377274b..08eab19b 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -370,7 +370,7 @@ def scalars(draw, dtypes, finite=False): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - if dtype in dh.dtype_ranges: + if dh.is_int_dtype(dtype): m, M = dh.dtype_ranges[dtype] return draw(integers(m, M)) elif dtype == bool_dtype: diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 50f99c1c..9025c461 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -44,7 +44,7 @@ category_to_funcs: Dict[str, List[FunctionType]] = {} for name, mod in name_to_mod.items(): - if name.endswith("_functions") or name == "info": # info functions file just named info.py + if name.endswith("_functions"): category = name.replace("_functions", "") objects = [getattr(mod, name) for name in mod.__all__] assert all(isinstance(o, FunctionType) for o in objects) # sanity check @@ -55,7 +55,26 @@ all_funcs.extend(funcs) name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} -EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available +info_funcs = [] +if api_version >= "2023.12": + # The info functions in the stubs are in info.py, but this is not a name + # in the standard. + info_mod = name_to_mod["info"] + + # Note that __array_namespace_info__ is in info.__all__ but it is in the + # top-level namespace, not the info namespace. + info_funcs = [getattr(info_mod, name) for name in info_mod.__all__ + if name != '__array_namespace_info__'] + assert all(isinstance(f, FunctionType) for f in info_funcs) + name_to_func.update({f.__name__: f for f in info_funcs}) + + all_funcs.append(info_mod.__array_namespace_info__) + name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__ + category_to_funcs['info'] = [info_mod.__array_namespace_info__] + +EXTENSIONS: List[str] = ["linalg"] +if api_version >= "2022.12": + EXTENSIONS.append("fft") extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 55391e43..a77a0b6a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -426,6 +426,22 @@ def test_tile(x, data): def test_unstack(x, data): axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis") kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw") - out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype) - ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: shapes and values testing + out = xp.unstack(x, **kw) + + assert isinstance(out, tuple) + assert len(out) == x.shape[axis] + expected_shape = list(x.shape) + expected_shape.pop(axis) + expected_shape = tuple(expected_shape) + for i in range(x.shape[axis]): + arr = out[i] + ph.assert_result_shape("unstack", in_shapes=[x.shape], + out_shape=arr.shape, expected=expected_shape, + kw=kw, repr_name=f"out[{i}].shape") + + ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype, + repr_name=f"out[{i}].dtype") + + idx = [slice(None)] * x.ndim + idx[axis] = i + ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]") diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 18ed9f55..4c019a2d 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -4,6 +4,7 @@ import cmath import math import operator +import builtins from copy import copy from enum import Enum, auto from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union @@ -369,6 +370,8 @@ def right_scalar_assert_against_refimpl( See unary_assert_against_refimpl for more information. """ + if expr_template is None: + expr_template = func_name + "({}, {})={}" if left.dtype in dh.complex_dtypes: component_filter = copy(filter_) filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) @@ -422,7 +425,7 @@ def right_scalar_assert_against_refimpl( ) -# When appropiate, this module tests operators alongside their respective +# When appropriate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. # @@ -432,8 +435,8 @@ def right_scalar_assert_against_refimpl( # - The argument strategies, which can be used to draw arguments for the test # case. They may require additional filtering for certain test cases. # - right_is_scalar (binary parameters only), which denotes if the right -# argument is a scalar in a test case. This can be used to appropiately adjust -# draw filtering and test logic. +# argument is a scalar in a test case. This can be used to appropriately +# adjust draw filtering and test logic. func_to_op = {v: k for k, v in dh.op_to_func.items()} @@ -475,7 +478,7 @@ def make_unary_params( ) if api_version < min_version: marks = pytest.mark.skip( - reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}" ) else: marks = () @@ -924,15 +927,125 @@ def test_ceil(x): @pytest.mark.min_version("2023.12") -@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes())) -def test_clip(x): +@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data()) +def test_clip(x, data): # TODO: test min/max kwargs, adjust values testing accordingly - out = xp.clip(x) - ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("clip", out_shape=out.shape, expected=x.shape) - ph.assert_array_elements("clip", out=out, expected=x) + # Ensure that if both min and max are arrays that all three of x, min, max + # are broadcast compatible. + shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2, + base_shape=x.shape), + label="min.shape, max.shape") + + dtypes = hh.real_floating_dtypes if dh.is_float_dtype(x.dtype) else hh.int_dtypes + + min = data.draw(st.one_of( + st.none(), + hh.scalars(dtypes=st.just(x.dtype)), + hh.arrays(dtype=dtypes, shape=shape1), + ), label="min") + max = data.draw(st.one_of( + st.none(), + hh.scalars(dtypes=st.just(x.dtype)), + hh.arrays(dtype=dtypes, shape=shape2), + ), label="max") + + # min > max is undefined (but allow nans) + assume(min is None or max is None or not xp.any(xp.asarray(min) > xp.asarray(max))) + + kw = data.draw( + hh.specified_kwargs( + ("min", min, None), + ("max", max, None)), + label="kwargs") + + out = xp.clip(x, **kw) + + # min and max do not participate in type promotion + ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) + shapes = [x.shape] + if min is not None and not dh.is_scalar(min): + shapes.append(min.shape) + if max is not None and not dh.is_scalar(max): + shapes.append(max.shape) + expected_shape = sh.broadcast_shapes(*shapes) + ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) + + if min is max is None: + ph.assert_array_elements("clip", out=out, expected=x) + elif max is None: + # If one operand is nan, the result is nan. See + # https://github.com/data-apis/array-api/pull/813. + def refimpl(_x, _min): + if math.isnan(_x) or math.isnan(_min): + return math.nan + return builtins.max(_x, _min) + if dh.is_scalar(min): + right_scalar_assert_against_refimpl( + "clip", x, min, out, refimpl, + left_sym="x", + expr_template="clip({}, min={})", + ) + else: + binary_assert_against_refimpl( + "clip", x, min, out, refimpl, + left_sym="x", right_sym="min", + expr_template="clip({}, min={})", + ) + elif min is None: + def refimpl(_x, _max): + if math.isnan(_x) or math.isnan(_max): + return math.nan + return builtins.min(_x, _max) + if dh.is_scalar(max): + right_scalar_assert_against_refimpl( + "clip", x, max, out, refimpl, + left_sym="x", + expr_template="clip({}, max={})", + ) + else: + binary_assert_against_refimpl( + "clip", x, max, out, refimpl, + left_sym="x", right_sym="max", + expr_template="clip({}, max={})", + ) + else: + def refimpl(_x, _min, _max): + if math.isnan(_x) or math.isnan(_min) or math.isnan(_max): + return math.nan + return builtins.min(builtins.max(_x, _min), _max) + + # This is based on right_scalar_assert_against_refimpl and + # binary_assert_against_refimpl. clip() is currently the only ternary + # elementwise function and the only function that supports arrays and + # scalars. However, where() (in test_searching_functions) is similar + # and if scalar support is added to it, we may want to factor out and + # reuse this logic. + + stype = dh.get_scalar_type(x.dtype) + min_shape = () if dh.is_scalar(min) else min.shape + max_shape = () if dh.is_scalar(max) else max.shape + + for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( + x.shape, min_shape, max_shape, out.shape): + x_val = stype(x[x_idx]) + min_val = min if dh.is_scalar(min) else min[min_idx] + min_val = stype(min_val) + max_val = max if dh.is_scalar(max) else max[max_idx] + max_val = stype(max_val) + expected = refimpl(x_val, min_val, max_val) + out_val = stype(out[o_idx]) + if math.isnan(expected): + assert math.isnan(out_val), ( + f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + else: + assert out_val == expected, ( + f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" +) if api_version >= "2022.12": @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index a12e9d52..c479d2f9 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -173,16 +173,14 @@ def test_where(shapes, dtypes, data): @given(data=st.data()) def test_searchsorted(data): # TODO: test side="right" + # TODO: Allow different dtypes for x1 and x2 _x1 = data.draw( st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), label="_x1", ) x1 = xp.asarray(_x1, dtype=dh.default_float) if data.draw(st.booleans(), label="use sorter?"): - sorter = data.draw( - st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)), - label="sorter", - ) + sorter = xp.argsort(x1) else: sorter = None x1 = xp.sort(x1) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index b109ebd0..1c9a8ef6 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -31,7 +31,8 @@ def squeeze(x, /, axis): from . import dtype_helpers as dh from . import xp -from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func +from .stubs import (array_methods, category_to_funcs, extension_to_funcs, + name_to_func, info_funcs) ParameterKind = Literal[ Parameter.POSITIONAL_ONLY, @@ -308,3 +309,15 @@ def test_array_method_signature(stub: FunctionType): assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}" method = getattr(x, stub.__name__) _test_func_signature(method, stub, is_method=True) + +if info_funcs: # pytest fails collecting if info_funcs is empty + @pytest.mark.min_version("2023.12") + @pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__) + def test_info_func_signature(stub: FunctionType): + try: + info_namespace = xp.__array_namespace_info__() + except Exception as e: + raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}") + + func = getattr(info_namespace, stub.__name__) + _test_func_signature(func, stub) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4ac421c7..817805b2 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -629,7 +629,7 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_case_block(case_block: str) -> List[UnaryCase]: +def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -660,7 +660,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: ... ''' ... >>> case_block = r_case_block.search(sqrt.__doc__).group(1) - >>> unary_cases = parse_unary_case_block(case_block) + >>> unary_cases = parse_unary_case_block(case_block, 'sqrt') >>> for case in unary_cases: ... print(repr(case)) UnaryCase( NaN>) @@ -691,7 +691,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) _check_result, result_expr = parse_result(m.group(2)) except ParseError as e: - warn(f"not machine-readable: '{e.value}'") + warn(f"case for {func_name} not machine-readable: '{e.value}'") continue cond_expr = cond_expr_template.replace("{}", "x_i") # Do not define check_result in this function's body - see @@ -708,7 +708,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases.append(case) else: if not r_remaining_case.search(case_str): - warn(f"case not machine-readable: '{case_str}'") + warn(f"case for {func_name} not machine-readable: '{case_str}'") return cases @@ -1102,7 +1102,7 @@ def cond(i1: float, i2: float) -> bool: r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_case_block(case_block: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1133,7 +1133,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: ... ''' ... >>> case_block = r_case_block.search(logaddexp.__doc__).group(1) - >>> binary_cases = parse_binary_case_block(case_block) + >>> binary_cases = parse_binary_case_block(case_block, 'logaddexp') >>> for case in binary_cases: ... print(repr(case)) BinaryCase( NaN>) @@ -1151,10 +1151,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: case = parse_binary_case(case_str) cases.append(case) except ParseError as e: - warn(f"not machine-readable: '{e.value}'") + warn(f"case for {func_name} not machine-readable: '{e.value}'") else: if not r_remaining_case.match(case_str): - warn(f"case not machine-readable: '{case_str}'") + warn(f"case for {func_name} not machine-readable: '{case_str}'") return cases @@ -1163,8 +1163,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: iop_params = [] func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: + func_name = stub.__name__ if stub.__doc__ is None: - warn(f"{stub.__name__}() stub has no docstring") + warn(f"{func_name}() stub has no docstring") continue if m := r_case_block.search(stub.__doc__): case_block = m.group(1) @@ -1172,10 +1173,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: continue marks = [] try: - func = getattr(xp, stub.__name__) + func = getattr(xp, func_name) except AttributeError: marks.append( - pytest.mark.skip(reason=f"{stub.__name__} not found in array module") + pytest.mark.skip(reason=f"{func_name} not found in array module") ) func = None sig = inspect.signature(stub) @@ -1184,10 +1185,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if cases := parse_unary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] + if cases := parse_unary_case_block(case_block, func_name): + name_to_func = {func_name: func} + if func_name in func_to_op.keys(): + op_name = func_to_op[func_name] op = getattr(operator, op_name) name_to_func[op_name] = op for func_name, func in name_to_func.items(): @@ -1196,20 +1197,20 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) unary_params.append(p) else: - warn(f"Special cases found for {stub.__name__} but none were parsed") + warn(f"Special cases found for {func_name} but none were parsed") continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] + if cases := parse_binary_case_block(case_block, func_name): + name_to_func = {func_name: func} + if func_name in func_to_op.keys(): + op_name = func_to_op[func_name] op = getattr(operator, op_name) name_to_func[op_name] = op # We collect inplace operator test cases seperately - if "equal" in stub.__name__: + if "equal" in func_name: continue iop_name = "__i" + op_name[2:] iop = getattr(operator, iop_name) @@ -1223,7 +1224,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) binary_params.append(p) else: - warn(f"Special cases found for {stub.__name__} but none were parsed") + warn(f"Special cases found for {func_name} but none were parsed") continue else: warn( @@ -1323,7 +1324,8 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp @pytest.mark.parametrize( - "func_name", [f.__name__ for f in category_to_funcs["statistical"]] + "func_name", [f.__name__ for f in category_to_funcs["statistical"] + if f.__name__ != 'cumulative_sum'] ) @given( x=hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1)), diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index bb8c0ef2..778cdea1 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -5,25 +5,84 @@ import pytest from hypothesis import assume, given from hypothesis import strategies as st +from ndindex import iter_indices from . import _array_module as xp from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import api_version from ._array_module import _UndefinedStub from .typing import DataType @pytest.mark.min_version("2023.12") -@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(min_dims=1, max_dims=1))) -def test_cumulative_sum(x): - # TODO: test kwargs + diff shapes, adjust shape and values testing accordingly - out = xp.cumulative_sum(x) - # TODO: assert dtype - ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape) - # TODO: assert values +@given( + x=hh.arrays( + dtype=hh.numeric_dtypes, + shape=hh.shapes(min_dims=1)), + data=st.data(), +) +def test_cumulative_sum(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + if x.ndim == 1: + axes = axes | st.none() + axis = data.draw(axes, label='axis') + _axis, = sh.normalise_axis(axis, x.ndim) + dtype = data.draw(kwarg_dtypes(x.dtype)) + include_initial = data.draw(st.booleans(), label="include_initial") + + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("dtype", dtype, None), + ("include_initial", include_initial, False), + ), + label="kw", + ) + + out = xp.cumulative_sum(x, **kw) + + expected_shape = list(x.shape) + if include_initial: + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + + scalar_type = dh.get_scalar_type(out.dtype) + + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] + + if include_initial: + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0) + + for n in range(x.shape[_axis]): + start = 1 if include_initial else 0 + out_val = out_arr[n + start] + assume(cmath.isfinite(out_val)) + elements = [] + for idx in range(n + 1): + s = scalar_type(x_arr[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: @@ -148,7 +207,7 @@ def test_prod(x, data): # See https://github.com/data-apis/array-api-tests/issues/106 if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check - elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234 + else: ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( @@ -237,7 +296,7 @@ def test_sum(x, data): # See https://github.com/data-apis/array-api-tests/issues/160 if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check - elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234 + else: ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( diff --git a/conftest.py b/conftest.py index 76a2f0b2..2646a537 100644 --- a/conftest.py +++ b/conftest.py @@ -12,7 +12,7 @@ from array_api_tests import api_version from array_api_tests._array_module import _UndefinedStub from array_api_tests.stubs import EXTENSIONS -from array_api_tests import xp_name +from array_api_tests import xp_name, xp as array_module from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa @@ -22,7 +22,12 @@ def pytest_report_header(config): ext for ext in EXTENSIONS + ['fft'] if ext not in disabled_extensions and xp_has_ext(ext) }) - return f"Array API Tests Module: {xp_name}. API Version: {api_version}. Enabled Extensions: {', '.join(enabled_extensions)}" + try: + array_module_version = array_module.__version__ + except AttributeError: + array_module_version = "version unknown" + + return f"Array API Tests Module: {xp_name} ({array_module_version}). API Version: {api_version}. Enabled Extensions: {', '.join(enabled_extensions)}" def pytest_addoption(parser): # Hypothesis max examples @@ -200,7 +205,7 @@ def pytest_collection_modifyitems(config, items): if api_version < min_version: item.add_marker( mark.skip( - reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}" ) ) # reduce max generated Hypothesis example for unvectorized tests