From e61b50d27200e31fd90c766c1e8e9f47a19827e3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:53:44 -0600 Subject: [PATCH 01/56] Add a @requires_api_version decorator --- array_api_strict/_flags.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 6cc503a..9c87229 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -277,6 +277,21 @@ def set_flags_from_environment(): set_flags_from_environment() +# Decorators + +def requires_api_version(version): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if version > API_VERSION: + raise RuntimeError( + f"The function {func.__name__} requires API version {version} or later, " + f"but the current API version for array-api-strict is {API_VERSION}" + ) + return func(*args, **kwargs) + return wrapper + return decorator + def requires_data_dependent_shapes(func): @functools.wraps(func) def wrapper(*args, **kwargs): From f49845aa3e27af808b5f1fd1d0afc724143dc638 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:59:58 -0600 Subject: [PATCH 02/56] Don't re-enable disabled extensions when setting the api version --- array_api_strict/_flags.py | 2 +- array_api_strict/tests/test_flags.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 9c87229..cd33290 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -140,7 +140,7 @@ def set_array_api_strict_flags( ) ENABLED_EXTENSIONS = tuple(enabled_extensions) else: - ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) + ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 303c930..996a684 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -42,6 +42,16 @@ def test_flags(): assert flags == { 'api_version': '2021.12', 'data_dependent_shapes': False, + 'enabled_extensions': (), + } + reset_array_api_strict_flags() + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2021.12', + 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } From 71c523198b8f15d9f3920237b191cc9fc7efe710 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 14:01:43 -0600 Subject: [PATCH 03/56] Add support for setting the api version to 2023.12 --- array_api_strict/_flags.py | 6 ++++++ array_api_strict/tests/test_array_object.py | 5 ++++- array_api_strict/tests/test_flags.py | 13 +++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index cd33290..205c325 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -21,6 +21,7 @@ supported_versions = ( "2021.12", "2022.12", + "2023.12", ) API_VERSION = default_version = "2022.12" @@ -67,6 +68,9 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). + 2023.12 support is preliminary. Some features in 2023.12 may still be + missing, and it hasn't been fully tested. + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in array-api-strict. @@ -123,6 +127,8 @@ def set_array_api_strict_flags( raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + if api_version == "2023.12": + warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index bae0553..9d9dad0 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -410,9 +410,12 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version="2023.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2023.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 996a684..b1ad61f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -54,6 +54,19 @@ def test_flags(): 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } + reset_array_api_strict_flags() + + # 2023.12 should issue a warning + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2023.12') + assert len(record) == 1 + assert '2023.12' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2023.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } # Test setting flags with invalid values pytest.raises(ValueError, lambda: From c39fdbfcea9c755b72f6387c92c3f8c8f1a7acdd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:48:25 -0600 Subject: [PATCH 04/56] Set the stacklevel in the set_array_api_strict_flags() warnings --- array_api_strict/_flags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 205c325..476ffb9 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -126,9 +126,9 @@ def set_array_api_strict_flags( if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) if api_version == "2023.12": - warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.") + warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2) API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION From 31c5a89902508087cdaec663631fd51a1f95aeb6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:52:32 -0600 Subject: [PATCH 05/56] Add clip() It is only enabled for when the api version is 2023.12. I have only tested that it works manually. There is no test suite support for clip() yet. --- array_api_strict/__init__.py | 2 + array_api_strict/_elementwise_functions.py | 67 ++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 3f418d8..6a9079a 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -134,6 +134,7 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, cos, cosh, @@ -196,6 +197,7 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "cos", "cosh", "divide", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8b69677..800ee70 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -12,6 +12,11 @@ _result_type, ) from ._array_object import Array +from ._flags import requires_api_version +from ._creation_functions import asarray +from ._utility_functions import any as xp_any + +from typing import Optional, Union import numpy as np @@ -240,6 +245,68 @@ def ceil(x: Array, /) -> Array: return x return Array._new(np.ceil(x._array)) +# WARNING: This function is not yet tested by the array-api-tests test suite. + +# Note: min and max argument names are different and not optional in numpy. +@requires_api_version('2023.12') +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.clip `. + + See its docstring for more information. + """ + if (x.dtype not in _real_numeric_dtypes + or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes + or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): + raise TypeError("Only real numeric dtypes are allowed in clip") + if not isinstance(min, (int, float, Array, type(None))): + raise TypeError("min must be an None, int, float, or an array") + if not isinstance(max, (int, float, Array, type(None))): + raise TypeError("max must be an None, int, float, or an array") + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(min, float) or + isinstance(min, Array) and min.dtype in _real_floating_dtypes)): + raise TypeError("min must be integral when x is integral") + if (x.dtype in _integer_dtypes + and (isinstance(max, float) or + isinstance(max, Array) and max.dtype in _real_floating_dtypes)): + raise TypeError("max must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(min, int) or + isinstance(min, Array) and min.dtype in _integer_dtypes)): + raise TypeError("min must be floating-point when x is floating-point") + if (x.dtype in _real_floating_dtypes + and (isinstance(max, int) or + isinstance(max, Array) and max.dtype in _integer_dtypes)): + raise TypeError("max must be floating-point when x is floating-point") + + if min is max is None: + # Note: NumPy disallows min = max = None + return x + + # Normalize to make the below logic simpler + if min is not None: + min = asarray(min)._array + if max is not None: + max = asarray(max)._array + + # min > max is implementation defined + if min is not None and max is not None and np.any(min > max): + raise ValueError("min must be less than or equal to max") + + result = np.clip(x._array, min, max) + # Note: NumPy applies type promotion, but the standard specifies the + # return dtype should be the same as x + if result.dtype != x.dtype._np_dtype: + result = result.astype(x.dtype._np_dtype) + return Array._new(result) def conj(x: Array, /) -> Array: """ From 6a6719ff6115e5bde084682cc5870e012f0a70c5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:58:49 -0600 Subject: [PATCH 06/56] Add a TODO note for clip() --- array_api_strict/_elementwise_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 800ee70..c9272bb 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -305,6 +305,8 @@ def clip( # Note: NumPy applies type promotion, but the standard specifies the # return dtype should be the same as x if result.dtype != x.dtype._np_dtype: + # TODO: I'm not completely sure this always gives the correct thing + # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 result = result.astype(x.dtype._np_dtype) return Array._new(result) From 77e6177eecf4b330101965bf4f93ae18da01041e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:30:13 -0600 Subject: [PATCH 07/56] Remove unused import --- array_api_strict/_elementwise_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index c9272bb..ea52d96 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -14,7 +14,6 @@ from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray -from ._utility_functions import any as xp_any from typing import Optional, Union From 04c24d72be76305a19fb1de4153961b3848330aa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:35:29 -0600 Subject: [PATCH 08/56] Add copysign copysign is not tested yet by the test suite, but the standard does not appear to deviate from NumPy (except in the restriction to floating-point dtypes). --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index e2212f1..9d9aca6 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -139,6 +139,7 @@ ceil, clip, conj, + copysign, cos, cosh, divide, @@ -202,6 +203,7 @@ "ceil", "clip", "conj", + "copysign", "cos", "cosh", "divide", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index ea52d96..994bcb2 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -319,6 +319,16 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x)) +@requires_api_version('2023.12') +def copysign(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.copysign `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in copysign") + return Array._new(np.copysign(x1._array, x2._array)) def cos(x: Array, /) -> Array: """ From c4587a4654db065ef8dba7e9115f43cacc871774 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 23:38:11 -0600 Subject: [PATCH 09/56] Implement cumulative_sum (still needs to be tested) --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_statistical_functions.py | 25 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 9d9aca6..c1fba30 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -290,9 +290,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var -__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index cbe9d0d..c65f50e 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -7,6 +7,9 @@ ) from ._array_object import Array from ._dtypes import float32, complex64 +from ._flags import requires_api_version +from ._creation_functions import zeros +from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -16,6 +19,28 @@ import numpy as np +@requires_api_version('2023.12') +def cumulative_sum( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_sum") + if dtype is None: + dtype = x.dtype + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + # np.cumsum does not support include_initial + if include_initial: + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype._np_dtype)) def max( x: Array, From 16b38d305a160e3d4825b65aea9b3ad2c5b54b48 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 14:53:58 -0600 Subject: [PATCH 10/56] Add a comment about cumulative_sum and 0-D inputs --- array_api_strict/_statistical_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index c65f50e..b35d26f 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -33,6 +33,7 @@ def cumulative_sum( if dtype is None: dtype = x.dtype + # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_sum for more than one dimension") From b689d43b9906b5e9091c23f628a0a7bfee7654d1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:18:46 -0600 Subject: [PATCH 11/56] Add hypot() This is untested, but the NumPy hypot() should match the standard. --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index c1fba30..b9df986 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -150,6 +150,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -214,6 +215,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 994bcb2..f144c69 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -455,6 +455,18 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) +def hypot(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.hypot `. + + See its docstring for more information. + """ + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in hypot") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.hypot(x1._array, x2._array)) def imag(x: Array, /) -> Array: """ From 9ee08c7252e5e924755da043d9a4a3bd2c80b9bc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:26:46 -0600 Subject: [PATCH 12/56] Update elementwise tests for new elementwise functions Also add a meta-test to ensure the elementwise tests stay up-to-date. --- .../tests/test_elementwise_functions.py | 140 ++++++++++-------- 1 file changed, 76 insertions(+), 64 deletions(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 1228d0a..abb02f8 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec +from inspect import getfullargspec, getmodule from numpy.testing import assert_raises @@ -10,79 +10,88 @@ _floating_dtypes, _integer_dtypes, ) - +from .._flags import set_array_api_strict_flags def nargs(func): return len(getfullargspec(func).args) +elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "real floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "real numeric", + "clip": "real numeric", + "conj": "complex floating-point", + "copysign": "real floating-point", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", + "hypot": "real floating-point", + "imag": "complex floating-point", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "real numeric", + "less_equal": "real numeric", + "log": "floating-point", + "logaddexp": "real floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "numeric", + "real": "complex floating-point", + "remainder": "real numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "real numeric", +} + +def test_missing_functions(): + # Ensure the above dictionary is complete. + import array_api_strict._elementwise_functions as mod + mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + assert set(mod_funcs) == set(elementwise_function_input_types) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. - elementwise_function_input_types = { - "abs": "numeric", - "acos": "floating-point", - "acosh": "floating-point", - "add": "numeric", - "asin": "floating-point", - "asinh": "floating-point", - "atan": "floating-point", - "atan2": "real floating-point", - "atanh": "floating-point", - "bitwise_and": "integer or boolean", - "bitwise_invert": "integer or boolean", - "bitwise_left_shift": "integer", - "bitwise_or": "integer or boolean", - "bitwise_right_shift": "integer", - "bitwise_xor": "integer or boolean", - "ceil": "real numeric", - "conj": "complex floating-point", - "cos": "floating-point", - "cosh": "floating-point", - "divide": "floating-point", - "equal": "all", - "exp": "floating-point", - "expm1": "floating-point", - "floor": "real numeric", - "floor_divide": "real numeric", - "greater": "real numeric", - "greater_equal": "real numeric", - "imag": "complex floating-point", - "isfinite": "numeric", - "isinf": "numeric", - "isnan": "numeric", - "less": "real numeric", - "less_equal": "real numeric", - "log": "floating-point", - "logaddexp": "real floating-point", - "log10": "floating-point", - "log1p": "floating-point", - "log2": "floating-point", - "logical_and": "boolean", - "logical_not": "boolean", - "logical_or": "boolean", - "logical_xor": "boolean", - "multiply": "numeric", - "negative": "numeric", - "not_equal": "all", - "positive": "numeric", - "pow": "numeric", - "real": "complex floating-point", - "remainder": "real numeric", - "round": "numeric", - "sign": "numeric", - "sin": "floating-point", - "sinh": "floating-point", - "sqrt": "floating-point", - "square": "numeric", - "subtract": "numeric", - "tan": "floating-point", - "tanh": "floating-point", - "trunc": "real numeric", - } - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -91,6 +100,9 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) + # Use the latest version of the standard so all functions are included + set_array_api_strict_flags(api_version="2023.12") + for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] From e24f55ea9a6cef94e4c025429f29f098db353b40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:32:52 -0600 Subject: [PATCH 13/56] Clear trailing whitespace --- array_api_strict/tests/test_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index 70b42f3..9969651 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -25,7 +25,7 @@ def test_reshape_copy(): a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=True) assert not np.shares_memory(a._array, b._array) - + a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=False) assert np.shares_memory(a._array, b._array) From f5fbf78055d0c9d677a169ef2d5a82dc214b0430 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:39:13 -0600 Subject: [PATCH 14/56] Silence warnings output in the tests --- array_api_strict/tests/test_array_object.py | 3 ++- array_api_strict/tests/test_elementwise_functions.py | 5 ++++- array_api_strict/tests/test_flags.py | 10 ++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 24fcf57..a66637f 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -410,7 +410,8 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" - assert a.__array_namespace__(api_version="2023.12") is array_api_strict + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2023.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2023.12" with pytest.warns(UserWarning): diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index abb02f8..3bfcbae 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -12,6 +12,8 @@ ) from .._flags import set_array_api_strict_flags +import pytest + def nargs(func): return len(getfullargspec(func).args) @@ -101,7 +103,8 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2023.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b1ad61f..f6fbc0d 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -73,9 +73,10 @@ def test_flags(): set_array_api_strict_flags(api_version='2020.12')) pytest.raises(ValueError, lambda: set_array_api_strict_flags( enabled_extensions=('linalg', 'fft', 'invalid'))) - pytest.raises(ValueError, lambda: set_array_api_strict_flags( - api_version='2021.12', - enabled_extensions=('linalg', 'fft'))) + with pytest.warns(UserWarning): + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + api_version='2021.12', + enabled_extensions=('linalg', 'fft'))) # Test resetting flags with pytest.warns(UserWarning): @@ -96,7 +97,8 @@ def test_api_version(): assert xp.__array_api_version__ == '2022.12' # Test setting the version - set_array_api_strict_flags(api_version='2021.12') + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') assert xp.__array_api_version__ == '2021.12' def test_data_dependent_shapes(): From 3e2d46de96b6c87f5a2552a45e6aed5f8ab7a882 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:27:32 -0600 Subject: [PATCH 15/56] Add missing requires_api_version decorator to hypot() --- array_api_strict/_elementwise_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index f144c69..d1e589b 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -455,6 +455,7 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) +@requires_api_version('2023.12') def hypot(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.hypot `. From 250ba869fc786e61c5c5ad3ef3bf66dc1cd24b71 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:27:50 -0600 Subject: [PATCH 16/56] Add maximum and minimum --- array_api_strict/__init__.py | 4 +++ array_api_strict/_elementwise_functions.py | 29 +++++++++++++++++++ .../tests/test_elementwise_functions.py | 2 ++ 3 files changed, 35 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index b9df986..3c0e147 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -166,6 +166,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -231,6 +233,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index d1e589b..a82818b 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -651,6 +651,35 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) +@requires_api_version('2023.12') +def maximum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.maximum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in maximum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error + # in that case? + return Array._new(np.maximum(x1._array, x2._array)) + +@requires_api_version('2023.12') +def minimum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.minimum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in minimum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.minimum(x1._array, x2._array)) def multiply(x1: Array, x2: Array, /) -> Array: """ diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 3bfcbae..6b4a5ec 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -64,6 +64,8 @@ def nargs(func): "logical_not": "boolean", "logical_or": "boolean", "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", "multiply": "numeric", "negative": "numeric", "not_equal": "all", From eb063e2580a7abe65b2cdd92738663f2b4fb84ae Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:35:16 -0600 Subject: [PATCH 17/56] Add moveaxis --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 3c0e147..5110c72 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -275,6 +275,7 @@ concat, expand_dims, flip, + moveaxis, permute_dims, reshape, roll, @@ -282,7 +283,7 @@ stack, ) -__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index af9a3dd..c22ea1b 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._data_type_functions import result_type +from ._flags import requires_api_version from typing import TYPE_CHECKING @@ -43,6 +44,20 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.flip(x._array, axis=axis)) +@requires_api_version('2023.12') +def moveaxis( + x: Array, + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]], + /, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.moveaxis `. + + See its docstring for more information. + """ + return Array._new(np.moveaxis(x._array, source, destination)) + # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. From 993805957ba9849d66bdb79e7a18870bd31d96ee Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:49:31 -0600 Subject: [PATCH 18/56] Add repeat() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_flags.py | 6 ++--- array_api_strict/_manipulation_functions.py | 27 +++++++++++++++++++-- array_api_strict/tests/test_flags.py | 11 +++++++-- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 5110c72..39eafc0 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -277,13 +277,14 @@ flip, moveaxis, permute_dims, + repeat, reshape, roll, squeeze, stack, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 476ffb9..fd36139 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -81,10 +81,10 @@ def set_array_api_strict_flags( The functions that make use of data-dependent shapes, and are therefore disabled by setting this flag to False are - - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. - - `nonzero` + - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. + - `nonzero()` - Boolean array indexing - - `repeat` when the `repeats` argument is an array (requires 2023.12 + - `repeat()` when the `repeats` argument is an array (requires 2023.12 version of the standard) See diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index c22ea1b..1f9a50f 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,8 +1,9 @@ from __future__ import annotations from ._array_object import Array +from ._creation_functions import asarray from ._data_type_functions import result_type -from ._flags import requires_api_version +from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -58,7 +59,6 @@ def moveaxis( """ return Array._new(np.moveaxis(x._array, source, destination)) - # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: @@ -69,6 +69,29 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: """ return Array._new(np.transpose(x._array, axes)) +@requires_api_version('2023.12') +def repeat( + x: Array, + repeats: Union[int, Array], + /, + *, + axis: Optional[int] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.repeat `. + + See its docstring for more information. + """ + if isinstance(repeats, Array): + data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes'] + if not data_dependent_shapes: + raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + elif isinstance(repeats, int): + repeats = asarray(repeats) + else: + raise TypeError("repeats must be an int or array") + + return Array._new(np.repeat(x._array, repeats, axis=axis)) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f6fbc0d..2eba40c 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -2,7 +2,7 @@ reset_array_api_strict_flags) from .. import (asarray, unique_all, unique_counts, unique_inverse, - unique_values, nonzero) + unique_values, nonzero, repeat) import array_api_strict as xp @@ -102,8 +102,12 @@ def test_api_version(): assert xp.__array_api_version__ == '2021.12' def test_data_dependent_shapes(): + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') # to enable repeat() + a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) + repeats = asarray([1, 1, 2, 2, 2]) # Should not error unique_all(a) @@ -112,7 +116,8 @@ def test_data_dependent_shapes(): unique_values(a) nonzero(a) a[mask] - # TODO: add repeat when it is implemented + repeat(a, repeats) + repeat(a, 2) set_array_api_strict_flags(data_dependent_shapes=False) @@ -122,6 +127,8 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) pytest.raises(RuntimeError, lambda: a[mask]) + pytest.raises(RuntimeError, lambda: repeat(a, repeats)) + repeat(a, 2) # Should never error linalg_examples = { 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), From 095be2f32a8373a95ae094d249fdaa0f765cb30c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:53:58 -0600 Subject: [PATCH 19/56] Require the repeats array to have an integer dtype NumPy allows it to be bool (casting it to int). --- array_api_strict/_manipulation_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 1f9a50f..3380b4e 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -3,6 +3,7 @@ from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import result_type +from ._dtypes import _integer_dtypes from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -86,6 +87,8 @@ def repeat( data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes'] if not data_dependent_shapes: raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + if repeats.dtype not in _integer_dtypes: + raise TypeError("The repeats array must have an integer dtype") elif isinstance(repeats, int): repeats = asarray(repeats) else: From 1c4460d8abeeaf8235ef1fd081ff1a56667f844e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:13:53 -0600 Subject: [PATCH 20/56] Add searchsorted As far as I can tell, except for the dtype restriction, the standard is the same as NumPy. --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_searching_functions.py | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 39eafc0..d9a4aab 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -286,9 +286,9 @@ __all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] -from ._searching_functions import argmax, argmin, nonzero, where +from ._searching_functions import argmax, argmin, nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "where"] +__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 1ef2556..89e50f3 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,11 +2,11 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes -from ._flags import requires_data_dependent_shapes +from ._flags import requires_data_dependent_shapes, requires_api_version from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Tuple + from typing import Literal, Optional, Tuple import numpy as np @@ -45,6 +45,26 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i) for i in np.nonzero(x._array)) +@requires_api_version('2023.12') +def searchsorted( + x1: Array, + x2: Array, + /, + *, + side: Literal["left", "right"] = "left", + sorter: Optional[Array] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.searchsorted `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in searchsorted") + sorter = sorter._array if sorter is not None else None + # TODO: The sort order of nans and signed zeros is implementation + # dependent. Should we error/warn if they are present? + return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ From 730e71616a22e05d3ce88654233b6ebefa794774 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:18:54 -0600 Subject: [PATCH 21/56] Add comment about x1 being 1-D in searchsorted --- array_api_strict/_searching_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 89e50f3..7314895 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -64,6 +64,8 @@ def searchsorted( sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? + + # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: From f26bd499f4ec2e68a6555258c8868bbd02e71586 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:19:57 -0600 Subject: [PATCH 22/56] Add signbit --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 12 ++++++++++++ array_api_strict/tests/test_elementwise_functions.py | 1 + 3 files changed, 15 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index d9a4aab..7c9bbef 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -177,6 +177,7 @@ remainder, round, sign, + signbit, sin, sinh, square, @@ -244,6 +245,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "square", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index a82818b..9ef71bd 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -791,6 +791,18 @@ def sign(x: Array, /) -> Array: return Array._new(np.sign(x._array)) +@requires_api_version('2023.12') +def signbit(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.signbit `. + + See its docstring for more information. + """ + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in signbit") + return Array._new(np.signbit(x._array)) + + def sin(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sin `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 6b4a5ec..90994f3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -75,6 +75,7 @@ def nargs(func): "remainder": "real numeric", "round": "numeric", "sign": "numeric", + "signbit": "real floating-point", "sin": "floating-point", "sinh": "floating-point", "sqrt": "floating-point", From dc1baad8ef7bfe07d1563647952930653e24c418 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:26:18 -0600 Subject: [PATCH 23/56] Add tile() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 7c9bbef..5cd6e53 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -284,9 +284,10 @@ roll, squeeze, stack, + tile, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile"] from ._searching_functions import argmax, argmin, nonzero, searchsorted, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 3380b4e..ee6066f 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -154,3 +154,16 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> result_type(*arrays) arrays = tuple(a._array for a in arrays) return Array._new(np.stack(arrays, axis=axis)) + + +@requires_api_version('2023.12') +def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tile `. + + See its docstring for more information. + """ + # Note: NumPy allows repetitions to be an int or array + if not isinstance(repetitions, tuple): + raise TypeError("repetitions must be a tuple") + return Array._new(np.tile(x._array, repetitions)) From a30536b8a32f6e15031d629dbc4619595aaf2b6d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:34:38 -0600 Subject: [PATCH 24/56] Add unstack() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 5cd6e53..17cb2c3 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -285,9 +285,10 @@ squeeze, stack, tile, + unstack, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] from ._searching_functions import argmax, argmin, nonzero, searchsorted, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index ee6066f..7652028 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -167,3 +167,15 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: if not isinstance(repetitions, tuple): raise TypeError("repetitions must be a tuple") return Array._new(np.tile(x._array, repetitions)) + +# Note: this function is new +@requires_api_version('2023.12') +def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]: + if not (-x.ndim <= axis < x.ndim): + raise ValueError("axis out of range") + + if axis < 0: + axis += x.ndim + + slices = (slice(None),) * axis + return tuple(x[slices + (i, ...)] for i in range(x.shape[axis])) From 161acaa38d29add635a5fb1e52af561c957f4d40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 25 Apr 2024 16:20:03 -0600 Subject: [PATCH 25/56] Add the inspection APIs --- array_api_strict/__init__.py | 6 ++ array_api_strict/_flags.py | 8 ++ array_api_strict/_info.py | 141 +++++++++++++++++++++++++++++++++++ array_api_strict/_typing.py | 38 ++++++++++ 4 files changed, 193 insertions(+) create mode 100644 array_api_strict/_info.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 17cb2c3..82a3cdd 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -260,6 +260,12 @@ __all__ += ["take"] +from ._info import __array_namespace_info__ + +__all__ += [ + "__array_namespace_info__", +] + # linalg is an extension in the array API spec, which is a sub-namespace. Only # a subset of functions in it are imported into the top-level namespace. from . import linalg diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 76dc96e..632c42b 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -178,6 +178,14 @@ def get_array_api_strict_flags(): This function is **not** part of the array API standard. It only exists in array-api-strict. + .. note:: + + The `inspection API + `__ + provides a portable way to access most of this information. However, it + is only present in standard versions starting with 2023.12. The array + API version can be accessed portably using `xp.__array_api_version__`. + Returns ------- dict diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py new file mode 100644 index 0000000..5f8c841 --- /dev/null +++ b/array_api_strict/_info.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +__all__ = [ + "__array_namespace_info__", + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Union, Tuple, List + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + +from ._array_object import CPU_DEVICE +from ._flags import get_array_api_strict_flags, requires_api_version +from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 + +@requires_api_version('2023.12') +def __array_namespace_info__() -> Info: + import array_api_strict._info + return array_api_strict._info + +@requires_api_version('2023.12') +def capabilities() -> Capabilities: + flags = get_array_api_strict_flags() + return {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + +@requires_api_version('2023.12') +def default_device() -> device: + return CPU_DEVICE + +@requires_api_version('2023.12') +def default_dtypes( + *, + device: Optional[device] = None, +) -> DefaultDataTypes: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + +@requires_api_version('2023.12') +def dtypes( + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, +) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + +@requires_api_version('2023.12') +def devices() -> List[device]: + return [CPU_DEVICE] + +__all__ = [ + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index ce25d4c..eb1b834 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,6 +21,8 @@ from typing import ( Any, + ModuleType, + TypedDict, TypeVar, Protocol, ) @@ -39,6 +41,8 @@ def __len__(self, /) -> int: ... Dtype = _DType +Info = ModuleType + if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol else: @@ -48,3 +52,37 @@ def __len__(self, /) -> int: ... class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + +Capabilities = TypedDict( + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} +) + +DefaultDataTypes = TypedDict( + "DefaultDataTypes", + { + "real floating": Dtype, + "complex floating": Dtype, + "integral": Dtype, + "indexing": Dtype, + }, +) + +DataTypes = TypedDict( + "DataTypes", + { + "bool": Dtype, + "float32": Dtype, + "float64": Dtype, + "complex64": Dtype, + "complex128": Dtype, + "int8": Dtype, + "int16": Dtype, + "int32": Dtype, + "int64": Dtype, + "uint8": Dtype, + "uint16": Dtype, + "uint32": Dtype, + "uint64": Dtype, + }, + total=False, +) From 4d3ff6c3ebe5cea5d15d453cf749ca93a3882540 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 14:34:44 -0600 Subject: [PATCH 26/56] Fix test failures --- array_api_strict/tests/test_flags.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f1b20cc..0cb670b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -55,6 +55,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2021.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } @@ -68,6 +69,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2023.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -132,6 +134,8 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_inverse(a)) pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: repeat(a, repeats)) + repeat(a, 2) # Should never error a[mask] # No error (boolean indexing is a separate flag) def test_boolean_indexing(): @@ -144,8 +148,6 @@ def test_boolean_indexing(): set_array_api_strict_flags(boolean_indexing=False) pytest.raises(RuntimeError, lambda: a[mask]) - pytest.raises(RuntimeError, lambda: repeat(a, repeats)) - repeat(a, 2) # Should never error linalg_examples = { 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), From 84d2aa5a94fc48627ba454edb337d8f70912a5c8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 15:52:30 -0600 Subject: [PATCH 27/56] Always make warnings errors in the tests We might need to remove this if we ever test things that NumPy raises warnings for. --- pytest.ini | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0c84ee3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +filterwarnings = error From 05fa0b5f8d6fd11e3be166bb1e75e2e3a55bc95a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 15:56:58 -0600 Subject: [PATCH 28/56] Add tests that the new 2023.12 functions are properly decorated --- array_api_strict/tests/test_flags.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 0cb670b..65aa26f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -1,5 +1,7 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) +from .._info import (capabilities, default_device, default_dtypes, devices, + dtypes) from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero, repeat) @@ -237,3 +239,39 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() + +api_version_2023_12_examples = { + '__array_namespace_info__': lambda: xp.__array_namespace_info__(), + # Test these functions directly to ensure they are properly decorated + 'capabilities': capabilities, + 'default_device': default_device, + 'default_dtypes': default_dtypes, + 'devices': devices, + 'dtypes': dtypes, + 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), + 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), + 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), + 'hypot': lambda: xp.hypot(xp.asarray([3., 4.]), xp.asarray([4., 3.])), + 'maximum': lambda: xp.maximum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'minimum': lambda: xp.minimum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'moveaxis': lambda: xp.moveaxis(xp.ones((3, 3)), 0, 1), + 'repeat': lambda: xp.repeat(xp.asarray([1, 2, 3]), 3), + 'searchsorted': lambda: xp.searchsorted(xp.asarray([1, 2, 3]), xp.asarray([0, 1, 2, 3, 4])), + 'signbit': lambda: xp.signbit(xp.asarray([-1., 0., 1.])), + 'tile': lambda: xp.tile(xp.ones((3, 3)), (2, 3)), + 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), +} + +@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) +def test_api_version_2023_12(func_name): + func = api_version_2023_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + func() + + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) From 83331076bfa5824942af6bf9cc27b38903d6dc85 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 16:04:14 -0600 Subject: [PATCH 29/56] Update documentation for 2023.12 support --- array_api_strict/_flags.py | 3 +-- docs/api.rst | 2 ++ docs/index.md | 17 +++++++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 632c42b..c0b744e 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -71,10 +71,9 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). - 2023.12 support is preliminary. Some features in 2023.12 may still be + 2023.12 support is experimental. Some features in 2023.12 may still be missing, and it hasn't been fully tested. - - `boolean_indexing`: Whether indexing by a boolean array is supported. Note that although boolean array indexing does result in data-dependent shapes, this flag is independent of the `data_dependent_shapes` flag diff --git a/docs/api.rst b/docs/api.rst index 15ce4e9..ed702dc 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,6 +11,8 @@ Array API Strict Flags .. currentmodule:: array_api_strict .. autofunction:: get_array_api_strict_flags + +.. _set_array_api_strict_flags: .. autofunction:: set_array_api_strict_flags .. autofunction:: reset_array_api_strict_flags .. autoclass:: ArrayAPIStrictFlags diff --git a/docs/index.md b/docs/index.md index 6e84efa..fc385d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,9 +15,12 @@ libraries. Consuming library code should use the support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. -array-api-strict currently supports the 2022.12 version of the standard. -2023.12 support is planned and is tracked by [this -issue](https://github.com/data-apis/array-api-strict/issues/25). +array-api-strict currently supports the +[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) +version of the standard. Experimental +[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) +support is implemented, [but must be enabled with a +flag](set_array_api_strict_flags). ## Install @@ -179,9 +182,11 @@ issue, but this hasn't necessarily been tested thoroughly. function. array-api-strict currently implements all of these. In the future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). -6. array-api-strict currently only supports the 2022.12 version of the array - API standard. [Support for 2023.12 is - planned](https://github.com/data-apis/array-api-strict/issues/25). +6. array-api-strict currently uses the 2022.12 version of the array API + standard. Support for 2023.12 is implemented but is still experimental and + not fully tested. It can be enabled with + [`array_api_strict.set_array_api_strict_flags(api_version='2023.12')`](set_array_api_strict_flags). + (numpy.array_api)= ## Relationship to `numpy.array_api` From a437da32033e1dc09b040b52e178648c4b472b76 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 23:58:50 -0600 Subject: [PATCH 30/56] Implement 2023.12 behavior for sum() and prod() --- array_api_strict/_statistical_functions.py | 34 ++++++++++++---------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index b35d26f..7a42d25 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -7,7 +7,7 @@ ) from ._array_object import Array from ._dtypes import float32, complex64 -from ._flags import requires_api_version +from ._flags import requires_api_version, get_array_api_strict_flags from ._creation_functions import zeros from ._manipulation_functions import concat @@ -89,14 +89,16 @@ def prod( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) @@ -126,14 +128,16 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) From 9f954e63265bfb1d655865ad390abe2ebe3ac585 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 14:28:48 -0600 Subject: [PATCH 31/56] Implement 2023.12 behavior for trace --- array_api_strict/linalg.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 1f548f0..3a0657e 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -11,7 +11,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array -from ._flags import requires_extension +from ._flags import requires_extension, get_array_api_strict_flags try: from numpy._core.numeric import normalize_axis_tuple @@ -377,10 +377,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # Note: trace() works the same as sum() and prod() (see # _statistical_functions.py) if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype # Note: trace always operates on the last two axes, whereas np.trace From 8572df37a1e0bdb543b505900be54e728d0bf79c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 14:29:06 -0600 Subject: [PATCH 32/56] Add a test for sum/trace/prod 2023.12 upcasting behavior --- .../tests/test_statistical_functions.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 array_api_strict/tests/test_statistical_functions.py diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py new file mode 100644 index 0000000..fcf8f7f --- /dev/null +++ b/array_api_strict/tests/test_statistical_functions.py @@ -0,0 +1,27 @@ +import pytest + +import array_api_strict as xp + +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) +def test_sum_prod_trace_2023_12(func_name): + # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes + # with dtype=None + if func_name == 'trace': + func = getattr(xp.linalg, func_name) + else: + func = getattr(xp, func_name) + + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) + + assert func(a_real).dtype == xp.float64 + assert func(a_complex).dtype == xp.complex128 + assert func(a_int).dtype == xp.int64 + + with pytest.warns(UserWarning): + xp.set_array_api_strict_flags(api_version='2023.12') + + assert func(a_real).dtype == xp.float32 + assert func(a_complex).dtype == xp.complex64 + assert func(a_int).dtype == xp.int64 From 47894ff54bc9b0cd40018e105aa2ef99ff3dd19c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 15:14:37 -0600 Subject: [PATCH 33/56] Add 2023.12 axis restrictions to vecdot() and cross() --- array_api_strict/_linear_algebra_functions.py | 15 +- array_api_strict/linalg.py | 11 ++ array_api_strict/tests/test_linalg.py | 133 ++++++++++++++++++ .../tests/test_statistical_functions.py | 4 +- 4 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 array_api_strict/tests/test_linalg.py diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 1ff08d4..6a1a921 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -8,8 +8,8 @@ from __future__ import annotations from ._dtypes import _numeric_dtypes - from ._array_object import Array +from ._flags import get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -54,6 +54,19 @@ def matrix_transpose(x: Array, /) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in vecdot") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # In versions if the standard prior to 2023.12, vecdot applied axis after + # broadcasting. This is different from applying it before broadcasting + # when axis is nonnegative. The below code keeps this behavior for + # 2022.12, primarily for backwards compatibility. Note that the behavior + # is unambiguous when axis is negative, so the below code should work + # correctly in that case regardless of which version is used. ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 3a0657e..bd11aa4 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -80,6 +80,17 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # Note: this is different from np.cross(), which allows dimension 2 if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in cross") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # Prior to 2023.12, there was ambiguity in the standard about whether + # positive axis applied before or after broadcasting. NumPy applies + # the axis before broadcasting. Since that behavior is what has always + # been implemented here, we keep it for backwards compatibility. return Array._new(np.cross(x1._array, x2._array, axis=axis)) @requires_extension('linalg') diff --git a/array_api_strict/tests/test_linalg.py b/array_api_strict/tests/test_linalg.py new file mode 100644 index 0000000..5e6cda2 --- /dev/null +++ b/array_api_strict/tests/test_linalg.py @@ -0,0 +1,133 @@ +import pytest + +from .._flags import set_array_api_strict_flags + +import array_api_strict as xp + +# TODO: Maybe all of these exceptions should be IndexError? + +# Technically this is linear_algebra, not linalg, but it's simpler to keep +# both of these tests together +def test_vecdot_2023_12(): + # Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= + # 0 behavior (which is primarily kept for backwards compatibility). + + a = xp.ones((2, 3, 4, 5)) + b = xp.ones(( 3, 4, 1)) + + # 2022.12 behavior, which is to apply axis >= 0 after broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5) + assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5) + # This is disallowed because the arrays must have the same values before + # broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + # negative axis behavior is unambiguous when it's within the bounds of + # both arrays before broadcasting + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12', '2023.12']) +def test_cross(api_version): + # This test tests everything that should be the same across all supported + # API versions. + + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) + + a = xp.ones((2, 4, 3, 5)) + b = xp.ones(( 4, 3, 1)) + assert xp.linalg.cross(a, b, axis=-2).shape == (2, 4, 3, 5) + + # This is disallowed because the axes must equal 3 before broadcasting + a = xp.ones((3, 2, 3, 5)) + b = xp.ones(( 2, 1, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12']) +def test_cross_2022_12(api_version): + # Test the 2022.12 axis >= 0 behavior, which is primarily kept for + # backwards compatibility. Note that unlike vecdot, array_api_strict + # cross() never implemented the "after broadcasting" axis behavior, but + # just reused NumPy cross(), which applies axes before broadcasting. + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + + # ambiguous case + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + +def test_cross_2023_12(): + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp. linalg.cross(a, b, axis=0)) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index fcf8f7f..61e848c 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,5 +1,7 @@ import pytest +from .._flags import set_array_api_strict_flags + import array_api_strict as xp @pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) @@ -20,7 +22,7 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_int).dtype == xp.int64 with pytest.warns(UserWarning): - xp.set_array_api_strict_flags(api_version='2023.12') + set_array_api_strict_flags(api_version='2023.12') assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 From 6b431946e4135b3815d7ff050cac8c26c0ce6d5d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 15:31:32 -0600 Subject: [PATCH 34/56] Add device flag to astype in 2023.12 Also clean up imports in test_data_type_functions.py --- array_api_strict/_data_type_functions.py | 20 +++++-- .../tests/test_data_type_functions.py | 54 ++++++++++++++----- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 41f70c5..7ae6244 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ._array_object import Array +from ._array_object import Array, CPU_DEVICE from ._dtypes import ( _DType, _all_dtypes, @@ -13,19 +13,31 @@ _numeric_dtypes, _result_type, ) +from ._flags import get_array_api_strict_flags from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import List, Tuple, Union - from ._typing import Dtype + from typing import List, Tuple, Union, Optional + from ._typing import Dtype, Device import numpy as np +# Use to emulate the asarray(device) argument not existing in 2022.12 +_default = object() # Note: astype is a function, not an array method as in NumPy. -def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: +def astype( + x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default +) -> Array: + if device is not _default: + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if device not in [CPU_DEVICE, None]: + raise ValueError(f"Unsupported device {device!r}") + else: + raise TypeError("The device argument to astype requires the 2023.12 version of the array API") + if not copy and dtype == x.dtype: return x return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy)) diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 60a7f29..40cab55 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -3,38 +3,68 @@ import pytest from numpy.testing import assert_raises -import array_api_strict as xp import numpy as np +from .._creation_functions import asarray +from .._data_type_functions import astype, can_cast, isdtype +from .._dtypes import ( + bool, int8, int16, uint8, float64, +) +from .._flags import set_array_api_strict_flags + + @pytest.mark.parametrize( "from_, to, expected", [ - (xp.int8, xp.int16, True), - (xp.int16, xp.int8, False), - (xp.bool, xp.int8, False), - (xp.asarray(0, dtype=xp.uint8), xp.int8, False), + (int8, int16, True), + (int16, int8, False), + (bool, int8, False), + (asarray(0, dtype=uint8), int8, False), ], ) def test_can_cast(from_, to, expected): """ can_cast() returns correct result """ - assert xp.can_cast(from_, to) == expected + assert can_cast(from_, to) == expected def test_isdtype_strictness(): - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64)) - assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8')) + assert_raises(TypeError, lambda: isdtype(float64, 64)) + assert_raises(ValueError, lambda: isdtype(float64, 'f8')) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),))) + assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.object_) + isdtype(float64, np.object_) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None)) + assert_raises(TypeError, lambda: isdtype(float64, None)) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.float64) + isdtype(float64, np.float64) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def astype_device(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + astype(a, int16) + + # Always an error + astype(a, int16, device="cpu") + + if api_version >= '2023.12': + astype(a, int8, device=None) + astype(a, int8, device=a.device) + else: + pytest.raises(TypeError, lambda: astype(a, int8, device=None)) + pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) From 3fde5ddfbe7b0e60d9ab2732676521e93a0e1a07 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 00:07:51 -0600 Subject: [PATCH 35/56] Factor out device checks into a helper function --- array_api_strict/_creation_functions.py | 75 +++++++++++++----------- array_api_strict/_data_type_functions.py | 6 +- 2 files changed, 43 insertions(+), 38 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index ad7ec82..dd3e74f 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -28,6 +28,14 @@ def _supports_buffer_protocol(obj): return False return True +def _check_device(device): + # _array_object imports in this file are inside the functions to avoid + # circular imports + from ._array_object import CPU_DEVICE + + if device not in [CPU_DEVICE, None]: + raise ValueError(f"Unsupported device {device!r}") + def asarray( obj: Union[ Array, @@ -48,16 +56,13 @@ def asarray( See its docstring for more information. """ - # _array_object imports in this file are inside the functions to avoid - # circular imports - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) _np_dtype = None if dtype is not None: _np_dtype = dtype._np_dtype - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) if np.__version__[0] < '2': if copy is False: @@ -106,11 +111,11 @@ def arange( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) @@ -127,11 +132,11 @@ def empty( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty(shape, dtype=dtype)) @@ -145,11 +150,11 @@ def empty_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty_like(x._array, dtype=dtype)) @@ -197,11 +202,11 @@ def full( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array if dtype is not None: @@ -227,11 +232,11 @@ def full_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype res = np.full_like(x._array, fill_value, dtype=dtype) @@ -257,11 +262,11 @@ def linspace( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) @@ -298,11 +303,11 @@ def ones( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones(shape, dtype=dtype)) @@ -316,11 +321,11 @@ def ones_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones_like(x._array, dtype=dtype)) @@ -365,11 +370,11 @@ def zeros( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros(shape, dtype=dtype)) @@ -383,11 +388,11 @@ def zeros_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 7ae6244..e43125a 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations -from ._array_object import Array, CPU_DEVICE +from ._array_object import Array +from ._creation_functions import _check_device from ._dtypes import ( _DType, _all_dtypes, @@ -33,8 +34,7 @@ def astype( ) -> Array: if device is not _default: if get_array_api_strict_flags()['api_version'] >= '2023.12': - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) else: raise TypeError("The device argument to astype requires the 2023.12 version of the array API") From 1ac528821ac90d741cf7f0245a383e877366df5c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 00:08:03 -0600 Subject: [PATCH 36/56] Add 2023.12 device and copy keywords to from_dlpack The copy keyword just raises NotImplementedError for now. --- array_api_strict/_creation_functions.py | 28 ++++++++++++++++++++---- array_api_strict/_data_type_functions.py | 2 +- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index dd3e74f..b24af98 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -12,6 +12,7 @@ SupportsBufferProtocol, ) from ._dtypes import _DType, _all_dtypes +from ._flags import get_array_api_strict_flags import numpy as np @@ -174,19 +175,38 @@ def eye( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) -def from_dlpack(x: object, /) -> Array: +_default = object() + +def from_dlpack( + x: object, + /, + *, + device: Optional[Device] = _default, + copy: Optional[bool] = _default, +) -> Array: from ._array_object import Array + if get_array_api_strict_flags()['api_version'] < '2023.12': + if device is not _default: + raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") + + if device is not _default: + _check_device(device) + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") + return Array._new(np.from_dlpack(x)) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index e43125a..3405710 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -36,7 +36,7 @@ def astype( if get_array_api_strict_flags()['api_version'] >= '2023.12': _check_device(device) else: - raise TypeError("The device argument to astype requires the 2023.12 version of the array API") + raise TypeError("The device argument to astype requires at least version 2023.12 of the array API") if not copy and dtype == x.dtype: return x From dc4684b4be7e70add3c197dff56deeef299d43bd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 13:56:00 -0600 Subject: [PATCH 37/56] Update the signature of __dlpack__ for 2023.12 The new arguments are not actually supported yet, and probably won't be until upstream NumPy does. --- array_api_strict/_array_object.py | 28 ++++++++++++++++++++++++- array_api_strict/_creation_functions.py | 1 + 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 8849ce3..26c4330 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -51,6 +51,8 @@ def __repr__(self): CPU_DEVICE = _cpu_device() +_default = object() + class Array: """ n-d array object for the array API namespace. @@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex: res = self._array.__complex__() return res - def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: + def __dlpack__( + self: Array, + /, + *, + stream: Optional[Union[int, Any]] = None, + max_version: Optional[tuple[int, int]] = _default, + dl_device: Optional[tuple[IntEnum, int]] = _default, + copy: Optional[bool] = _default, + ) -> PyCapsule: """ Performs the operation __dlpack__. """ + if get_array_api_strict_flags()['api_version'] < '2023.12': + if max_version is not _default: + raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API") + if dl_device is not _default: + raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") + + # Going to wait for upstream numpy support + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + return self._array.__dlpack__(stream=stream) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index b24af98..0e85cdc 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -202,6 +202,7 @@ def from_dlpack( if copy is not _default: raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") + # Going to wait for upstream numpy support if device is not _default: _check_device(device) if copy not in [_default, None]: From 647a5f004053a42203707297d609793a0ef25210 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:22:45 -0600 Subject: [PATCH 38/56] Add tests for from_dlpack and __dlpack__ 2023.12 behavior --- array_api_strict/tests/test_array_object.py | 32 +++++++++++++++++++ .../tests/test_creation_functions.py | 26 ++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a66637f..f0efdfa 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -23,6 +23,8 @@ uint64, bool as bool_, ) +from .._flags import set_array_api_strict_flags + import array_api_strict def test_validate_index(): @@ -420,3 +422,33 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + a.__dlpack__() + + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 78d4c80..819afad 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -3,6 +3,8 @@ from numpy.testing import assert_raises import numpy as np +import pytest + from .. import all from .._creation_functions import ( asarray, @@ -10,6 +12,7 @@ empty, empty_like, eye, + from_dlpack, full, full_like, linspace, @@ -21,7 +24,7 @@ ) from .._dtypes import float32, float64 from .._array_object import Array, CPU_DEVICE - +from .._flags import set_array_api_strict_flags def test_asarray_errors(): # Test various protections against incorrect usage @@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors(): meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32)) assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def from_dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1., 2., 3.], dtype=float64) + # Never an error + capsule = a.__dlpack__() + from_dlpack(capsule) + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE)) + pytest.raises(exception, lambda: from_dlpack(capsule, device=None)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) From 306de9bd5f2810636f0c7e1f83a119194b25f47c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:29:13 -0600 Subject: [PATCH 39/56] Add 2023.12 testing to the CI --- .github/workflows/array-api-tests.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index bfb7dcf..ce246e4 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -12,6 +12,7 @@ jobs: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] numpy-version: ['1.26', 'dev'] + api_version: ['2022.12', '2023.12'] exclude: - python-version: '3.8' numpy-version: 'dev' @@ -49,5 +50,12 @@ jobs: # tests fail in numpy 1.26 on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | + export ARRAY_API_STRICT_API_VERSION=${{ matrix.api_version }} + + # Only signature tests work for now for 2023.12 + if [[ "${{ matrix.api_version }}" == "2023.12" ]]; then + PYTEST_ARGS="${PYTEST_ARGS} -k signature + fi + cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} From 44bbdb214fa555c8f8877b720d75ebc07d6c0afc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:31:31 -0600 Subject: [PATCH 40/56] Better error message --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 26c4330..0fff27a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -439,7 +439,7 @@ def _validate_index(self, key): "Array API when the array is the sole index." ) if not get_array_api_strict_flags()['boolean_indexing']: - raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict") + raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( From e5225ed7f54de3c0e82cc1a63c66f6093c5d204f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:34:13 -0600 Subject: [PATCH 41/56] Parameterize the API version in a loop instead of in the matrix --- .github/workflows/array-api-tests.yml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ce246e4..b37ec04 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -4,6 +4,7 @@ on: [push, pull_request] env: PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" + API_VERSIONS: "2022.12 2023.12" jobs: array-api-tests: @@ -12,7 +13,6 @@ jobs: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] numpy-version: ['1.26', 'dev'] - api_version: ['2022.12', '2023.12'] exclude: - python-version: '3.8' numpy-version: 'dev' @@ -50,12 +50,13 @@ jobs: # tests fail in numpy 1.26 on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | - export ARRAY_API_STRICT_API_VERSION=${{ matrix.api_version }} + # Parameterizing this in the CI matrix is wasteful. Just do a loop here. + for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do + # Only signature tests work for now for 2023.12 + if [[ "$ARRAY_API_STRICT_API_VERSION" == "2023.12" ]]; then + PYTEST_ARGS="${PYTEST_ARGS} -k signature + fi - # Only signature tests work for now for 2023.12 - if [[ "${{ matrix.api_version }}" == "2023.12" ]]; then - PYTEST_ARGS="${PYTEST_ARGS} -k signature - fi - - cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + done From 3e0be7df7eff77d4a443b0b7d265d7c0e34ee4e5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 16:05:15 -0600 Subject: [PATCH 42/56] Ensure a.mT works even if the linalg extension is disabled --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_flags.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0fff27a..18dd219 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -1159,7 +1159,7 @@ def device(self) -> Device: # Note: mT is new in array API spec (see matrix_transpose) @property def mT(self) -> Array: - from .linalg import matrix_transpose + from ._linear_algebra_functions import matrix_transpose return matrix_transpose(self) @property diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 65aa26f..b68e7aa 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -184,9 +184,10 @@ def test_boolean_indexing(): 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'mT': lambda: xp.eye(3).mT, } -assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) +assert set(linalg_main_namespace_examples) == (set(xp.__all__) & set(xp.linalg.__all__)) | {"mT"} @pytest.mark.parametrize('func_name', linalg_examples.keys()) def test_linalg(func_name): From 338ebfefce1ff395391154b33851359dc32565c0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 16:41:36 -0600 Subject: [PATCH 43/56] Don't allow environment variables to be set during test runs --- array_api_strict/_flags.py | 7 +++++++ array_api_strict/tests/conftest.py | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c0b744e..866e4f5 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -293,6 +293,13 @@ def __exit__(self, exc_type, exc_value, traceback): # Private functions +ENVIRONMENT_VARIABLES = [ + "ARRAY_API_STRICT_API_VERSION", + "ARRAY_API_STRICT_BOOLEAN_INDEXING", + "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES", + "ARRAY_API_STRICT_ENABLED_EXTENSIONS", +] + def set_flags_from_environment(): if "ARRAY_API_STRICT_API_VERSION" in os.environ: set_array_api_strict_flags( diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py index 5000d5d..322675c 100644 --- a/array_api_strict/tests/conftest.py +++ b/array_api_strict/tests/conftest.py @@ -1,7 +1,14 @@ -from .._flags import reset_array_api_strict_flags +import os + +from .._flags import reset_array_api_strict_flags, ENVIRONMENT_VARIABLES import pytest +def pytest_sessionstart(session): + for env_var in ENVIRONMENT_VARIABLES: + if env_var in os.environ: + pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.") + @pytest.fixture(autouse=True) def reset_flags(): reset_array_api_strict_flags() From 6a466f45cbe30fdec4049dee4bfd6e352c90c4a0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 2 May 2024 21:55:58 -0600 Subject: [PATCH 44/56] Use a more robust way to fail the tests if an env var is set --- array_api_strict/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py index 322675c..1a9d507 100644 --- a/array_api_strict/tests/conftest.py +++ b/array_api_strict/tests/conftest.py @@ -4,7 +4,7 @@ import pytest -def pytest_sessionstart(session): +def pytest_configure(config): for env_var in ENVIRONMENT_VARIABLES: if env_var in os.environ: pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.") From dd01b12c75e044ec3b0d3f1a50eb22885670f6ea Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 15:59:15 -0600 Subject: [PATCH 45/56] Fix setting ARRAY_API_STRICT_ENABLED_EXTENSIONS='' --- array_api_strict/_flags.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 866e4f5..b02b869 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -317,9 +317,10 @@ def set_flags_from_environment(): ) if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: - set_array_api_strict_flags( - enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") - ) + enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + if enabled_extensions == [""]: + enabled_extensions = [] + set_array_api_strict_flags(enabled_extensions=enabled_extensions) set_flags_from_environment() From 43b9088ce2d7f33d50486e03b45fd99ef6166cd6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 16:06:09 -0600 Subject: [PATCH 46/56] Make extensions give AttributeError when they are disabled This is how the test suite and presumably some other codes detect if extensions are enabled or not. This also dynamically updates __all__ whenever extensions are enabled or disabled. --- array_api_strict/__init__.py | 32 ++-- array_api_strict/{fft.py => _fft.py} | 0 array_api_strict/_flags.py | 7 + array_api_strict/{linalg.py => _linalg.py} | 0 array_api_strict/tests/test_flags.py | 168 ++++++++++++++++----- 5 files changed, 160 insertions(+), 47 deletions(-) rename array_api_strict/{fft.py => _fft.py} (100%) rename array_api_strict/{linalg.py => _linalg.py} (100%) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 82a3cdd..8dfa09f 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,13 +16,15 @@ """ +__all__ = [] + # Warning: __array_api_version__ could change globally with # set_array_api_strict_flags(). This should always be accessed as an # attribute, like xp.__array_api_version__, or using # array_api_strict.get_array_api_strict_flags()['api_version']. from ._flags import API_VERSION as __array_api_version__ -__all__ = ["__array_api_version__"] +__all__ += ["__array_api_version__"] from ._constants import e, inf, nan, pi, newaxis @@ -266,19 +268,10 @@ "__array_namespace_info__", ] -# linalg is an extension in the array API spec, which is a sub-namespace. Only -# a subset of functions in it are imported into the top-level namespace. -from . import linalg - -__all__ += ["linalg"] - from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] -from . import fft -__all__ += ["fft"] - from ._manipulation_functions import ( concat, expand_dims, @@ -330,3 +323,22 @@ from . import _version __version__ = _version.get_versions()['version'] del _version + + +# Extensions can be enabled or disabled dynamically. In order to make +# "array_api_strict.linalg" give an AttributeError when it is disabled, we +# use __getattr__. Note that linalg and fft are dynamically added and removed +# from __all__ in set_array_api_strict_flags. + +def __getattr__(name): + if name in ['linalg', 'fft']: + if name in get_array_api_strict_flags()['enabled_extensions']: + if name == 'linalg': + from . import _linalg + return _linalg + elif name == 'fft': + from . import _fft + return _fft + else: + raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/array_api_strict/fft.py b/array_api_strict/_fft.py similarity index 100% rename from array_api_strict/fft.py rename to array_api_strict/_fft.py diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index b02b869..221d0d3 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -161,6 +161,10 @@ def set_array_api_strict_flags( else: ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION]) + array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) | + set(array_api_strict.__all__) - + set(default_extensions)) + # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( supported_versions=supported_versions, @@ -321,6 +325,9 @@ def set_flags_from_environment(): if enabled_extensions == [""]: enabled_extensions = [] set_array_api_strict_flags(enabled_extensions=enabled_extensions) + else: + # Needed at first import to add linalg and fft to __all__ + set_array_api_strict_flags(enabled_extensions=default_extensions) set_flags_from_environment() diff --git a/array_api_strict/linalg.py b/array_api_strict/_linalg.py similarity index 100% rename from array_api_strict/linalg.py rename to array_api_strict/_linalg.py diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b68e7aa..38b1a3b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -1,7 +1,15 @@ +import sys +import subprocess + from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) from .._info import (capabilities, default_device, default_dtypes, devices, dtypes) +from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, + ihfft, fftfreq, rfftfreq, fftshift, ifftshift) +from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, + matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv, + qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm) from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero, repeat) @@ -152,29 +160,29 @@ def test_boolean_indexing(): pytest.raises(RuntimeError, lambda: a[mask]) linalg_examples = { - 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), - 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), - 'det': lambda: xp.linalg.det(xp.eye(3)), - 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), - 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), - 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), - 'inv': lambda: xp.linalg.inv(xp.eye(3)), - 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), - 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), - 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), - 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), - 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), - 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), - 'qr': lambda: xp.linalg.qr(xp.eye(3)), - 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), - 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), - 'svd': lambda: xp.linalg.svd(xp.eye(3)), - 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), - 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), - 'trace': lambda: xp.linalg.trace(xp.eye(3)), - 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), + 'cholesky': lambda: cholesky(xp.eye(3)), + 'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: det(xp.eye(3)), + 'diagonal': lambda: diagonal(xp.eye(3)), + 'eigh': lambda: eigh(xp.eye(3)), + 'eigvalsh': lambda: eigvalsh(xp.eye(3)), + 'inv': lambda: inv(xp.eye(3)), + 'matmul': lambda: matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: matrix_norm(xp.eye(3)), + 'matrix_power': lambda: matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: matrix_transpose(xp.eye(3)), + 'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: pinv(xp.eye(3)), + 'qr': lambda: qr(xp.eye(3)), + 'slogdet': lambda: slogdet(xp.eye(3)), + 'solve': lambda: solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: svd(xp.eye(3)), + 'svdvals': lambda: svdvals(xp.eye(3)), + 'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: trace(xp.eye(3)), + 'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])), } assert set(linalg_examples) == set(xp.linalg.__all__) @@ -210,20 +218,20 @@ def test_linalg(func_name): main_namespace_func() fft_examples = { - 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), - 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), - 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), - 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), - 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), - 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), - 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), - 'fftfreq': lambda: xp.fft.fftfreq(4), - 'rfftfreq': lambda: xp.fft.rfftfreq(4), - 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), - 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: fftfreq(4), + 'rfftfreq': lambda: rfftfreq(4), + 'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])), } assert set(fft_examples) == set(xp.fft.__all__) @@ -276,3 +284,89 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) + +def test_disabled_extensions(): + # Test that xp.extension errors when an extension is disabled, and that + # xp.__all__ is updated properly. + + # First test that things are correct on the initial import. Since we have + # already called set_array_api_strict_flags many times throughout running + # the tests, we have to test this in a subprocess. + subprocess_tests = [('''\ +import array_api_strict + +array_api_strict.linalg # No error +array_api_strict.fft # No error +assert "linalg" in array_api_strict.__all__ +assert "fft" in array_api_strict.__all__ +assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__)) +''', {}), +# Test that the initial population of __all__ works correctly +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +fft +''', {}), +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}), +('''\ +from array_api_strict import * # No error +fft # Should have been imported by the previous line +assert 'linalg' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}), +('''\ +from array_api_strict import * # No error +assert 'linalg' not in globals() +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}), +] + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(e.stderr) + + assert 'linalg' in xp.__all__ + assert 'fft' in xp.__all__ + xp.linalg # No error + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + assert 'linalg' in xp.__all__ + assert 'fft' not in xp.__all__ + xp.linalg # No error + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' not in ns + + set_array_api_strict_flags(enabled_extensions=('fft',)) + assert 'linalg' not in xp.__all__ + assert 'fft' in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=()) + assert 'linalg' not in xp.__all__ + assert 'fft' not in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' not in ns From 7bc29d6b638fcdc666028b4e33bd9529e77b4213 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:04:03 -0600 Subject: [PATCH 47/56] Fix setting ARRAY_API_STRICT_API_VERSION to 2021.12 --- array_api_strict/_flags.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 221d0d3..f6cef29 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -305,29 +305,25 @@ def __exit__(self, exc_type, exc_value, traceback): ] def set_flags_from_environment(): + kwargs = {} if "ARRAY_API_STRICT_API_VERSION" in os.environ: - set_array_api_strict_flags( - api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] - ) + kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"] if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ: - set_array_api_strict_flags( - boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" - ) + kwargs["boolean_indexing"] = os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: - set_array_api_strict_flags( - data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" - ) + kwargs["data_dependent_shapes"] = os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") if enabled_extensions == [""]: enabled_extensions = [] - set_array_api_strict_flags(enabled_extensions=enabled_extensions) - else: - # Needed at first import to add linalg and fft to __all__ - set_array_api_strict_flags(enabled_extensions=default_extensions) + kwargs["enabled_extensions"] = enabled_extensions + + # Called unconditionally because it is needed at first import to add + # linalg and fft to __all__ + set_array_api_strict_flags(**kwargs) set_flags_from_environment() From c770c9b5503b7e80ba4aac8026ffe06bccbff9cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:04:19 -0600 Subject: [PATCH 48/56] Add tests for environment variables They're not pretty, but they get the job done. --- array_api_strict/tests/test_flags.py | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 38b1a3b..86ad8e2 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -370,3 +370,124 @@ def test_disabled_extensions(): exec('from array_api_strict import *', ns) assert 'linalg' not in ns assert 'fft' not in ns + + +def test_environment_variables(): + # Test that the environment variables work as expected + subprocess_tests = [ + # ARRAY_API_STRICT_API_VERSION + ('''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '2022.12' + +assert xp.get_array_api_strict_flags()['api_version'] == '2022.12' + +''', {}), + *[ + (f'''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '{version}' + +assert xp.get_array_api_strict_flags()['api_version'] == '{version}' + +if {version} == '2021.12': + assert hasattr(xp, 'linalg') + assert not hasattr(xp, 'fft') + +''', {"ARRAY_API_STRICT_API_VERSION": version}) for version in ('2021.12', '2022.12', '2023.12')], + + # ARRAY_API_STRICT_BOOLEAN_INDEXING + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +assert xp.all(a[mask] == xp.asarray([1., 1.])) +assert xp.get_array_api_strict_flags()['boolean_indexing'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +if {boolean_indexing}: + assert xp.all(a[mask] == xp.asarray([1., 1.])) +else: + try: + a[mask] + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['boolean_indexing'] == {boolean_indexing} +''', {"ARRAY_API_STRICT_BOOLEAN_INDEXING": boolean_indexing}) + for boolean_indexing in ('True', 'False')], + + # ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +xp.unique_all(a) + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +if {data_dependent_shapes}: + xp.unique_all(a) +else: + try: + xp.unique_all(a) + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == {data_dependent_shapes} +''', {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES": data_dependent_shapes}) + for data_dependent_shapes in ('True', 'False')], + + # ARRAY_API_STRICT_ENABLED_EXTENSIONS + ('''\ +import array_api_strict as xp +assert hasattr(xp, 'linalg') +assert hasattr(xp, 'fft') + +assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft') +''', {}), + *[(f'''\ +import array_api_strict as xp + +assert hasattr(xp, 'linalg') == ('linalg' in {extensions.split(',')}) +assert hasattr(xp, 'fft') == ('fft' in {extensions.split(',')}) + +assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == {sorted(set(extensions.split(','))-{''})} +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": extensions}) + for extensions in ('', 'linalg', 'fft', 'linalg,fft')], + ] + + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(f"""\ +STDOUT: +{e.stderr} + +STDERR: +{e.stderr} + +TEST: +{test} + +ENV: +{env}""") From a8f8fdcd665ed664b5335afda0b909c411529c2a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:18:41 -0600 Subject: [PATCH 49/56] More than signature tests are now implemented for 2023.12 --- .github/workflows/array-api-tests.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index b37ec04..af91d2a 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -52,11 +52,6 @@ jobs: run: | # Parameterizing this in the CI matrix is wasteful. Just do a loop here. for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do - # Only signature tests work for now for 2023.12 - if [[ "$ARRAY_API_STRICT_API_VERSION" == "2023.12" ]]; then - PYTEST_ARGS="${PYTEST_ARGS} -k signature - fi - cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} done From 752b70667aea493b417ee6464080813ce78c4c01 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 16 May 2024 14:08:49 -0600 Subject: [PATCH 50/56] Add more info to an error message --- array_api_strict/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 0e85cdc..67ba67c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -20,7 +20,7 @@ def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: - raise ValueError("dtype must be one of the supported dtypes") + raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") def _supports_buffer_protocol(obj): try: From 7cb321498647c3e17ccf7156b5626609e30e1e19 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 17 May 2024 15:06:52 -0600 Subject: [PATCH 51/56] Fix some issues with cumulative_sum - The behavior for dtype=None was incorrect. - Fix an error with axis=-1, include_initial=True. --- array_api_strict/_statistical_functions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 7a42d25..39e3736 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -30,8 +30,9 @@ def cumulative_sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - if dtype is None: - dtype = x.dtype + dt = x.dtype if dtype is None else dtype + if dtype is not None: + dtype = dtype._np_dtype # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: @@ -40,8 +41,10 @@ def cumulative_sum( axis = 0 # np.cumsum does not support include_initial if include_initial: - x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype._np_dtype)) + if axis < 0: + axis += x.ndim + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype)) def max( x: Array, From beb95ae1857d3cff554ead2f8cc1c23964cfa8e6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 24 May 2024 16:07:02 -0600 Subject: [PATCH 52/56] Fix typo --- array_api_strict/_linear_algebra_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 6a1a921..dcb654d 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -61,7 +61,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: elif axis < min(-1, -x1.ndim, -x2.ndim): raise ValueError("axis is out of bounds for x1 and x2") - # In versions if the standard prior to 2023.12, vecdot applied axis after + # In versions of the standard prior to 2023.12, vecdot applied axis after # broadcasting. This is different from applying it before broadcasting # when axis is nonnegative. The below code keeps this behavior for # 2022.12, primarily for backwards compatibility. Note that the behavior From c721f3d552c7203a7233075d1cf13063d7128039 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 28 May 2024 15:36:04 -0600 Subject: [PATCH 53/56] Remove duplicate __all__ definition from _info.py --- array_api_strict/_info.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 5f8c841..ab5447a 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,14 +1,5 @@ from __future__ import annotations -__all__ = [ - "__array_namespace_info__", - "capabilities", - "default_device", - "default_dtypes", - "devices", - "dtypes", -] - from typing import TYPE_CHECKING if TYPE_CHECKING: From 5cf028c51fa22f2926d3a85deacf21c5815d420c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 15:11:38 -0600 Subject: [PATCH 54/56] Fix NumPy 1.26 type promotion in copysign --- array_api_strict/_elementwise_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 9ef71bd..b39bd86 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -328,6 +328,9 @@ def copysign(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.copysign(x1._array, x2._array)) def cos(x: Array, /) -> Array: From 5e607c3f72bcf32443970f4c2b03757f02e6bdd0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 15:19:51 -0600 Subject: [PATCH 55/56] Remove NPY_PROMOTION_STATE=weak from the CI The strict library should be explicitly working around all the bad promotion issues from NumPy 1.26. --- .github/workflows/array-api-tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index af91d2a..ab7dbb8 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -46,9 +46,6 @@ jobs: - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: array_api_strict - # This enables the NEP 50 type promotion behavior (without it a lot of - # tests fail in numpy 1.26 on bad scalar type promotion behavior) - NPY_PROMOTION_STATE: weak run: | # Parameterizing this in the CI matrix is wasteful. Just do a loop here. for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do From 6f8c07f548d4e90bcfae71e74a10b0043385adb6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 17:11:12 -0600 Subject: [PATCH 56/56] Trigger CI