Skip to content

Commit

Permalink
Make extensions give AttributeError when they are disabled
Browse files Browse the repository at this point in the history
This is how the test suite and presumably some other codes detect if
extensions are enabled or not.

This also dynamically updates __all__ whenever extensions are enabled or
disabled.
  • Loading branch information
asmeurer committed May 3, 2024
1 parent dd01b12 commit 43b9088
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 47 deletions.
32 changes: 22 additions & 10 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
"""

__all__ = []

# Warning: __array_api_version__ could change globally with
# set_array_api_strict_flags(). This should always be accessed as an
# attribute, like xp.__array_api_version__, or using
# array_api_strict.get_array_api_strict_flags()['api_version'].
from ._flags import API_VERSION as __array_api_version__

__all__ = ["__array_api_version__"]
__all__ += ["__array_api_version__"]

from ._constants import e, inf, nan, pi, newaxis

Expand Down Expand Up @@ -266,19 +268,10 @@
"__array_namespace_info__",
]

# linalg is an extension in the array API spec, which is a sub-namespace. Only
# a subset of functions in it are imported into the top-level namespace.
from . import linalg

__all__ += ["linalg"]

from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot

__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]

from . import fft
__all__ += ["fft"]

from ._manipulation_functions import (
concat,
expand_dims,
Expand Down Expand Up @@ -330,3 +323,22 @@
from . import _version
__version__ = _version.get_versions()['version']
del _version


# Extensions can be enabled or disabled dynamically. In order to make
# "array_api_strict.linalg" give an AttributeError when it is disabled, we
# use __getattr__. Note that linalg and fft are dynamically added and removed
# from __all__ in set_array_api_strict_flags.

def __getattr__(name):
if name in ['linalg', 'fft']:
if name in get_array_api_strict_flags()['enabled_extensions']:
if name == 'linalg':
from . import _linalg
return _linalg
elif name == 'fft':
from . import _fft
return _fft
else:
raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
File renamed without changes.
7 changes: 7 additions & 0 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def set_array_api_strict_flags(
else:
ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION])

array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
set(array_api_strict.__all__) -
set(default_extensions))

# We have to do this separately or it won't get added as the docstring
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
supported_versions=supported_versions,
Expand Down Expand Up @@ -321,6 +325,9 @@ def set_flags_from_environment():
if enabled_extensions == [""]:
enabled_extensions = []
set_array_api_strict_flags(enabled_extensions=enabled_extensions)
else:
# Needed at first import to add linalg and fft to __all__
set_array_api_strict_flags(enabled_extensions=default_extensions)

set_flags_from_environment()

Expand Down
File renamed without changes.
168 changes: 131 additions & 37 deletions array_api_strict/tests/test_flags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import sys
import subprocess

from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
reset_array_api_strict_flags)
from .._info import (capabilities, default_device, default_dtypes, devices,
dtypes)
from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft,
ihfft, fftfreq, rfftfreq, fftshift, ifftshift)
from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv,
matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv,
qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm)

from .. import (asarray, unique_all, unique_counts, unique_inverse,
unique_values, nonzero, repeat)
Expand Down Expand Up @@ -152,29 +160,29 @@ def test_boolean_indexing():
pytest.raises(RuntimeError, lambda: a[mask])

linalg_examples = {
'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)),
'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])),
'det': lambda: xp.linalg.det(xp.eye(3)),
'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)),
'eigh': lambda: xp.linalg.eigh(xp.eye(3)),
'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)),
'inv': lambda: xp.linalg.inv(xp.eye(3)),
'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)),
'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)),
'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2),
'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)),
'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)),
'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
'pinv': lambda: xp.linalg.pinv(xp.eye(3)),
'qr': lambda: xp.linalg.qr(xp.eye(3)),
'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)),
'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)),
'svd': lambda: xp.linalg.svd(xp.eye(3)),
'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)),
'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)),
'trace': lambda: xp.linalg.trace(xp.eye(3)),
'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])),
'cholesky': lambda: cholesky(xp.eye(3)),
'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])),
'det': lambda: det(xp.eye(3)),
'diagonal': lambda: diagonal(xp.eye(3)),
'eigh': lambda: eigh(xp.eye(3)),
'eigvalsh': lambda: eigvalsh(xp.eye(3)),
'inv': lambda: inv(xp.eye(3)),
'matmul': lambda: matmul(xp.eye(3), xp.eye(3)),
'matrix_norm': lambda: matrix_norm(xp.eye(3)),
'matrix_power': lambda: matrix_power(xp.eye(3), 2),
'matrix_rank': lambda: matrix_rank(xp.eye(3)),
'matrix_transpose': lambda: matrix_transpose(xp.eye(3)),
'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
'pinv': lambda: pinv(xp.eye(3)),
'qr': lambda: qr(xp.eye(3)),
'slogdet': lambda: slogdet(xp.eye(3)),
'solve': lambda: solve(xp.eye(3), xp.eye(3)),
'svd': lambda: svd(xp.eye(3)),
'svdvals': lambda: svdvals(xp.eye(3)),
'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)),
'trace': lambda: trace(xp.eye(3)),
'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])),
}

assert set(linalg_examples) == set(xp.linalg.__all__)
Expand Down Expand Up @@ -210,20 +218,20 @@ def test_linalg(func_name):
main_namespace_func()

fft_examples = {
'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])),
'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])),
'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])),
'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])),
'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])),
'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])),
'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])),
'fftfreq': lambda: xp.fft.fftfreq(4),
'rfftfreq': lambda: xp.fft.rfftfreq(4),
'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])),
'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])),
'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])),
'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])),
'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])),
'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])),
'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])),
'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])),
'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])),
'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])),
'fftfreq': lambda: fftfreq(4),
'rfftfreq': lambda: rfftfreq(4),
'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])),
'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])),
}

assert set(fft_examples) == set(xp.fft.__all__)
Expand Down Expand Up @@ -276,3 +284,89 @@ def test_api_version_2023_12(func_name):

set_array_api_strict_flags(api_version='2022.12')
pytest.raises(RuntimeError, func)

def test_disabled_extensions():
# Test that xp.extension errors when an extension is disabled, and that
# xp.__all__ is updated properly.

# First test that things are correct on the initial import. Since we have
# already called set_array_api_strict_flags many times throughout running
# the tests, we have to test this in a subprocess.
subprocess_tests = [('''\
import array_api_strict
array_api_strict.linalg # No error
array_api_strict.fft # No error
assert "linalg" in array_api_strict.__all__
assert "fft" in array_api_strict.__all__
assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__))
''', {}),
# Test that the initial population of __all__ works correctly
('''\
from array_api_strict import * # No error
linalg # Should have been imported by the previous line
fft
''', {}),
('''\
from array_api_strict import * # No error
linalg # Should have been imported by the previous line
assert 'fft' not in globals()
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}),
('''\
from array_api_strict import * # No error
fft # Should have been imported by the previous line
assert 'linalg' not in globals()
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}),
('''\
from array_api_strict import * # No error
assert 'linalg' not in globals()
assert 'fft' not in globals()
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}),
]
for test, env in subprocess_tests:
try:
subprocess.run([sys.executable, '-c', test], check=True,
capture_output=True, encoding='utf-8', env=env)
except subprocess.CalledProcessError as e:
print(e.stdout, end='')
# Ensure the exception is shown in the output log
raise AssertionError(e.stderr)

assert 'linalg' in xp.__all__
assert 'fft' in xp.__all__
xp.linalg # No error
xp.fft # No error
ns = {}
exec('from array_api_strict import *', ns)
assert 'linalg' in ns
assert 'fft' in ns

set_array_api_strict_flags(enabled_extensions=('linalg',))
assert 'linalg' in xp.__all__
assert 'fft' not in xp.__all__
xp.linalg # No error
pytest.raises(AttributeError, lambda: xp.fft)
ns = {}
exec('from array_api_strict import *', ns)
assert 'linalg' in ns
assert 'fft' not in ns

set_array_api_strict_flags(enabled_extensions=('fft',))
assert 'linalg' not in xp.__all__
assert 'fft' in xp.__all__
pytest.raises(AttributeError, lambda: xp.linalg)
xp.fft # No error
ns = {}
exec('from array_api_strict import *', ns)
assert 'linalg' not in ns
assert 'fft' in ns

set_array_api_strict_flags(enabled_extensions=())
assert 'linalg' not in xp.__all__
assert 'fft' not in xp.__all__
pytest.raises(AttributeError, lambda: xp.linalg)
pytest.raises(AttributeError, lambda: xp.fft)
ns = {}
exec('from array_api_strict import *', ns)
assert 'linalg' not in ns
assert 'fft' not in ns

0 comments on commit 43b9088

Please sign in to comment.