Skip to content

Commit 271d31c

Browse files
committed
Add jax.experimental.array_api interface
1 parent d60014c commit 271d31c

25 files changed

+1777
-1
lines changed

Diff for: .github/workflows/jax-array-api.yml

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: JAX Array API
2+
3+
on:
4+
workflow_dispatch: # allows triggering the workflow run manually
5+
pull_request: # Automatically trigger on pull requests affecting particular files
6+
branches:
7+
- main
8+
paths:
9+
- '**workflows/jax-array-api.yml'
10+
- '**experimental/array_api/**'
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: [3.11]
19+
20+
steps:
21+
- name: Checkout jax
22+
uses: actions/checkout@v3
23+
- name: Checkout array-api-tests
24+
uses: actions/checkout@v3
25+
with:
26+
repository: data-apis/array-api-tests
27+
ref: '83f0bcdcc5286250dbb26be5d37511702970b4dc' # Latest commit as of 2023-11-15
28+
submodules: 'true'
29+
path: 'array-api-tests'
30+
- name: Fix array-apis bug
31+
# Temporary workaround for https://github.com/data-apis/array-api/issues/631
32+
run: |
33+
sed -i -e 's/\\/\\\\/g' array-api-tests/array-api/spec/API_specification/signatures/*.py
34+
- name: Set up Python ${{ matrix.python-version }}
35+
uses: actions/setup-python@v1
36+
with:
37+
python-version: ${{ matrix.python-version }}
38+
- name: Install dependencies
39+
run: |
40+
python -m pip install .[cpu]
41+
python -m pip install "hypothesis<6.88.4" # 6.88.4 breaks with a Return-type annotation warning
42+
python -m pip install -r array-api-tests/requirements.txt
43+
- name: Run the test suite
44+
env:
45+
ARRAY_API_TESTS_MODULE: jax.experimental.array_api
46+
JAX_ENABLE_X64: 'true'
47+
run: |
48+
cd ${GITHUB_WORKSPACE}/array-api-tests
49+
pytest --ci array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt

Diff for: docs/jax.experimental.array_api.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``jax.experimental.array_api`` module
2+
=====================================
3+
4+
.. automodule:: jax.experimental.array_api

Diff for: docs/jax.experimental.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Experimental Modules
1414
.. toctree::
1515
:maxdepth: 1
1616

17+
jax.experimental.array_api
1718
jax.experimental.checkify
1819
jax.experimental.host_callback
1920
jax.experimental.maps

Diff for: jax/BUILD

+12-1
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,17 @@ pytype_library(
841841
deps = [":jax"],
842842
)
843843

844+
pytype_library(
845+
name = "experimental_array_api",
846+
srcs = glob(
847+
[
848+
"experimental/array_api/*.py",
849+
],
850+
),
851+
visibility = [":internal"],
852+
deps = [":jax"],
853+
)
854+
844855
pytype_library(
845856
name = "experimental_sparse",
846857
srcs = glob(
@@ -873,7 +884,7 @@ pytype_library(
873884
"example_libraries/optimizers.py",
874885
],
875886
visibility = ["//visibility:public"],
876-
deps = [":jax"],
887+
deps = [":jax"] + py_deps("numpy"),
877888
)
878889

879890
pytype_library(

Diff for: jax/experimental/array_api/__init__.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module includes experimental JAX support for the `Python array API standard`_.
17+
Support for this is currently experimental and not fully complete.
18+
19+
Example Usage::
20+
21+
>>> from jax.experimental import array_api as xp
22+
23+
>>> xp.__array_api_version__
24+
'2022.12'
25+
26+
>>> arr = xp.arange(1000)
27+
28+
>>> arr.sum()
29+
Array(499500, dtype=int32)
30+
31+
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
32+
and implements most of the API listed in the standard.
33+
34+
.. _Python array API standard: https://data-apis.org/array-api/latest/
35+
"""
36+
37+
from __future__ import annotations
38+
39+
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
40+
41+
from jax.experimental.array_api import (
42+
fft as fft,
43+
linalg as linalg,
44+
)
45+
46+
from jax.experimental.array_api._constants import (
47+
e as e,
48+
inf as inf,
49+
nan as nan,
50+
newaxis as newaxis,
51+
pi as pi,
52+
)
53+
54+
from jax.experimental.array_api._creation_functions import (
55+
arange as arange,
56+
asarray as asarray,
57+
empty as empty,
58+
empty_like as empty_like,
59+
eye as eye,
60+
from_dlpack as from_dlpack,
61+
full as full,
62+
full_like as full_like,
63+
linspace as linspace,
64+
meshgrid as meshgrid,
65+
ones as ones,
66+
ones_like as ones_like,
67+
tril as tril,
68+
triu as triu,
69+
zeros as zeros,
70+
zeros_like as zeros_like,
71+
)
72+
73+
from jax.experimental.array_api._data_type_functions import (
74+
astype as astype,
75+
can_cast as can_cast,
76+
finfo as finfo,
77+
iinfo as iinfo,
78+
isdtype as isdtype,
79+
result_type as result_type,
80+
)
81+
82+
from jax.experimental.array_api._dtypes import (
83+
bool as bool,
84+
int8 as int8,
85+
int16 as int16,
86+
int32 as int32,
87+
int64 as int64,
88+
uint8 as uint8,
89+
uint16 as uint16,
90+
uint32 as uint32,
91+
uint64 as uint64,
92+
float32 as float32,
93+
float64 as float64,
94+
complex64 as complex64,
95+
complex128 as complex128,
96+
)
97+
98+
from jax.experimental.array_api._elementwise_functions import (
99+
abs as abs,
100+
acos as acos,
101+
acosh as acosh,
102+
add as add,
103+
asin as asin,
104+
asinh as asinh,
105+
atan as atan,
106+
atan2 as atan2,
107+
atanh as atanh,
108+
bitwise_and as bitwise_and,
109+
bitwise_invert as bitwise_invert,
110+
bitwise_left_shift as bitwise_left_shift,
111+
bitwise_or as bitwise_or,
112+
bitwise_right_shift as bitwise_right_shift,
113+
bitwise_xor as bitwise_xor,
114+
ceil as ceil,
115+
conj as conj,
116+
cos as cos,
117+
cosh as cosh,
118+
divide as divide,
119+
equal as equal,
120+
exp as exp,
121+
expm1 as expm1,
122+
floor as floor,
123+
floor_divide as floor_divide,
124+
greater as greater,
125+
greater_equal as greater_equal,
126+
imag as imag,
127+
isfinite as isfinite,
128+
isinf as isinf,
129+
isnan as isnan,
130+
less as less,
131+
less_equal as less_equal,
132+
log as log,
133+
log10 as log10,
134+
log1p as log1p,
135+
log2 as log2,
136+
logaddexp as logaddexp,
137+
logical_and as logical_and,
138+
logical_not as logical_not,
139+
logical_or as logical_or,
140+
logical_xor as logical_xor,
141+
multiply as multiply,
142+
negative as negative,
143+
not_equal as not_equal,
144+
positive as positive,
145+
pow as pow,
146+
real as real,
147+
remainder as remainder,
148+
round as round,
149+
sign as sign,
150+
sin as sin,
151+
sinh as sinh,
152+
sqrt as sqrt,
153+
square as square,
154+
subtract as subtract,
155+
tan as tan,
156+
tanh as tanh,
157+
trunc as trunc,
158+
)
159+
160+
from jax.experimental.array_api._indexing_functions import (
161+
take as take,
162+
)
163+
164+
from jax.experimental.array_api._manipulation_functions import (
165+
broadcast_arrays as broadcast_arrays,
166+
broadcast_to as broadcast_to,
167+
concat as concat,
168+
expand_dims as expand_dims,
169+
flip as flip,
170+
permute_dims as permute_dims,
171+
reshape as reshape,
172+
roll as roll,
173+
squeeze as squeeze,
174+
stack as stack,
175+
)
176+
177+
from jax.experimental.array_api._searching_functions import (
178+
argmax as argmax,
179+
argmin as argmin,
180+
nonzero as nonzero,
181+
where as where,
182+
)
183+
184+
from jax.experimental.array_api._set_functions import (
185+
unique_all as unique_all,
186+
unique_counts as unique_counts,
187+
unique_inverse as unique_inverse,
188+
unique_values as unique_values,
189+
)
190+
191+
from jax.experimental.array_api._sorting_functions import (
192+
argsort as argsort,
193+
sort as sort,
194+
)
195+
196+
from jax.experimental.array_api._statistical_functions import (
197+
max as max,
198+
mean as mean,
199+
min as min,
200+
prod as prod,
201+
std as std,
202+
sum as sum,
203+
var as var
204+
)
205+
206+
from jax.experimental.array_api._utility_functions import (
207+
all as all,
208+
any as any,
209+
)
210+
211+
from jax.experimental.array_api._linear_algebra_functions import (
212+
matmul as matmul,
213+
matrix_transpose as matrix_transpose,
214+
tensordot as tensordot,
215+
vecdot as vecdot,
216+
)
217+
218+
from jax.experimental.array_api import _array_methods
219+
_array_methods.add_array_object_methods()
220+
del _array_methods

Diff for: jax/experimental/array_api/_array_methods.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any, Callable, Optional, Union
18+
19+
import jax
20+
from jax._src.array import ArrayImpl
21+
from jax.experimental.array_api._version import __array_api_version__
22+
23+
from jax._src.lib import xla_extension as xe
24+
25+
26+
def _array_namespace(self, /, *, api_version: None | str = None):
27+
if api_version is not None and api_version != __array_api_version__:
28+
raise ValueError(f"{api_version=!r} is not available; "
29+
f"available versions are: {[__array_api_version__]}")
30+
return jax.experimental.array_api
31+
32+
33+
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
34+
stream: Optional[Union[int, Any]] = None):
35+
if stream is not None:
36+
raise NotImplementedError("stream argument of array.to_device()")
37+
# The type of device is defined by Array.device. In JAX, this is a callable that
38+
# returns a device, so we must handle this case to satisfy the API spec.
39+
return jax.device_put(self, device() if callable(device) else device)
40+
41+
42+
def add_array_object_methods():
43+
# TODO(jakevdp): set on tracers as well?
44+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
45+
setattr(ArrayImpl, "to_device", _to_device)

Diff for: jax/experimental/array_api/_constants.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
e = np.e
18+
inf = np.inf
19+
nan = np.nan
20+
newaxis = np.newaxis
21+
pi = np.pi

0 commit comments

Comments
 (0)