Skip to content

Conversation

@liulehui
Copy link
Contributor

@liulehui liulehui commented Nov 25, 2025

Description

  1. Jax dependency is introduced in [train][jax] Enable Jax trainer on GPU #58322
  2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14.
  3. jax <= 0.4.14 does not support py 3.12.
  4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui requested a review from a team as a code owner November 25, 2025 19:35
@liulehui liulehui requested a review from elliot-barn November 25, 2025 19:35
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds pytest.mark.skipif decorators to JAX tests to skip them on Python 3.12 and newer, as the required JAX version is not compatible. The changes are correct and effectively address the issue. I've added one suggestion to refactor the duplicated skipif condition in test_jax_trainer.py into module-level constants to improve code maintainability.

@elliot-barn elliot-barn requested a review from aslonnie November 26, 2025 01:00
@elliot-barn
Copy link
Contributor

Thanks for the quick PR. I'll create a set of dependencies with a more recent jax version for the py3.12 tests in the near future

@justinvyu justinvyu enabled auto-merge (squash) November 26, 2025 01:15
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Nov 26, 2025
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Nov 26, 2025
@justinvyu justinvyu merged commit 5e206d8 into ray-project:master Nov 26, 2025
7 checks passed
KaisennHu pushed a commit to KaisennHu/ray that referenced this pull request Nov 26, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
SheldonTsen pushed a commit to SheldonTsen/ray that referenced this pull request Dec 1, 2025
1. Jax dependency is introduced in
ray-project#58322
2. The current test environment is for CUDA 12.1, which limit jax
version below 0.4.14.
3. jax <= 0.4.14 does not support py 3.12.
4. skip jax test if it runs against py3.12+.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants