From 36d15bbe24eeb0c597bd98dd248b1d9adbe21c40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:18:12 -0600 Subject: [PATCH 1/3] Fix __all__ not getting updated with reset_array_api_strict_flags() --- array_api_strict/_flags.py | 4 +++- array_api_strict/tests/test_flags.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 62acddf..46c0786 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -262,7 +262,9 @@ def reset_array_api_strict_flags(): BOOLEAN_INDEXING = True DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions - + array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) | + set(array_api_strict.__all__) - + set(default_extensions)) class ArrayAPIStrictFlags: """ diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 86ad8e2..76ca596 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -371,6 +371,15 @@ def test_disabled_extensions(): assert 'linalg' not in ns assert 'fft' not in ns + reset_array_api_strict_flags() + assert 'linalg' in xp.__all__ + assert 'fft' in xp.__all__ + xp.linalg # No error + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' in ns def test_environment_variables(): # Test that the environment variables work as expected From b3d5214e453d12780897048b9732da1ca0c57b39 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:20:47 -0600 Subject: [PATCH 2/3] Make 2023.12 the default version SciPy and others have been using it and haven't found any issues. Test suite support is still not 100% but is pretty strong at this point. This also splits some of the tests to avoid setting different versions of flags within the same test. --- README.md | 4 - array_api_strict/_flags.py | 8 +- array_api_strict/tests/test_array_object.py | 9 +- .../tests/test_elementwise_functions.py | 3 +- array_api_strict/tests/test_flags.py | 96 +++++++++---------- array_api_strict/tests/test_linalg.py | 27 +++--- .../tests/test_statistical_functions.py | 20 +++- docs/index.md | 16 +--- 8 files changed, 87 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 8172237..0d52fec 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,4 @@ libraries. Consuming library code should use the support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. -array-api-strict currently supports the 2022.12 version of the standard. -2023.12 support is planned and is tracked by [this -issue](https://github.com/data-apis/array-api-strict/issues/25). - See the documentation for more details https://data-apis.org/array-api-strict/ diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 46c0786..c393ad9 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -24,7 +24,7 @@ "2023.12", ) -API_VERSION = default_version = "2022.12" +API_VERSION = default_version = "2023.12" BOOLEAN_INDEXING = True @@ -76,10 +76,6 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). - 2023.12 support is experimental. Some features in 2023.12 may still be - missing, and it hasn't been fully tested. A future version of - array-api-strict will change the default version to 2023.12. - boolean_indexing : bool, optional Whether indexing by a boolean array is supported. This flag is enabled by default. Note that although boolean array indexing does result in @@ -142,8 +138,6 @@ def set_array_api_strict_flags( raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) - if api_version == "2023.12": - warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2) API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b0d4868..dad6696 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -406,16 +406,15 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict - assert array_api_strict.__array_api_version__ == "2022.12" + assert array_api_strict.__array_api_version__ == "2023.12" assert a.__array_namespace__(api_version=None) is array_api_strict - assert array_api_strict.__array_api_version__ == "2022.12" + assert array_api_strict.__array_api_version__ == "2023.12" assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" - with pytest.warns(UserWarning): - assert a.__array_namespace__(api_version="2023.12") is array_api_strict + assert a.__array_namespace__(api_version="2023.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2023.12" with pytest.warns(UserWarning): @@ -435,7 +434,7 @@ def test_iter(): @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index fa3405a..8f3ce7a 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -111,8 +111,7 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2023.12") + set_array_api_strict_flags(api_version="2023.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 76ca596..2603f35 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -18,21 +18,38 @@ import pytest -def test_flags(): - # Test defaults +def test_flag_defaults(): flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + +def test_reset_flags(): + with pytest.warns(UserWarning): + set_array_api_strict_flags( + api_version='2021.12', + boolean_indexing=False, + data_dependent_shapes=False, + enabled_extensions=()) + reset_array_api_strict_flags() + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } - # Test setting flags + +def test_setting_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), @@ -40,11 +57,13 @@ def test_flags(): set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } + +def test_flags_api_version_2021_12(): # Make sure setting the version to 2021.12 disables fft and issues a # warning. with pytest.warns(UserWarning) as record: @@ -55,27 +74,23 @@ def test_flags(): assert flags == { 'api_version': '2021.12', 'boolean_indexing': True, - 'data_dependent_shapes': False, - 'enabled_extensions': (), + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg',), } - reset_array_api_strict_flags() - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2021.12') +def test_flags_api_version_2022_12(): + set_array_api_strict_flags(api_version='2022.12') flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2021.12', + 'api_version': '2022.12', 'boolean_indexing': True, 'data_dependent_shapes': True, - 'enabled_extensions': ('linalg',), + 'enabled_extensions': ('linalg', 'fft'), } - reset_array_api_strict_flags() - # 2023.12 should issue a warning - with pytest.warns(UserWarning) as record: - set_array_api_strict_flags(api_version='2023.12') - assert len(record) == 1 - assert '2023.12' in str(record[0].message) + +def test_flags_api_version_2023_12(): + set_array_api_strict_flags(api_version='2023.12') flags = get_array_api_strict_flags() assert flags == { 'api_version': '2023.12', @@ -84,6 +99,7 @@ def test_flags(): 'enabled_extensions': ('linalg', 'fft'), } +def test_setting_flags_invalid(): # Test setting flags with invalid values pytest.raises(ValueError, lambda: set_array_api_strict_flags(api_version='2020.12')) @@ -94,35 +110,15 @@ def test_flags(): api_version='2021.12', enabled_extensions=('linalg', 'fft'))) - # Test resetting flags - with pytest.warns(UserWarning): - set_array_api_strict_flags( - api_version='2021.12', - boolean_indexing=False, - data_dependent_shapes=False, - enabled_extensions=()) - reset_array_api_strict_flags() - flags = get_array_api_strict_flags() - assert flags == { - 'api_version': '2022.12', - 'boolean_indexing': True, - 'data_dependent_shapes': True, - 'enabled_extensions': ('linalg', 'fft'), - } - def test_api_version(): # Test defaults - assert xp.__array_api_version__ == '2022.12' + assert xp.__array_api_version__ == '2023.12' # Test setting the version - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2021.12') - assert xp.__array_api_version__ == '2021.12' + set_array_api_strict_flags(api_version='2022.12') + assert xp.__array_api_version__ == '2022.12' def test_data_dependent_shapes(): - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') # to enable repeat() - a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) repeats = asarray([1, 1, 2, 2, 2]) @@ -275,12 +271,16 @@ def test_fft(func_name): def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] - # By default, these functions should error + # By default, these functions should not error + func() + + # In 2022.12, these functions should error + set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') - func() + # Test the behavior gets updated properly + set_array_api_strict_flags(api_version='2023.12') + func() set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) @@ -387,9 +387,9 @@ def test_environment_variables(): # ARRAY_API_STRICT_API_VERSION ('''\ import array_api_strict as xp -assert xp.__array_api_version__ == '2022.12' +assert xp.__array_api_version__ == '2023.12' -assert xp.get_array_api_strict_flags()['api_version'] == '2022.12' +assert xp.get_array_api_strict_flags()['api_version'] == '2023.12' ''', {}), *[ diff --git a/array_api_strict/tests/test_linalg.py b/array_api_strict/tests/test_linalg.py index 5e6cda2..04023bc 100644 --- a/array_api_strict/tests/test_linalg.py +++ b/array_api_strict/tests/test_linalg.py @@ -8,14 +8,17 @@ # Technically this is linear_algebra, not linalg, but it's simpler to keep # both of these tests together -def test_vecdot_2023_12(): - # Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= - # 0 behavior (which is primarily kept for backwards compatibility). + + +# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= +# 0 behavior (which is primarily kept for backwards compatibility). +def test_vecdot_2022_12(): + # 2022.12 behavior, which is to apply axis >= 0 after broadcasting + set_array_api_strict_flags(api_version='2022.12') a = xp.ones((2, 3, 4, 5)) b = xp.ones(( 3, 4, 1)) - # 2022.12 behavior, which is to apply axis >= 0 after broadcasting pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5) assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5) @@ -34,10 +37,13 @@ def test_vecdot_2023_12(): assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) +def test_vecdot_2023_12(): # 2023.12 behavior, which is to only allow axis < 0 and axis >= # min(x1.ndim, x2.ndim), which is unambiguous - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') + set_array_api_strict_flags(api_version='2023.12') + + a = xp.ones((2, 3, 4, 5)) + b = xp.ones(( 3, 4, 1)) pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1)) @@ -56,7 +62,7 @@ def test_cross(api_version): # This test tests everything that should be the same across all supported # API versions. - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: @@ -88,7 +94,7 @@ def test_cross_2022_12(api_version): # backwards compatibility. Note that unlike vecdot, array_api_strict # cross() never implemented the "after broadcasting" axis behavior, but # just reused NumPy cross(), which applies axes before broadcasting. - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: @@ -104,11 +110,6 @@ def test_cross_2022_12(api_version): assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) def test_cross_2023_12(): - # 2023.12 behavior, which is to only allow axis < 0 and axis >= - # min(x1.ndim, x2.ndim), which is unambiguous - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') - a = xp.ones((3, 2, 4, 5)) b = xp.ones((3, 2, 4, 1)) pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index 61e848c..7f2a457 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -4,10 +4,12 @@ import array_api_strict as xp +# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes +# with dtype=None @pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) -def test_sum_prod_trace_2023_12(func_name): - # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes - # with dtype=None +def test_sum_prod_trace_2022_12(func_name): + set_array_api_strict_flags(api_version='2022.12') + if func_name == 'trace': func = getattr(xp.linalg, func_name) else: @@ -21,8 +23,16 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_complex).dtype == xp.complex128 assert func(a_int).dtype == xp.int64 - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) +def test_sum_prod_trace_2023_12(func_name): + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) + + if func_name == 'trace': + func = getattr(xp.linalg, func_name) + else: + func = getattr(xp, func_name) assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 diff --git a/docs/index.md b/docs/index.md index a14fbcb..12aadbb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,10 +16,10 @@ support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. array-api-strict currently supports the -[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) -version of the standard. Experimental -[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) -support is implemented, [but must be enabled with a +[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) +version of the standard. +[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) +support is also implemented, [and can be enabled with a flag](array-api-strict-flags). ## Install @@ -176,14 +176,6 @@ issue, but this hasn't necessarily been tested thoroughly. this deviation may be tested with type checking. This [behavior may improve in the future](https://github.com/data-apis/array-api-strict/issues/6). -5. array-api-strict currently uses the 2022.12 version of the array API - standard by default. Support for 2023.12 is implemented but is still - experimental and not fully tested. It can be enabled with - {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') - ` or by setting the - environment variable {envvar}`ARRAY_API_STRICT_API_VERSION=2023.12 - `. - (numpy.array_api)= ## Relationship to `numpy.array_api` From 2aae491a421bc8aa74b35d37d3006cf20d479305 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:23:54 -0600 Subject: [PATCH 3/3] Remove unused import --- array_api_strict/tests/test_elementwise_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 8f3ce7a..870361e 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -17,7 +17,6 @@ ) from .._flags import set_array_api_strict_flags -import pytest def nargs(func): return len(getfullargspec(func).args)