Skip to content

Commit ef5bed7

Browse files
committed
Add jax.experimental.array_api interface.
1 parent 840b5c5 commit ef5bed7

23 files changed

+1655
-1
lines changed

Diff for: .github/workflows/jax-array-api.yml

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: JAX Array API
2+
3+
on:
4+
workflow_dispatch: # allows triggering the workflow run manually
5+
pull_request: # Automatically trigger on pull requests affecting particular files
6+
branches:
7+
- main
8+
paths:
9+
- '**workflows/jax-array-api.yml'
10+
- '**experimental/array_api/**'
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: [3.11]
19+
20+
steps:
21+
- name: Checkout jax
22+
uses: actions/checkout@v3
23+
- name: Checkout array-api-tests
24+
uses: actions/checkout@v3
25+
with:
26+
repository: data-apis/array-api-tests
27+
ref: '83f0bcdcc5286250dbb26be5d37511702970b4dc' # Latest commit as of 2023-11-15
28+
submodules: 'true'
29+
path: 'array-api-tests'
30+
- name: Fix array-apis bug
31+
# Temporary workaround for https://github.com/data-apis/array-api/issues/631
32+
run: |
33+
sed -i -e 's/\\/\\\\/g' array-api-tests/array-api/spec/API_specification/signatures/*.py
34+
- name: Set up Python ${{ matrix.python-version }}
35+
uses: actions/setup-python@v1
36+
with:
37+
python-version: ${{ matrix.python-version }}
38+
- name: Install dependencies
39+
run: |
40+
python -m pip install .[cpu]
41+
python -m pip install hypothesis!=6.88.4 # 6.88.4 leads to a strange error
42+
python -m pip install -r array-api-tests/requirements.txt
43+
- name: Run the test suite
44+
env:
45+
ARRAY_API_TESTS_MODULE: jax.experimental.array_api
46+
JAX_ENABLE_X64: 'true'
47+
run: |
48+
cd ${GITHUB_WORKSPACE}/array-api-tests
49+
pytest --ci array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/array-api-skips.txt

Diff for: array-api-skips.txt

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Known failures for the array api tests.
2+
3+
# JAX doesn't yet support scalar boolean indexing
4+
array_api_tests/test_array_object.py::test_getitem_masking
5+
6+
# Hypothesis warning
7+
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
8+
9+
# Test suite attempts in-place mutation:
10+
array_api_tests/test_special_cases.py::test_binary
11+
array_api_tests/test_special_cases.py::test_iop
12+
array_api_tests/test_special_cases.py::test_nan_propagation
13+
array_api_tests/test_special_cases.py::test_unary
14+
array_api_tests/test_array_object.py::test_setitem
15+
array_api_tests/test_creation_functions.py::test_asarray_arrays
16+
array_api_tests/test_linalg.py::test_matrix_power
17+
array_api_tests/test_linalg.py::test_solve
18+
19+
# Overflow errors due to hypothesis generating integers that overflow int64
20+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
21+
array_api_tests/test_operators_and_elementwise_functions.py::test_square
22+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)]
23+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
24+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)]
25+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
26+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
27+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
28+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)]
29+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
30+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)]
31+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
32+
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
33+
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
34+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x, s)]
35+
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x, s)]
36+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]
37+
38+
# JAX's NaN sorting doesn't match specification
39+
array_api_tests/test_set_functions.py::test_unique_all
40+
array_api_tests/test_set_functions.py::test_unique_counts
41+
array_api_tests/test_set_functions.py::test_unique_inverse
42+
array_api_tests/test_set_functions.py::test_unique_values
43+
array_api_tests/test_sorting_functions.py::test_argsort

Diff for: docs/jax.experimental.array_api.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``jax.experimental.array_api`` module
2+
=====================================
3+
4+
.. automodule:: jax.experimental.array_api

Diff for: docs/jax.experimental.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Experimental Modules
1414
.. toctree::
1515
:maxdepth: 1
1616

17+
jax.experimental.array_api
1718
jax.experimental.checkify
1819
jax.experimental.host_callback
1920
jax.experimental.maps

Diff for: jax/BUILD

+12-1
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,17 @@ pytype_library(
842842
deps = [":jax"],
843843
)
844844

845+
pytype_library(
846+
name = "experimental_array_api",
847+
srcs = glob(
848+
[
849+
"experimental/array_api/*.py",
850+
],
851+
),
852+
visibility = [":internal"],
853+
deps = [":jax"],
854+
)
855+
845856
pytype_library(
846857
name = "experimental_sparse",
847858
srcs = glob(
@@ -874,7 +885,7 @@ pytype_library(
874885
"example_libraries/optimizers.py",
875886
],
876887
visibility = ["//visibility:public"],
877-
deps = [":jax"],
888+
deps = [":jax"] + py_deps("numpy"),
878889
)
879890

880891
pytype_library(

Diff for: jax/experimental/array_api/__init__.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module includes experimental JAX support for the `Python array API standard`_.
17+
Support for this is currently experimental and not fully complete.
18+
19+
Example Usage::
20+
21+
>>> from jax.experimental import array_api as xp
22+
23+
>>> xp.__array_api_version__
24+
'2022.12'
25+
26+
>>> arr = xp.arange(1000)
27+
28+
>>> arr.sum()
29+
Array(499500, dtype=int32)
30+
31+
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
32+
and implements most of the API listed in the standard.
33+
34+
.. _Python array API standard: https://data-apis.org/array-api/latest/
35+
"""
36+
37+
from __future__ import annotations
38+
39+
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
40+
41+
from jax.experimental.array_api import linalg as linalg
42+
43+
from jax.experimental.array_api._constants import (
44+
e as e,
45+
inf as inf,
46+
nan as nan,
47+
newaxis as newaxis,
48+
pi as pi,
49+
)
50+
51+
from jax.experimental.array_api._creation_functions import (
52+
arange as arange,
53+
asarray as asarray,
54+
empty as empty,
55+
empty_like as empty_like,
56+
eye as eye,
57+
from_dlpack as from_dlpack,
58+
full as full,
59+
full_like as full_like,
60+
linspace as linspace,
61+
meshgrid as meshgrid,
62+
ones as ones,
63+
ones_like as ones_like,
64+
tril as tril,
65+
triu as triu,
66+
zeros as zeros,
67+
zeros_like as zeros_like,
68+
)
69+
70+
from jax.experimental.array_api._data_type_functions import (
71+
astype as astype,
72+
can_cast as can_cast,
73+
finfo as finfo,
74+
iinfo as iinfo,
75+
isdtype as isdtype,
76+
result_type as result_type,
77+
)
78+
79+
from jax.experimental.array_api._dtypes import (
80+
bool as bool,
81+
int8 as int8,
82+
int16 as int16,
83+
int32 as int32,
84+
int64 as int64,
85+
uint8 as uint8,
86+
uint16 as uint16,
87+
uint32 as uint32,
88+
uint64 as uint64,
89+
float32 as float32,
90+
float64 as float64,
91+
complex64 as complex64,
92+
complex128 as complex128,
93+
)
94+
95+
from jax.experimental.array_api._elementwise_functions import (
96+
abs as abs,
97+
acos as acos,
98+
acosh as acosh,
99+
add as add,
100+
asin as asin,
101+
asinh as asinh,
102+
atan as atan,
103+
atan2 as atan2,
104+
atanh as atanh,
105+
bitwise_and as bitwise_and,
106+
bitwise_invert as bitwise_invert,
107+
bitwise_left_shift as bitwise_left_shift,
108+
bitwise_or as bitwise_or,
109+
bitwise_right_shift as bitwise_right_shift,
110+
bitwise_xor as bitwise_xor,
111+
ceil as ceil,
112+
conj as conj,
113+
cos as cos,
114+
cosh as cosh,
115+
divide as divide,
116+
equal as equal,
117+
exp as exp,
118+
expm1 as expm1,
119+
floor as floor,
120+
floor_divide as floor_divide,
121+
greater as greater,
122+
greater_equal as greater_equal,
123+
imag as imag,
124+
isfinite as isfinite,
125+
isinf as isinf,
126+
isnan as isnan,
127+
less as less,
128+
less_equal as less_equal,
129+
log as log,
130+
log10 as log10,
131+
log1p as log1p,
132+
log2 as log2,
133+
logaddexp as logaddexp,
134+
logical_and as logical_and,
135+
logical_not as logical_not,
136+
logical_or as logical_or,
137+
logical_xor as logical_xor,
138+
multiply as multiply,
139+
negative as negative,
140+
not_equal as not_equal,
141+
positive as positive,
142+
pow as pow,
143+
real as real,
144+
remainder as remainder,
145+
round as round,
146+
sign as sign,
147+
sin as sin,
148+
sinh as sinh,
149+
sqrt as sqrt,
150+
square as square,
151+
subtract as subtract,
152+
tan as tan,
153+
tanh as tanh,
154+
trunc as trunc,
155+
)
156+
157+
from jax.experimental.array_api._indexing_functions import (
158+
take as take,
159+
)
160+
161+
from jax.experimental.array_api._manipulation_functions import (
162+
broadcast_arrays as broadcast_arrays,
163+
broadcast_to as broadcast_to,
164+
concat as concat,
165+
expand_dims as expand_dims,
166+
flip as flip,
167+
permute_dims as permute_dims,
168+
reshape as reshape,
169+
roll as roll,
170+
squeeze as squeeze,
171+
stack as stack,
172+
)
173+
174+
from jax.experimental.array_api._searching_functions import (
175+
argmax as argmax,
176+
argmin as argmin,
177+
nonzero as nonzero,
178+
where as where,
179+
)
180+
181+
from jax.experimental.array_api._set_functions import (
182+
unique_all as unique_all,
183+
unique_counts as unique_counts,
184+
unique_inverse as unique_inverse,
185+
unique_values as unique_values,
186+
)
187+
188+
from jax.experimental.array_api._sorting_functions import (
189+
argsort as argsort,
190+
sort as sort,
191+
)
192+
193+
from jax.experimental.array_api._statistical_functions import (
194+
max as max,
195+
mean as mean,
196+
min as min,
197+
prod as prod,
198+
std as std,
199+
sum as sum,
200+
var as var
201+
)
202+
203+
from jax.experimental.array_api._utility_functions import (
204+
all as all,
205+
any as any,
206+
)
207+
208+
from jax.experimental.array_api._linear_algebra_functions import (
209+
matmul as matmul,
210+
matrix_transpose as matrix_transpose,
211+
tensordot as tensordot,
212+
vecdot as vecdot,
213+
)
214+
215+
from jax.experimental.array_api import _array_methods
216+
_array_methods.add_array_object_methods()
217+
del _array_methods

Diff for: jax/experimental/array_api/_array_methods.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any, Callable, Optional, Union
18+
19+
import jax
20+
from jax._src.array import ArrayImpl
21+
from jax.experimental.array_api._version import __array_api_version__
22+
23+
from jax._src.lib import xla_extension as xe
24+
25+
26+
def _array_namespace(self, /, *, api_version: None | str = None):
27+
if api_version is not None and api_version != __array_api_version__:
28+
raise ValueError(f"{api_version=!r} is not available; "
29+
f"available versions are: {[__array_api_version__]}")
30+
return jax.experimental.array_api
31+
32+
33+
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
34+
stream: Optional[Union[int, Any]] = None):
35+
if stream is not None:
36+
raise NotImplementedError("stream argument of array.to_device()")
37+
# The type of device is defined by Array.device. In JAX, this is a callable that
38+
# returns a device, so we must handle this case to satisfy the API spec.
39+
return jax.device_put(self, device() if callable(device) else device)
40+
41+
42+
def add_array_object_methods():
43+
# TODO(jakevdp): set on tracers as well?
44+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
45+
setattr(ArrayImpl, "to_device", _to_device)

0 commit comments

Comments
 (0)