Skip to content

Commit

Permalink
Merge pull request #61 from asmeurer/2023.12-default
Browse files Browse the repository at this point in the history
Make 2023.12 the default version
  • Loading branch information
asmeurer authored Sep 18, 2024
2 parents 718f15b + 2aae491 commit 05c8b0f
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 98 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,4 @@ libraries. Consuming library code should use the
support the array API. Rather, it is intended to be used in the test suites of
consuming libraries to test their array API usage.

array-api-strict currently supports the 2022.12 version of the standard.
2023.12 support is planned and is tracked by [this
issue](https://github.com/data-apis/array-api-strict/issues/25).

See the documentation for more details https://data-apis.org/array-api-strict/
12 changes: 4 additions & 8 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"2023.12",
)

API_VERSION = default_version = "2022.12"
API_VERSION = default_version = "2023.12"

BOOLEAN_INDEXING = True

Expand Down Expand Up @@ -76,10 +76,6 @@ def set_array_api_strict_flags(
Note that 2021.12 is supported, but currently gives the same thing as
2022.12 (except that the fft extension will be disabled).
2023.12 support is experimental. Some features in 2023.12 may still be
missing, and it hasn't been fully tested. A future version of
array-api-strict will change the default version to 2023.12.
boolean_indexing : bool, optional
Whether indexing by a boolean array is supported. This flag is enabled
by default. Note that although boolean array indexing does result in
Expand Down Expand Up @@ -142,8 +138,6 @@ def set_array_api_strict_flags(
raise ValueError(f"Unsupported standard version {api_version!r}")
if api_version == "2021.12":
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
if api_version == "2023.12":
warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2)
API_VERSION = api_version
array_api_strict.__array_api_version__ = API_VERSION

Expand Down Expand Up @@ -262,7 +256,9 @@ def reset_array_api_strict_flags():
BOOLEAN_INDEXING = True
DATA_DEPENDENT_SHAPES = True
ENABLED_EXTENSIONS = default_extensions

array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
set(array_api_strict.__all__) -
set(default_extensions))

class ArrayAPIStrictFlags:
"""
Expand Down
9 changes: 4 additions & 5 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,15 @@ def test_array_keys_use_private_array():
def test_array_namespace():
a = ones((3, 3))
assert a.__array_namespace__() == array_api_strict
assert array_api_strict.__array_api_version__ == "2022.12"
assert array_api_strict.__array_api_version__ == "2023.12"

assert a.__array_namespace__(api_version=None) is array_api_strict
assert array_api_strict.__array_api_version__ == "2022.12"
assert array_api_strict.__array_api_version__ == "2023.12"

assert a.__array_namespace__(api_version="2022.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2022.12"

with pytest.warns(UserWarning):
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2023.12"

with pytest.warns(UserWarning):
Expand All @@ -435,7 +434,7 @@ def test_iter():

@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
def dlpack_2023_12(api_version):
if api_version != '2022.12':
if api_version == '2021.12':
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)
else:
Expand Down
4 changes: 1 addition & 3 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from .._flags import set_array_api_strict_flags

import pytest

def nargs(func):
return len(getfullargspec(func).args)
Expand Down Expand Up @@ -111,8 +110,7 @@ def _array_vals():
yield asarray(1.0, dtype=d)

# Use the latest version of the standard so all functions are included
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version="2023.12")
set_array_api_strict_flags(api_version="2023.12")

for x in _array_vals():
for func_name, types in elementwise_function_input_types.items():
Expand Down
105 changes: 57 additions & 48 deletions array_api_strict/tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,52 @@

import pytest

def test_flags():
# Test defaults
def test_flag_defaults():
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'api_version': '2023.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}


def test_reset_flags():
with pytest.warns(UserWarning):
set_array_api_strict_flags(
api_version='2021.12',
boolean_indexing=False,
data_dependent_shapes=False,
enabled_extensions=())
reset_array_api_strict_flags()
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}

# Test setting flags

def test_setting_flags():
set_array_api_strict_flags(data_dependent_shapes=False)
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'api_version': '2023.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('linalg', 'fft'),
}
set_array_api_strict_flags(enabled_extensions=('fft',))
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'api_version': '2023.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('fft',),
}

def test_flags_api_version_2021_12():
# Make sure setting the version to 2021.12 disables fft and issues a
# warning.
with pytest.warns(UserWarning) as record:
Expand All @@ -55,27 +74,23 @@ def test_flags():
assert flags == {
'api_version': '2021.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': (),
'data_dependent_shapes': True,
'enabled_extensions': ('linalg',),
}
reset_array_api_strict_flags()

with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2021.12')
def test_flags_api_version_2022_12():
set_array_api_strict_flags(api_version='2022.12')
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2021.12',
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg',),
'enabled_extensions': ('linalg', 'fft'),
}
reset_array_api_strict_flags()

# 2023.12 should issue a warning
with pytest.warns(UserWarning) as record:
set_array_api_strict_flags(api_version='2023.12')
assert len(record) == 1
assert '2023.12' in str(record[0].message)

def test_flags_api_version_2023_12():
set_array_api_strict_flags(api_version='2023.12')
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2023.12',
Expand All @@ -84,6 +99,7 @@ def test_flags():
'enabled_extensions': ('linalg', 'fft'),
}

def test_setting_flags_invalid():
# Test setting flags with invalid values
pytest.raises(ValueError, lambda:
set_array_api_strict_flags(api_version='2020.12'))
Expand All @@ -94,35 +110,15 @@ def test_flags():
api_version='2021.12',
enabled_extensions=('linalg', 'fft')))

# Test resetting flags
with pytest.warns(UserWarning):
set_array_api_strict_flags(
api_version='2021.12',
boolean_indexing=False,
data_dependent_shapes=False,
enabled_extensions=())
reset_array_api_strict_flags()
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}

def test_api_version():
# Test defaults
assert xp.__array_api_version__ == '2022.12'
assert xp.__array_api_version__ == '2023.12'

# Test setting the version
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2021.12')
assert xp.__array_api_version__ == '2021.12'
set_array_api_strict_flags(api_version='2022.12')
assert xp.__array_api_version__ == '2022.12'

def test_data_dependent_shapes():
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2023.12') # to enable repeat()

a = asarray([0, 0, 1, 2, 2])
mask = asarray([True, False, True, False, True])
repeats = asarray([1, 1, 2, 2, 2])
Expand Down Expand Up @@ -275,12 +271,16 @@ def test_fft(func_name):
def test_api_version_2023_12(func_name):
func = api_version_2023_12_examples[func_name]

# By default, these functions should error
# By default, these functions should not error
func()

# In 2022.12, these functions should error
set_array_api_strict_flags(api_version='2022.12')
pytest.raises(RuntimeError, func)

with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2023.12')
func()
# Test the behavior gets updated properly
set_array_api_strict_flags(api_version='2023.12')
func()

set_array_api_strict_flags(api_version='2022.12')
pytest.raises(RuntimeError, func)
Expand Down Expand Up @@ -371,16 +371,25 @@ def test_disabled_extensions():
assert 'linalg' not in ns
assert 'fft' not in ns

reset_array_api_strict_flags()
assert 'linalg' in xp.__all__
assert 'fft' in xp.__all__
xp.linalg # No error
xp.fft # No error
ns = {}
exec('from array_api_strict import *', ns)
assert 'linalg' in ns
assert 'fft' in ns

def test_environment_variables():
# Test that the environment variables work as expected
subprocess_tests = [
# ARRAY_API_STRICT_API_VERSION
('''\
import array_api_strict as xp
assert xp.__array_api_version__ == '2022.12'
assert xp.__array_api_version__ == '2023.12'
assert xp.get_array_api_strict_flags()['api_version'] == '2022.12'
assert xp.get_array_api_strict_flags()['api_version'] == '2023.12'
''', {}),
*[
Expand Down
27 changes: 14 additions & 13 deletions array_api_strict/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

# Technically this is linear_algebra, not linalg, but it's simpler to keep
# both of these tests together
def test_vecdot_2023_12():
# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
# 0 behavior (which is primarily kept for backwards compatibility).


# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
# 0 behavior (which is primarily kept for backwards compatibility).
def test_vecdot_2022_12():
# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
set_array_api_strict_flags(api_version='2022.12')

a = xp.ones((2, 3, 4, 5))
b = xp.ones(( 3, 4, 1))

# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5)
assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5)
Expand All @@ -34,10 +37,13 @@ def test_vecdot_2023_12():
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)

def test_vecdot_2023_12():
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
# min(x1.ndim, x2.ndim), which is unambiguous
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2023.12')
set_array_api_strict_flags(api_version='2023.12')

a = xp.ones((2, 3, 4, 5))
b = xp.ones(( 3, 4, 1))

pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1))
Expand All @@ -56,7 +62,7 @@ def test_cross(api_version):
# This test tests everything that should be the same across all supported
# API versions.

if api_version != '2022.12':
if api_version == '2021.12':
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)
else:
Expand Down Expand Up @@ -88,7 +94,7 @@ def test_cross_2022_12(api_version):
# backwards compatibility. Note that unlike vecdot, array_api_strict
# cross() never implemented the "after broadcasting" axis behavior, but
# just reused NumPy cross(), which applies axes before broadcasting.
if api_version != '2022.12':
if api_version == '2021.12':
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)
else:
Expand All @@ -104,11 +110,6 @@ def test_cross_2022_12(api_version):
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)

def test_cross_2023_12():
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
# min(x1.ndim, x2.ndim), which is unambiguous
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2023.12')

a = xp.ones((3, 2, 4, 5))
b = xp.ones((3, 2, 4, 1))
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))
Expand Down
20 changes: 15 additions & 5 deletions array_api_strict/tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import array_api_strict as xp

# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
# with dtype=None
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
def test_sum_prod_trace_2023_12(func_name):
# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
# with dtype=None
def test_sum_prod_trace_2022_12(func_name):
set_array_api_strict_flags(api_version='2022.12')

if func_name == 'trace':
func = getattr(xp.linalg, func_name)
else:
Expand All @@ -21,8 +23,16 @@ def test_sum_prod_trace_2023_12(func_name):
assert func(a_complex).dtype == xp.complex128
assert func(a_int).dtype == xp.int64

with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version='2023.12')
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
def test_sum_prod_trace_2023_12(func_name):
a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32)
a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64)
a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32)

if func_name == 'trace':
func = getattr(xp.linalg, func_name)
else:
func = getattr(xp, func_name)

assert func(a_real).dtype == xp.float32
assert func(a_complex).dtype == xp.complex64
Expand Down
Loading

0 comments on commit 05c8b0f

Please sign in to comment.