diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 55f82d7e42..9498b2a5b3 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -268,7 +268,15 @@ def initialize_ray(cfg: DictConfig): env_vars["VLLM_USE_V1"] = "1" env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - if not peer_access_supported(cfg): + max_num_gpus_per_node = max( + [ + cfg.trainer.placement.policy_num_gpus_per_node, + cfg.trainer.placement.critic_num_gpus_per_node, + cfg.trainer.placement.ref_num_gpus_per_node, + cfg.trainer.placement.reward_num_gpus_per_node, + ] + ) + if not peer_access_supported(max_num_gpus_per_node=max_num_gpus_per_node): logger.info("Peer access is not supported on this node type, disabling P2P and SHM") env_vars["NCCL_P2P_DISABLE"] = "1" env_vars["NCCL_SHM_DISABLE"] = "1" @@ -359,17 +367,9 @@ def run_p2p_access_check(): return True -def peer_access_supported(cfg: DictConfig): +def peer_access_supported(max_num_gpus_per_node: int): # whatever the max num gpus per node is, we can check p2p access if there are at least 2 GPUs # if max is 1, p2p access is not supported - max_num_gpus_per_node = max( - [ - cfg.trainer.placement.policy_num_gpus_per_node, - cfg.trainer.placement.critic_num_gpus_per_node, - cfg.trainer.placement.ref_num_gpus_per_node, - cfg.trainer.placement.reward_num_gpus_per_node, - ] - ) if max_num_gpus_per_node <= 1: return False diff --git a/skyrl-train/tests/gpu/gpu_ci/conftest.py b/skyrl-train/tests/gpu/gpu_ci/conftest.py index 3c83fbcb9f..d726b85e5d 100644 --- a/skyrl-train/tests/gpu/gpu_ci/conftest.py +++ b/skyrl-train/tests/gpu/gpu_ci/conftest.py @@ -16,7 +16,7 @@ def ray_init_fixture(): if ray.is_initialized(): ray.shutdown() env_vars = {} - if not peer_access_supported(): + if not peer_access_supported(max_num_gpus_per_node=2): log_once("Disabling NCCL P2P for CI environment") env_vars = {"NCCL_P2P_DISABLE": "1", "NCCL_SHM_DISABLE": "1"} ray.init(runtime_env={"env_vars": env_vars})