Skip to content

Commit

Permalink
Update JAX and enable more tests
Browse files Browse the repository at this point in the history
Signed-off-by: PragmaTwice <twice@apache.org>
  • Loading branch information
PragmaTwice committed Nov 23, 2024
1 parent 50bc54a commit cd195e2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 15 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/pkgci_test_pjrt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ jobs:
# install editable into venv
source ${VENV_DIR}/bin/activate
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
# install jax (must be no larger than 0.4.20, refer to #19223)
# TODO: switch to the latest JAX after #19223 is fixed
python -m pip install jax==0.4.20 jaxlib==0.4.20 'numpy<2'
# install
python -m pip install jax==0.4.35
- name: Run tests
run: |
source ${VENV_DIR}/bin/activate
Expand Down
12 changes: 2 additions & 10 deletions build_tools/testing/run_jax_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,9 @@ diff_jax_test() {
echo "no difference found"
}

# FIXME: due to #19223, we need to use jax no higher than 0.4.20,
# but in such version of jax, 'stablehlo.broadcast_in_dim' op
# will be emitted without attribute 'broadcast_dimensions',
# which leads to an error in IREE PJRT plugin.
# So currently any program with broadcast will fail,
# e.g. test/test_simple.py.
# After #19223 is fixed, we can uncomment the line below.

# diff_jax_test test/test_simple.py

diff_jax_test test/test_add.py
diff_jax_test test/test_degenerate.py
diff_jax_test test/test_simple.py


# FIXME: we can also utilize the native test cases from JAX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_version_info():
pin_versions = dict(pin_pairs)
print(f"requirements.txt pins: {pin_versions}")
# Convert pinned versions to >= for install_requires.
for pin_name in ("iree-compiler", "jaxlib"):
for pin_name in ("iree-base-compiler", "jaxlib"):
pin_version = pin_versions[pin_name]
install_requires.append(f"{pin_name}>={pin_version}")

Expand Down
2 changes: 1 addition & 1 deletion integrations/pjrt/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-f https://iree.dev/pip-release-links.html
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
iree-compiler==20230922.653
iree-base-compiler==3.0.0
jaxlib==0.4.17.dev20230922

0 comments on commit cd195e2

Please sign in to comment.