Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary 2023.12 support #35

Merged
merged 59 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
e61b50d
Add a @requires_api_version decorator
asmeurer Apr 19, 2024
f49845a
Don't re-enable disabled extensions when setting the api version
asmeurer Apr 19, 2024
71c5231
Add support for setting the api version to 2023.12
asmeurer Apr 19, 2024
c39fdbf
Set the stacklevel in the set_array_api_strict_flags() warnings
asmeurer Apr 19, 2024
31c5a89
Add clip()
asmeurer Apr 19, 2024
6a6719f
Add a TODO note for clip()
asmeurer Apr 19, 2024
aa61833
Merge branch 'main' into 2023.12
asmeurer Apr 19, 2024
77e6177
Remove unused import
asmeurer Apr 19, 2024
04c24d7
Add copysign
asmeurer Apr 19, 2024
c4587a4
Implement cumulative_sum (still needs to be tested)
asmeurer Apr 20, 2024
16b38d3
Add a comment about cumulative_sum and 0-D inputs
asmeurer Apr 22, 2024
b689d43
Add hypot()
asmeurer Apr 22, 2024
9ee08c7
Update elementwise tests for new elementwise functions
asmeurer Apr 22, 2024
e24f55e
Clear trailing whitespace
asmeurer Apr 22, 2024
f5fbf78
Silence warnings output in the tests
asmeurer Apr 22, 2024
3e2d46d
Add missing requires_api_version decorator to hypot()
asmeurer Apr 23, 2024
250ba86
Add maximum and minimum
asmeurer Apr 23, 2024
eb063e2
Add moveaxis
asmeurer Apr 23, 2024
9938059
Add repeat()
asmeurer Apr 23, 2024
095be2f
Require the repeats array to have an integer dtype
asmeurer Apr 23, 2024
1c4460d
Add searchsorted
asmeurer Apr 24, 2024
730e716
Add comment about x1 being 1-D in searchsorted
asmeurer Apr 24, 2024
f26bd49
Add signbit
asmeurer Apr 24, 2024
dc1baad
Add tile()
asmeurer Apr 24, 2024
a30536b
Add unstack()
asmeurer Apr 24, 2024
f247130
Merge branch 'main' into 2023.12
asmeurer Apr 25, 2024
161acaa
Add the inspection APIs
asmeurer Apr 25, 2024
4d3ff6c
Fix test failures
asmeurer Apr 26, 2024
84d2aa5
Always make warnings errors in the tests
asmeurer Apr 26, 2024
05fa0b5
Add tests that the new 2023.12 functions are properly decorated
asmeurer Apr 26, 2024
8333107
Update documentation for 2023.12 support
asmeurer Apr 26, 2024
a437da3
Implement 2023.12 behavior for sum() and prod()
asmeurer Apr 27, 2024
9f954e6
Implement 2023.12 behavior for trace
asmeurer Apr 29, 2024
8572df3
Add a test for sum/trace/prod 2023.12 upcasting behavior
asmeurer Apr 29, 2024
47894ff
Add 2023.12 axis restrictions to vecdot() and cross()
asmeurer Apr 29, 2024
6b43194
Add device flag to astype in 2023.12
asmeurer Apr 29, 2024
3fde5dd
Factor out device checks into a helper function
asmeurer May 1, 2024
1ac5288
Add 2023.12 device and copy keywords to from_dlpack
asmeurer May 1, 2024
dc4684b
Update the signature of __dlpack__ for 2023.12
asmeurer May 1, 2024
647a5f0
Add tests for from_dlpack and __dlpack__ 2023.12 behavior
asmeurer May 1, 2024
306de9b
Add 2023.12 testing to the CI
asmeurer May 1, 2024
44bbdb2
Better error message
asmeurer May 1, 2024
e5225ed
Parameterize the API version in a loop instead of in the matrix
asmeurer May 1, 2024
3e0be7d
Ensure a.mT works even if the linalg extension is disabled
asmeurer May 3, 2024
338ebfe
Don't allow environment variables to be set during test runs
asmeurer May 1, 2024
6a466f4
Use a more robust way to fail the tests if an env var is set
asmeurer May 3, 2024
dd01b12
Fix setting ARRAY_API_STRICT_ENABLED_EXTENSIONS=''
asmeurer May 3, 2024
43b9088
Make extensions give AttributeError when they are disabled
asmeurer May 3, 2024
7bc29d6
Fix setting ARRAY_API_STRICT_API_VERSION to 2021.12
asmeurer May 6, 2024
c770c9b
Add tests for environment variables
asmeurer May 6, 2024
a8f8fdc
More than signature tests are now implemented for 2023.12
asmeurer May 6, 2024
752b706
Add more info to an error message
asmeurer May 16, 2024
7cb3214
Fix some issues with cumulative_sum
asmeurer May 17, 2024
beb95ae
Fix typo
asmeurer May 24, 2024
c721f3d
Remove duplicate __all__ definition from _info.py
asmeurer May 28, 2024
5cf028c
Fix NumPy 1.26 type promotion in copysign
asmeurer Jun 26, 2024
b65e9a5
Merge branch 'main' into 2023.12
asmeurer Jun 26, 2024
5e607c3
Remove NPY_PROMOTION_STATE=weak from the CI
asmeurer Jun 26, 2024
6f8c07f
Trigger CI
asmeurer Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on: [push, pull_request]

env:
PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200"
API_VERSIONS: "2022.12 2023.12"

jobs:
array-api-tests:
Expand Down Expand Up @@ -45,9 +46,9 @@ jobs:
- name: Run the array API testsuite
env:
ARRAY_API_TESTS_MODULE: array_api_strict
# This enables the NEP 50 type promotion behavior (without it a lot of
# tests fail in numpy 1.26 on bad scalar type promotion behavior)
NPY_PROMOTION_STATE: weak
run: |
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
# Parameterizing this in the CI matrix is wasteful. Just do a loop here.
for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}
done
60 changes: 47 additions & 13 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

"""

__all__ = []

# Warning: __array_api_version__ could change globally with
# set_array_api_strict_flags(). This should always be accessed as an
# attribute, like xp.__array_api_version__, or using
# array_api_strict.get_array_api_strict_flags()['api_version'].
from ._flags import API_VERSION as __array_api_version__

__all__ = ["__array_api_version__"]
__all__ += ["__array_api_version__"]

from ._constants import e, inf, nan, pi, newaxis

Expand Down Expand Up @@ -137,7 +139,9 @@
bitwise_right_shift,
bitwise_xor,
ceil,
clip,
conj,
copysign,
cos,
cosh,
divide,
Expand All @@ -148,6 +152,7 @@
floor_divide,
greater,
greater_equal,
hypot,
imag,
isfinite,
isinf,
Expand All @@ -163,6 +168,8 @@
logical_not,
logical_or,
logical_xor,
maximum,
minimum,
multiply,
negative,
not_equal,
Expand All @@ -172,6 +179,7 @@
remainder,
round,
sign,
signbit,
sin,
sinh,
square,
Expand Down Expand Up @@ -199,7 +207,9 @@
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"clip",
"conj",
"copysign",
"cos",
"cosh",
"divide",
Expand All @@ -210,6 +220,7 @@
"floor_divide",
"greater",
"greater_equal",
"hypot",
"imag",
"isfinite",
"isinf",
Expand All @@ -225,6 +236,8 @@
"logical_not",
"logical_or",
"logical_xor",
"maximum",
"minimum",
"multiply",
"negative",
"not_equal",
Expand All @@ -234,6 +247,7 @@
"remainder",
"round",
"sign",
"signbit",
"sin",
"sinh",
"square",
Expand All @@ -248,35 +262,36 @@

__all__ += ["take"]

# linalg is an extension in the array API spec, which is a sub-namespace. Only
# a subset of functions in it are imported into the top-level namespace.
from . import linalg
from ._info import __array_namespace_info__

__all__ += ["linalg"]
__all__ += [
"__array_namespace_info__",
]

from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot

__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]

from . import fft
__all__ += ["fft"]

from ._manipulation_functions import (
concat,
expand_dims,
flip,
moveaxis,
permute_dims,
repeat,
reshape,
roll,
squeeze,
stack,
tile,
unstack,
)

__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]

from ._searching_functions import argmax, argmin, nonzero, where
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where

__all__ += ["argmax", "argmin", "nonzero", "where"]
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]

from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values

Expand All @@ -286,9 +301,9 @@

__all__ += ["argsort", "sort"]

from ._statistical_functions import max, mean, min, prod, std, sum, var
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var

__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]

from ._utility_functions import all, any

Expand All @@ -308,3 +323,22 @@
from . import _version
__version__ = _version.get_versions()['version']
del _version


# Extensions can be enabled or disabled dynamically. In order to make
# "array_api_strict.linalg" give an AttributeError when it is disabled, we
# use __getattr__. Note that linalg and fft are dynamically added and removed
# from __all__ in set_array_api_strict_flags.

def __getattr__(name):
if name in ['linalg', 'fft']:
if name in get_array_api_strict_flags()['enabled_extensions']:
if name == 'linalg':
from . import _linalg
return _linalg
elif name == 'fft':
from . import _fft
return _fft
else:
raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
32 changes: 29 additions & 3 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __repr__(self):

CPU_DEVICE = _cpu_device()

_default = object()

class Array:
"""
n-d array object for the array API namespace.
Expand Down Expand Up @@ -437,7 +439,7 @@ def _validate_index(self, key):
"Array API when the array is the sole index."
)
if not get_array_api_strict_flags()['boolean_indexing']:
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")

elif i.dtype in _integer_dtypes and i.ndim != 0:
raise IndexError(
Expand Down Expand Up @@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
res = self._array.__complex__()
return res

def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
def __dlpack__(
self: Array,
/,
*,
stream: Optional[Union[int, Any]] = None,
max_version: Optional[tuple[int, int]] = _default,
dl_device: Optional[tuple[IntEnum, int]] = _default,
copy: Optional[bool] = _default,
) -> PyCapsule:
"""
Performs the operation __dlpack__.
"""
if get_array_api_strict_flags()['api_version'] < '2023.12':
if max_version is not _default:
raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API")
if dl_device is not _default:
raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API")
if copy is not _default:
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")

# Going to wait for upstream numpy support
if max_version not in [_default, None]:
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
if dl_device not in [_default, None]:
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
if copy not in [_default, None]:
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")

return self._array.__dlpack__(stream=stream)

def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
Expand Down Expand Up @@ -1142,7 +1168,7 @@ def device(self) -> Device:
# Note: mT is new in array API spec (see matrix_transpose)
@property
def mT(self) -> Array:
from .linalg import matrix_transpose
from ._linear_algebra_functions import matrix_transpose
return matrix_transpose(self)

@property
Expand Down
Loading
Loading