Skip to content

Commit

Permalink
Chex: Add CI job for oldest supported JAX version
Browse files Browse the repository at this point in the history
Also bump minimum JAX version to 0.4.27, in order to pass tests.

PiperOrigin-RevId: 645438153
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Jun 21, 2024
1 parent 63edbff commit 7aec754
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ on:

jobs:
build-and-test:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version}}"
runs-on: "${{ matrix.os }}"

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]
jax-version: ["newest"]
include:
- python-version: "3.9"
os: "ubuntu-latest"
jax-version: "0.4.27" # Keep this in sync with version in pyproject.toml

steps:
- uses: "actions/checkout@v2"
Expand All @@ -26,5 +31,5 @@ jobs:
cache: "pip"
cache-dependency-path: '**/requirements*.txt'
- name: Run CI tests
run: bash test.sh
run: JAX_VERSION="${{ matrix.jax-version }}" bash test.sh
shell: bash
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
absl-py>=0.9.0
typing_extensions>=4.2.0
jax>=0.4.16
jaxlib>=0.1.37
jax>=0.4.27
jaxlib>=0.4.27
numpy>=1.24.1
setuptools;python_version>="3.12"
toolz>=0.9.0
9 changes: 9 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ pip install flake8 pytest-xdist pylint pylint-exit
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-test.txt

# Install the requested JAX version
if [ "$JAX_VERSION" = "" ]; then
: # use version installed in requirements above
elif [ "$JAX_VERSION" = "newest" ]; then
pip install -U jax jaxlib
else
pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"
fi

# Lint with flake8.
flake8 `find chex -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics

Expand Down

0 comments on commit 7aec754

Please sign in to comment.