From ebc43f025e3a8faf8ef4d7e54845ca50e0fda401 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 16 Oct 2025 15:04:18 +0800 Subject: [PATCH 1/5] update config from ray cluster --- trinity/cli/launcher.py | 8 +++---- trinity/common/config.py | 44 +++++++++++++++++++++++++++++++++-- trinity/common/verl_config.py | 2 ++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 649bb348a0..024e2fee3f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -168,11 +168,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): load_plugins() config = load_config(config_path) - ray_address = "auto" - if dlc: cluster_namespace = f"{config.project}-{config.name}" - ray_address = setup_ray_cluster(namespace=cluster_namespace) + config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) if not is_running(): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") @@ -203,7 +201,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if prev_stage_checkpoint is not None: stage_config.model.model_path = prev_stage_checkpoint stage_config.check_and_update() - run_stage(stage_config, ray_address=ray_address) + run_stage(stage_config, ray_address=config.cluster.ray_address) logger.info( "===========================================================\n" f"> Stage {i + 1}/{len(config.stages)} finished.\n" @@ -212,7 +210,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config) else: config.check_and_update() - run_stage(config, ray_address=ray_address) + run_stage(config, ray_address=config.cluster.ray_address) finally: if dlc: diff --git a/trinity/common/config.py b/trinity/common/config.py index 228eb54c4c..f3af6b6aa1 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Any, Dict, List, Optional +import ray from omegaconf import OmegaConf from trinity.common.constants import ( @@ -360,8 +361,9 @@ class AlgorithmConfig: class ClusterConfig: """Config for the cluster.""" - node_num: int = 1 - gpu_per_node: int = 8 + ray_address: str = "auto" + node_num: Optional[int] = None + gpu_per_node: Optional[int] = None @Experimental @@ -611,6 +613,41 @@ def _check_deprecated(self) -> None: "`explorer.runner_num` is deprecated, please use `explorer.runner_per_model` instead." ) + def _update_config_from_ray_cluster(self) -> None: + if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None: + return + + # init ray cluster to detect node_num and gpu_per_node + ray.init( + address=self.cluster.ray_address, + ignore_reinit_error=True, + namespace=self.ray_namespace, + runtime_env={"env_vars": self.get_envs()}, + ) + + alive_nodes = [n for n in ray.nodes() if n["alive"]] + if not alive_nodes: + raise RuntimeError("Could not find any alive nodes in the Ray cluster.") + + # set node_num + if self.cluster.node_num is None: + self.cluster.node_num = len(alive_nodes) + logger.info(f"Auto-detected and set node_num: {self.cluster.node_num}") + + # set gpu_per_node + if self.cluster.gpu_per_node is None: + gpu_per_node = 0 + for node in alive_nodes: + node_gpus = node.get("Resources", {}).get("GPU") + if node_gpus and node_gpus > 0: + gpu_per_node = int(node_gpus) + break + + self.cluster.gpu_per_node = gpu_per_node + logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}") + + ray.shutdown() + def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 @@ -901,6 +938,9 @@ def check_and_update(self) -> Config: # noqa: C901 if self.ray_namespace is None or len(self.ray_namespace) == 0: self.ray_namespace = f"{self.project}/{self.name}" + # check cluster infomation + self._update_config_from_ray_cluster() + # check algorithm self._check_algorithm() diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eb360266b4..7f81e06726 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -358,6 +358,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 else: rollout_gpu_num = 0 + assert config.cluster.node_num is not None + assert config.cluster.gpu_per_node is not None if config.cluster.node_num == 1: # for single node scenarios, rollout and training are on the same node self.trainer.nnodes = config.cluster.node_num From 29dc2e37ecd1aceb355976837a9cbd8117e2341b Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 16 Oct 2025 16:11:54 +0800 Subject: [PATCH 2/5] add dlc config and a unittest --- tests/common/config_test.py | 9 +++++++++ trinity/cli/launcher.py | 15 ++++++++++----- trinity/common/config.py | 1 + trinity/utils/dlc_utils.py | 29 +++++++++++++++++++++++++---- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index a7b4d63530..147eaa8659 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -82,6 +82,15 @@ def test_config_flatten(self): self.assertIsInstance(key, str) self.assertNotIsInstance(value, dict) + def test_update_config_from_ray_cluster(self): + config = get_template_config() + config.cluster.node_num = None + config.cluster.gpu_per_node = None + + config._update_config_from_ray_cluster() + self.assertTrue(config.cluster.node_num > 0) + self.assertTrue(config.cluster.gpu_per_node > 0) + def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 024e2fee3f..a34370803d 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -142,9 +142,9 @@ def both(config: Config) -> None: } -def run_stage(config: Config, ray_address: str) -> None: +def run_stage(config: Config) -> None: ray.init( - address=ray_address, + address=config.cluster.ray_address, ignore_reinit_error=True, namespace=config.ray_namespace, runtime_env={"env_vars": config.get_envs()}, @@ -170,7 +170,12 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if dlc: cluster_namespace = f"{config.project}-{config.name}" - config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) + cluster_information = setup_ray_cluster(namespace=cluster_namespace) + config.cluster.ray_address = cluster_information["ray_address"] + if config.cluster.node_num is None: + config.cluster.node_num = cluster_information.get("node_num") + if config.cluster.gpu_per_node is None: + config.cluster.gpu_per_node = cluster_information.get("gpu_per_node") if not is_running(): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") @@ -201,7 +206,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if prev_stage_checkpoint is not None: stage_config.model.model_path = prev_stage_checkpoint stage_config.check_and_update() - run_stage(stage_config, ray_address=config.cluster.ray_address) + run_stage(stage_config) logger.info( "===========================================================\n" f"> Stage {i + 1}/{len(config.stages)} finished.\n" @@ -210,7 +215,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config) else: config.check_and_update() - run_stage(config, ray_address=config.cluster.ray_address) + run_stage(config) finally: if dlc: diff --git a/trinity/common/config.py b/trinity/common/config.py index f3af6b6aa1..2a6794ed20 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -614,6 +614,7 @@ def _check_deprecated(self) -> None: ) def _update_config_from_ray_cluster(self) -> None: + """Update config if `node_num` or `gpu_per_node` are not set.""" if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None: return diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 63cd874115..0a4ecfc64d 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -2,6 +2,7 @@ import subprocess import sys import time +from typing import Dict import ray @@ -64,20 +65,23 @@ def wait_for_ray_worker_nodes(world_size: int) -> None: time.sleep(1) -def setup_ray_cluster(namespace: str) -> str: +def setup_ray_cluster(namespace: str) -> Dict: """Setup a ray cluster in DLC environment. This function will start a ray cluster if it is not running, otherwise it will reuse the existing ray cluster. Returns: - str: The address of the ray cluster. + Dict: + - ray_address: The address of the ray cluster. + - node_num (Optional): The world size of the ray cluster. + - gpu_per_node (Optional): The number of GPUs per node. """ env_vars = get_dlc_env_vars() is_master = env_vars["RANK"] == 0 if is_running(): # reuse existing ray cluster - return "auto" + return {"ray_address": "auto"} else: if is_master: cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}" @@ -98,11 +102,28 @@ def setup_ray_cluster(namespace: str) -> str: namespace=namespace, ignore_reinit_error=True, ) + + # get gpu_per_node from enviroment variables + gpu_per_node = None + if "PAI_GPU_COUNT" in os.environ: + try: + gpu_per_node = int(os.environ["PAI_GPU_COUNT"]) + except ValueError: + logger.warning("Could not parse PAI_GPU_COUNT as an integer.") + elif "CUDA_VISIBLE_DEVICES" in os.environ: + visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] + if visible_devices: + gpu_per_node = len(visible_devices.split(",")) + if is_master: # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) ray.shutdown() - return f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" + return { + "ray_address": f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", + "node_num": env_vars["WORLD_SIZE"], + "gpu_per_node": gpu_per_node, + } else: # worker wait on the cluster status actor cluster_status = ( From 1200a79a6badb99d044088500820227bfd23ec61 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 16 Oct 2025 17:10:24 +0800 Subject: [PATCH 3/5] get gpu_per_node in setup_ray_cluster --- trinity/cli/launcher.py | 2 +- trinity/utils/dlc_utils.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index a34370803d..da8bceb321 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -171,7 +171,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if dlc: cluster_namespace = f"{config.project}-{config.name}" cluster_information = setup_ray_cluster(namespace=cluster_namespace) - config.cluster.ray_address = cluster_information["ray_address"] + config.cluster.ray_address = cluster_information.get("ray_address", "auto") if config.cluster.node_num is None: config.cluster.node_num = cluster_information.get("node_num") if config.cluster.gpu_per_node is None: diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 0a4ecfc64d..0a32049510 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -103,17 +103,14 @@ def setup_ray_cluster(namespace: str) -> Dict: ignore_reinit_error=True, ) - # get gpu_per_node from enviroment variables + # get gpu_per_node from ray cluster gpu_per_node = None - if "PAI_GPU_COUNT" in os.environ: - try: - gpu_per_node = int(os.environ["PAI_GPU_COUNT"]) - except ValueError: - logger.warning("Could not parse PAI_GPU_COUNT as an integer.") - elif "CUDA_VISIBLE_DEVICES" in os.environ: - visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] - if visible_devices: - gpu_per_node = len(visible_devices.split(",")) + alive_nodes = [n for n in ray.nodes() if n["alive"]] + for node in alive_nodes: + node_gpus = node.get("Resources", {}).get("GPU") + if node_gpus and node_gpus > 0: + gpu_per_node = int(node_gpus) + break if is_master: # master wait for worker nodes to join From 6fff5e5c03ed93f6f7cee6dc822772a7cd49134c Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 16 Oct 2025 19:23:52 +0800 Subject: [PATCH 4/5] remove duplicate code --- tests/common/config_test.py | 4 ++-- trinity/cli/launcher.py | 7 +------ trinity/common/config.py | 16 +++++++++------- trinity/utils/dlc_utils.py | 26 ++++---------------------- 4 files changed, 16 insertions(+), 37 deletions(-) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 147eaa8659..e51832f3a6 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -88,8 +88,8 @@ def test_update_config_from_ray_cluster(self): config.cluster.gpu_per_node = None config._update_config_from_ray_cluster() - self.assertTrue(config.cluster.node_num > 0) - self.assertTrue(config.cluster.gpu_per_node > 0) + self.assertEqual(config.cluster.node_num, 2) + self.assertEqual(config.cluster.gpu_per_node, 2) def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index da8bceb321..468ab2df53 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -170,12 +170,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): if dlc: cluster_namespace = f"{config.project}-{config.name}" - cluster_information = setup_ray_cluster(namespace=cluster_namespace) - config.cluster.ray_address = cluster_information.get("ray_address", "auto") - if config.cluster.node_num is None: - config.cluster.node_num = cluster_information.get("node_num") - if config.cluster.gpu_per_node is None: - config.cluster.gpu_per_node = cluster_information.get("gpu_per_node") + config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) if not is_running(): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") diff --git a/trinity/common/config.py b/trinity/common/config.py index 2a6794ed20..56e03c8d0e 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -619,12 +619,13 @@ def _update_config_from_ray_cluster(self) -> None: return # init ray cluster to detect node_num and gpu_per_node - ray.init( - address=self.cluster.ray_address, - ignore_reinit_error=True, - namespace=self.ray_namespace, - runtime_env={"env_vars": self.get_envs()}, - ) + was_initialized = ray.is_initialized() + if not was_initialized: + ray.init( + address=self.cluster.ray_address, + ignore_reinit_error=True, + namespace=self.ray_namespace, + ) alive_nodes = [n for n in ray.nodes() if n["alive"]] if not alive_nodes: @@ -647,7 +648,8 @@ def _update_config_from_ray_cluster(self) -> None: self.cluster.gpu_per_node = gpu_per_node logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}") - ray.shutdown() + if not was_initialized: + ray.shutdown() def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 0a32049510..63cd874115 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -2,7 +2,6 @@ import subprocess import sys import time -from typing import Dict import ray @@ -65,23 +64,20 @@ def wait_for_ray_worker_nodes(world_size: int) -> None: time.sleep(1) -def setup_ray_cluster(namespace: str) -> Dict: +def setup_ray_cluster(namespace: str) -> str: """Setup a ray cluster in DLC environment. This function will start a ray cluster if it is not running, otherwise it will reuse the existing ray cluster. Returns: - Dict: - - ray_address: The address of the ray cluster. - - node_num (Optional): The world size of the ray cluster. - - gpu_per_node (Optional): The number of GPUs per node. + str: The address of the ray cluster. """ env_vars = get_dlc_env_vars() is_master = env_vars["RANK"] == 0 if is_running(): # reuse existing ray cluster - return {"ray_address": "auto"} + return "auto" else: if is_master: cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}" @@ -102,25 +98,11 @@ def setup_ray_cluster(namespace: str) -> Dict: namespace=namespace, ignore_reinit_error=True, ) - - # get gpu_per_node from ray cluster - gpu_per_node = None - alive_nodes = [n for n in ray.nodes() if n["alive"]] - for node in alive_nodes: - node_gpus = node.get("Resources", {}).get("GPU") - if node_gpus and node_gpus > 0: - gpu_per_node = int(node_gpus) - break - if is_master: # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) ray.shutdown() - return { - "ray_address": f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", - "node_num": env_vars["WORLD_SIZE"], - "gpu_per_node": gpu_per_node, - } + return f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" else: # worker wait on the cluster status actor cluster_status = ( From a6097fc8c4067a6c095deda7f8fc427ea2ca6d9a Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 16 Oct 2025 19:54:01 +0800 Subject: [PATCH 5/5] fix unittest --- tests/explorer/explorer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 228e888c83..27f5b0d455 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -131,7 +131,7 @@ def test_explorer(self): def run_serve(config): config.check_and_update() - run_stage(config, "auto") + run_stage(config) def run_agent(base_url, model_path: str):