Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def test_config_flatten(self):
self.assertIsInstance(key, str)
self.assertNotIsInstance(value, dict)

def test_update_config_from_ray_cluster(self):
config = get_template_config()
config.cluster.node_num = None
config.cluster.gpu_per_node = None

config._update_config_from_ray_cluster()
self.assertEqual(config.cluster.node_num, 2)
self.assertEqual(config.cluster.gpu_per_node, 2)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
2 changes: 1 addition & 1 deletion tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_explorer(self):

def run_serve(config):
config.check_and_update()
run_stage(config, "auto")
run_stage(config)


def run_agent(base_url, model_path: str):
Expand Down
12 changes: 5 additions & 7 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def both(config: Config) -> None:
}


def run_stage(config: Config, ray_address: str) -> None:
def run_stage(config: Config) -> None:
ray.init(
address=ray_address,
address=config.cluster.ray_address,
ignore_reinit_error=True,
namespace=config.ray_namespace,
runtime_env={"env_vars": config.get_envs()},
Expand All @@ -168,11 +168,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
load_plugins()
config = load_config(config_path)

ray_address = "auto"

if dlc:
cluster_namespace = f"{config.project}-{config.name}"
ray_address = setup_ray_cluster(namespace=cluster_namespace)
config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace)

if not is_running():
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
Expand Down Expand Up @@ -203,7 +201,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
if prev_stage_checkpoint is not None:
stage_config.model.model_path = prev_stage_checkpoint
stage_config.check_and_update()
run_stage(stage_config, ray_address=ray_address)
run_stage(stage_config)
logger.info(
"===========================================================\n"
f"> Stage {i + 1}/{len(config.stages)} finished.\n"
Expand All @@ -212,7 +210,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config)
else:
config.check_and_update()
run_stage(config, ray_address=ray_address)
run_stage(config)

finally:
if dlc:
Expand Down
47 changes: 45 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional

import ray
from omegaconf import OmegaConf

from trinity.common.constants import (
Expand Down Expand Up @@ -360,8 +361,9 @@ class AlgorithmConfig:
class ClusterConfig:
"""Config for the cluster."""

node_num: int = 1
gpu_per_node: int = 8
ray_address: str = "auto"
node_num: Optional[int] = None
gpu_per_node: Optional[int] = None


@Experimental
Expand Down Expand Up @@ -611,6 +613,44 @@ def _check_deprecated(self) -> None:
"`explorer.runner_num` is deprecated, please use `explorer.runner_per_model` instead."
)

def _update_config_from_ray_cluster(self) -> None:
"""Update config if `node_num` or `gpu_per_node` are not set."""
if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None:
return

# init ray cluster to detect node_num and gpu_per_node
was_initialized = ray.is_initialized()
if not was_initialized:
ray.init(
address=self.cluster.ray_address,
ignore_reinit_error=True,
namespace=self.ray_namespace,
)

alive_nodes = [n for n in ray.nodes() if n["alive"]]
if not alive_nodes:
raise RuntimeError("Could not find any alive nodes in the Ray cluster.")

# set node_num
if self.cluster.node_num is None:
self.cluster.node_num = len(alive_nodes)
logger.info(f"Auto-detected and set node_num: {self.cluster.node_num}")

# set gpu_per_node
if self.cluster.gpu_per_node is None:
gpu_per_node = 0
for node in alive_nodes:
node_gpus = node.get("Resources", {}).get("GPU")
if node_gpus and node_gpus > 0:
gpu_per_node = int(node_gpus)
break

self.cluster.gpu_per_node = gpu_per_node
logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}")

if not was_initialized:
ray.shutdown()

def _check_interval(self) -> None:
assert self.synchronizer.sync_interval > 0

Expand Down Expand Up @@ -901,6 +941,9 @@ def check_and_update(self) -> Config: # noqa: C901
if self.ray_namespace is None or len(self.ray_namespace) == 0:
self.ray_namespace = f"{self.project}/{self.name}"

# check cluster infomation
self._update_config_from_ray_cluster()

# check algorithm
self._check_algorithm()

Expand Down
2 changes: 2 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
else:
rollout_gpu_num = 0

assert config.cluster.node_num is not None
assert config.cluster.gpu_per_node is not None
if config.cluster.node_num == 1:
# for single node scenarios, rollout and training are on the same node
self.trainer.nnodes = config.cluster.node_num
Expand Down