diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 9aab1b8459d..a877b7d8fd7 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -54,7 +54,7 @@ test_jax() { python -m pip uninstall -y jax jaxlib python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - ${python_new_invoke_test} -s jax test_integration_multigpu + CUDA_VISIBLE_DEVICES="0,1" ${python_new_invoke_test} -s jax test_integration_multigpu CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py