Skip to content

Commit b6e0f62

Browse files
liulehuiKaisennHu
authored andcommitted
[train][jax] Skip py3.12 for jax test (ray-project#58979)
1. Jax dependency is introduced in ray-project#58322 2. The current test environment is for CUDA 12.1, which limit jax version below 0.4.14. 3. jax <= 0.4.14 does not support py 3.12. 4. skip jax test if it runs against py3.12+. Signed-off-by: Lehui Liu <lehui@anyscale.com>
1 parent 35dfbe2 commit b6e0f62

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

python/ray/train/v2/tests/test_jax_gpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def reduce_health_check_interval(monkeypatch):
1919

2020

2121
@pytest.mark.skipif(sys.platform == "darwin", reason="JAX GPU not supported on macOS")
22+
@pytest.mark.skipif(
23+
sys.version_info >= (3, 12),
24+
reason="Current jax version is not supported in python 3.12+",
25+
)
2226
def test_jax_distributed_gpu_training(ray_start_4_cpus_2_gpus, tmp_path):
2327
"""Test multi-GPU JAX distributed training.
2428

python/ray/train/v2/tests/test_jax_trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
import pytest
24

35
import ray
@@ -71,6 +73,10 @@ def train_func():
7173
train.report({"result": [str(d) for d in devices]})
7274

7375

76+
@pytest.mark.skipif(
77+
sys.version_info >= (3, 12),
78+
reason="Current jax version is not supported in python 3.12+",
79+
)
7480
def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
7581
trainer = JaxTrainer(
7682
train_loop_per_worker=train_func,
@@ -101,6 +107,10 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
101107
assert len(labeled_nodes) == 1
102108

103109

110+
@pytest.mark.skipif(
111+
sys.version_info >= (3, 12),
112+
reason="Current jax version is not supported in python 3.12+",
113+
)
104114
def test_minimal_multihost(ray_tpu_multi_host, tmp_path):
105115
trainer = JaxTrainer(
106116
train_loop_per_worker=train_func,

0 commit comments

Comments
 (0)