diff --git a/tests/common/config_test.py b/tests/common/config_test.py index a7b4d63530..e51832f3a6 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -82,6 +82,15 @@ def test_config_flatten(self): self.assertIsInstance(key, str) self.assertNotIsInstance(value, dict) + def test_update_config_from_ray_cluster(self): + config = get_template_config() + config.cluster.node_num = None + config.cluster.gpu_per_node = None + + config._update_config_from_ray_cluster() + self.assertEqual(config.cluster.node_num, 2) + self.assertEqual(config.cluster.gpu_per_node, 2) + def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 228e888c83..27f5b0d455 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -131,7 +131,7 @@ def test_explorer(self): def run_serve(config): config.check_and_update() - run_stage(config, "auto") + run_stage(config) def run_agent(base_url, model_path: str): diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 649bb348a0..468ab2df53 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -142,9 +142,9 @@ def both(config: Config) -> None: } -def run_stage(config: Config, ray_address: str) -> None: +def run_stage(config: Config) -> None: ray.init( - address=ray_address, + address=config.cluster.ray_address, ignore_reinit_error=True, namespace=config.ray_namespace, runtime_env={"env_vars": config.get_envs()}, @@ -168,11 +168,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): load_plugins() config = load_config(config_path) - ray_address = "auto" - if dlc: cluster_namespace = f"{config.project}-{config.name}" - ray_address = setup_ray_cluster(namespace=cluster_namespace) + config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) if not is_running(): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") @@ -203,7 +201,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if prev_stage_checkpoint is not None: stage_config.model.model_path = prev_stage_checkpoint stage_config.check_and_update() - run_stage(stage_config, ray_address=ray_address) + run_stage(stage_config) logger.info( "===========================================================\n" f"> Stage {i + 1}/{len(config.stages)} finished.\n" @@ -212,7 +210,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config) else: config.check_and_update() - run_stage(config, ray_address=ray_address) + run_stage(config) finally: if dlc: diff --git a/trinity/common/config.py b/trinity/common/config.py index 228eb54c4c..56e03c8d0e 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Any, Dict, List, Optional +import ray from omegaconf import OmegaConf from trinity.common.constants import ( @@ -360,8 +361,9 @@ class AlgorithmConfig: class ClusterConfig: """Config for the cluster.""" - node_num: int = 1 - gpu_per_node: int = 8 + ray_address: str = "auto" + node_num: Optional[int] = None + gpu_per_node: Optional[int] = None @Experimental @@ -611,6 +613,44 @@ def _check_deprecated(self) -> None: "`explorer.runner_num` is deprecated, please use `explorer.runner_per_model` instead." ) + def _update_config_from_ray_cluster(self) -> None: + """Update config if `node_num` or `gpu_per_node` are not set.""" + if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None: + return + + # init ray cluster to detect node_num and gpu_per_node + was_initialized = ray.is_initialized() + if not was_initialized: + ray.init( + address=self.cluster.ray_address, + ignore_reinit_error=True, + namespace=self.ray_namespace, + ) + + alive_nodes = [n for n in ray.nodes() if n["alive"]] + if not alive_nodes: + raise RuntimeError("Could not find any alive nodes in the Ray cluster.") + + # set node_num + if self.cluster.node_num is None: + self.cluster.node_num = len(alive_nodes) + logger.info(f"Auto-detected and set node_num: {self.cluster.node_num}") + + # set gpu_per_node + if self.cluster.gpu_per_node is None: + gpu_per_node = 0 + for node in alive_nodes: + node_gpus = node.get("Resources", {}).get("GPU") + if node_gpus and node_gpus > 0: + gpu_per_node = int(node_gpus) + break + + self.cluster.gpu_per_node = gpu_per_node + logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}") + + if not was_initialized: + ray.shutdown() + def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 @@ -901,6 +941,9 @@ def check_and_update(self) -> Config: # noqa: C901 if self.ray_namespace is None or len(self.ray_namespace) == 0: self.ray_namespace = f"{self.project}/{self.name}" + # check cluster infomation + self._update_config_from_ray_cluster() + # check algorithm self._check_algorithm() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eb360266b4..7f81e06726 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -358,6 +358,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 else: rollout_gpu_num = 0 + assert config.cluster.node_num is not None + assert config.cluster.gpu_per_node is not None if config.cluster.node_num == 1: # for single node scenarios, rollout and training are on the same node self.trainer.nnodes = config.cluster.node_num