diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 00202e857..28adca005 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,13 +10,18 @@ on: jobs: build-and-test: - name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" + name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version }}" runs-on: "${{ matrix.os }}" strategy: matrix: python-version: ["3.9", "3.10", "3.11"] os: [ubuntu-latest] + jax-version: [newest] + include: + - python-version: "3.9" + os: "ubuntu-latest" + jax-version: "0.4.27" # Keep this in sync with version in pyproject.toml steps: - uses: "actions/checkout@v2" @@ -26,7 +31,7 @@ jobs: cache: "pip" cache-dependency-path: 'pyproject.toml' - name: Run CI tests - run: bash test.sh + run: JAX_VERSION="${{ matrix.jax-version }}" bash test.sh shell: bash markdown-link-check: name: "Check links in markdown files" diff --git a/pyproject.toml b/pyproject.toml index 040d7b7de..775f8903d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,8 @@ classifiers = [ dependencies = [ "absl-py>=0.7.1", "chex>=0.1.86", - "jax>=0.4.26", - "jaxlib>=0.1.37", + "jax>=0.4.27", # Keep this in sync with version in .github/workflows/tests.yml + "jaxlib>=0.4.27", "numpy>=1.18.0", "etils[epy]", ] diff --git a/test.sh b/test.sh index cea1657e5..6f536f8e1 100755 --- a/test.sh +++ b/test.sh @@ -41,6 +41,15 @@ pip install -q -e ".[test, examples]" pip install -q -e ".[dp-accounting]" pip install -q "dp-accounting>=0.1.1" --no-deps +# Install the requested JAX version +if [ "$JAX_VERSION" = "" ]; then + : # use version installed in requirements above +elif [ "$JAX_VERSION" = "newest" ]; then + pip install -U jax jaxlib +else + pip install "jax==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" +fi + # Ensure optax was not installed by one of the dependencies above, # since if it is, the tests below will be run against that version instead of # the branch build.