Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ model:

- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration.
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
Expand Down Expand Up @@ -405,6 +405,7 @@ trainer:
trainer_type: 'verl'
save_interval: 100
total_steps: 1000
save_strategy: "unrestricted"
trainer_config: null
trainer_config_path: ''
```
Expand All @@ -413,6 +414,11 @@ trainer:
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `total_steps`: Total number of training steps.
- `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows:
- `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially.
- `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially.
- `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially.
- `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously.
- `trainer_config`: The trainer configuration provided inline.
- `trainer_config_path`: The path to the trainer configuration file. Only one of `trainer_config_path` and `trainer_config` should be specified.

Expand Down
8 changes: 7 additions & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ model:

- `model_path`: 被训练模型的路径。
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
- `max_model_len`: 该模型所支持的单个序列最大 token 数。
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
Expand Down Expand Up @@ -405,6 +405,7 @@ trainer:
trainer_type: 'verl'
save_interval: 100
total_steps: 1000
save_strategy: "unrestricted"
trainer_config: null
trainer_config_path: ''
```
Expand All @@ -413,6 +414,11 @@ trainer:
- `trainer_type`: trainer 后端实现。目前仅支持 `verl`。
- `save_interval`: 保存模型检查点的频率(步)。
- `total_steps`: 总训练步数。
- `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下:
- `single_thread`:整个系统中,仅允许一个线程进行模型保存,不同保存线程之间串行执行。
- `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。
- `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。
- `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。
- `trainer_config`: 内联提供的 trainer 配置。
- `trainer_config_path`: trainer 配置文件的路径。`trainer_config_path` 和 `trainer_config` 只能指定其一。

Expand Down
2 changes: 2 additions & 0 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

class TestLauncherMain(unittest.TestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
self._orig_argv = sys.argv.copy()
self.config = get_template_config()
self.config.checkpoint_root_dir = get_checkpoint_path()
Expand Down
15 changes: 4 additions & 11 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ def print_debug(*args):
"repeat_times",
"enable_history",
"use_async",
"max_model_len",
),
[
(2, 2, 2, True, False, None),
(1, 2, 1, False, True, None),
(2, 1, 3, True, True, None),
(2, 2, 2, True, False),
(1, 2, 1, False, True),
(2, 1, 3, True, True),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
Expand All @@ -108,7 +107,6 @@ def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = self.max_model_len
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
Expand Down Expand Up @@ -189,12 +187,7 @@ async def test_generate(
"content": results[0].response_text,
}
)
if self.max_model_len is not None:
with self.assertRaises(ValueError):
exp = self.model_wrapper.convert_messages_to_experience(messages)
return
else:
exp = self.model_wrapper.convert_messages_to_experience(messages)
exp = self.model_wrapper.convert_messages_to_experience(messages)
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
result_dict = tokenizer.apply_chat_template(
messages,
Expand Down
1 change: 0 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = None # self.max_model_len
self.config.explorer.rollout_model.engine_num = 1 # self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = 1 # self.tensor_parallel_size
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
Expand Down
20 changes: 4 additions & 16 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
LOG_DIR_ENV_VAR,
LOG_LEVEL_ENV_VAR,
LOG_NODE_IP_ENV_VAR,
MAX_MODEL_LEN,
PLUGIN_DIRS_ENV_VAR,
TRAINER_NAME,
PromptType,
SaveStrategy,
StorageType,
SyncMethod,
SyncStyle,
Expand Down Expand Up @@ -471,6 +471,8 @@ class TrainerConfig:
actor_grad_clip: Optional[float] = None
# TODO: extract more train-related params from underlying trainer engine

save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED

# Only one needs to be set for `trainer_config` and `trainer_config_path`
trainer_config: Any = field(default_factory=dict)
trainer_config_path: str = ""
Expand Down Expand Up @@ -843,21 +845,7 @@ def _check_model(self) -> None:
f"`max_model_len` is set to {model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`."
)
else:
from transformers import AutoConfig, AutoTokenizer
from transformers.tokenization_utils_base import LARGE_INTEGER

tokenizer = AutoTokenizer.from_pretrained(model.model_path)
config = AutoConfig.from_pretrained(model.model_path)
max_model_len = min(
getattr(tokenizer, "model_max_length", LARGE_INTEGER),
getattr(config, "max_position_embeddings", LARGE_INTEGER),
)
if max_model_len >= LARGE_INTEGER:
max_model_len = MAX_MODEL_LEN
logger.warning(
f"Failed to get `max_model_len` from model {model.model_path}, use {MAX_MODEL_LEN} instead."
)
model.max_model_len = max_model_len
raise ValueError("Unable to determine `max_model_len`, please set it manually.")

# both max_prompt_tokens and max_response_tokens are None
if model.max_prompt_tokens is None and model.max_response_tokens is None:
Expand Down
7 changes: 7 additions & 0 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,10 @@ class SyncStyle(CaseInsensitiveEnum):
FIXED = "fixed"
DYNAMIC_BY_TRAINER = "dynamic_by_trainer"
DYNAMIC_BY_EXPLORER = "dynamic_by_explorer"


class SaveStrategy(CaseInsensitiveEnum):
SINGLE_THREAD = "single_thread"
SINGLE_PROCESS = "single_process"
SINGLE_NODE = "single_node"
UNRESTRICTED = "unrestricted"
12 changes: 12 additions & 0 deletions trinity/trainer/verl/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def _save_with_thread(
thread.join()

def _save():
runtime_context = ray.get_runtime_context()
node_id = runtime_context.get_node_id()
job_id = runtime_context.get_job_id()
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
torch.save(obj, path)
log_with_rank(
f"Saved {prefix} to {os.path.abspath(path)}",
Expand Down Expand Up @@ -357,6 +361,14 @@ def save_checkpoint( # noqa: C901
self._save_model_thread.join()

def _save_model():
runtime_context = ray.get_runtime_context()
node_id = runtime_context.get_node_id()
job_id = runtime_context.get_job_id()
ray.get(
self.checkpoint_monitor.notify_started.remote(
node_id=node_id, job_id=job_id
)
)
save_model.save_pretrained(hf_local_path, state_dict=state_dict)
log_with_rank(
f"Saved hf_model to {os.path.abspath(hf_local_path)}",
Expand Down
4 changes: 4 additions & 0 deletions trinity/trainer/verl/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def _save_state_dict(self, local_path, global_step):

def finalize_save_fn():
# Rank 0 uploads checkpoint to HDFS if hdfs_path is provided
runtime_context = ray.get_runtime_context()
node_id = runtime_context.get_node_id()
job_id = runtime_context.get_job_id()
ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id))
log_with_rank(
f"Dist checkpointing save completed for {dist_checkpoint_path}",
rank=self.rank,
Expand Down
46 changes: 41 additions & 5 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Modified from verl/trainer/ppo/ray_trainer.py
"""
import asyncio
import os
import sys
from collections import defaultdict
Expand Down Expand Up @@ -33,14 +34,17 @@
from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import Config
from trinity.common.constants import SaveStrategy
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto
from trinity.utils.log import get_logger


class CheckpointMonitor:
def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
def __init__(
self, save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None
):
self.logger = get_logger("checkpoint_monitor", in_ray_actor=True)
self.default_local_dir = default_local_dir
self.default_hdfs_dir = default_hdfs_dir
Expand All @@ -57,6 +61,11 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
self.latest_checkpoint_step = 0
self.latest_state_dict_step = 0

self.save_strategy = save_strategy
self.condition = asyncio.Condition()
self.current_identifier = 0
self.saving_count = 0

def update_latest_checkpoint_step(self, step: int):
assert step >= self.latest_checkpoint_step
if step == self.latest_checkpoint_step:
Expand Down Expand Up @@ -87,7 +96,7 @@ def update_latest_state_dict_step(self, step: int):
with open(self.local_latest_state_dict_iteration, "w") as f:
f.write(str(step))

def register_thread_count(
async def register_thread_count(
self,
step: int,
*,
Expand All @@ -99,7 +108,7 @@ def register_thread_count(
if checkpoint_thread_count != 0:
self.checkpoint_counter[step] += checkpoint_thread_count

def monitor_step(self, step: int, is_state_dict: bool = False):
async def monitor_step(self, step: int, is_state_dict: bool = False):
if is_state_dict:
self.state_dict_steps.add(step)
if self.state_dict_counter[step] == 0:
Expand All @@ -109,7 +118,28 @@ def monitor_step(self, step: int, is_state_dict: bool = False):
if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0:
self.update_latest_checkpoint_step(step)

def notify_finished(self, step: int, is_state_dict: bool = False):
async def notify_started(self, node_id: str, job_id: str):
if self.save_strategy == SaveStrategy.SINGLE_THREAD:
identifier = self.current_identifier + 1
elif self.save_strategy == SaveStrategy.SINGLE_PROCESS:
identifier = f"{node_id}_{job_id}"
elif self.save_strategy == SaveStrategy.SINGLE_NODE:
identifier = node_id
elif self.save_strategy == SaveStrategy.UNRESTRICTED:
return
else:
raise ValueError(f"Invalid save strategy: {self.save_strategy}")

async with self.condition:
if identifier != self.current_identifier and self.saving_count > 0:
await self.condition.wait_for(lambda: self.saving_count == 0)
self.current_identifier = identifier
self.saving_count += 1

async def notify_finished(self, step: int, is_state_dict: bool = False):
async with self.condition:
self.saving_count -= 1
self.condition.notify_all()
if is_state_dict:
self.state_dict_counter[step] -= 1
if (
Expand All @@ -131,6 +161,7 @@ def notify_finished(self, step: int, is_state_dict: bool = False):
def get_actor(
cls,
namespace: str,
save_strategy: Optional[SaveStrategy] = None,
default_local_dir: Optional[str] = None,
default_hdfs_dir: Optional[str] = None,
):
Expand All @@ -141,7 +172,11 @@ def get_actor(
namespace=namespace,
get_if_exists=True,
)
.remote(default_local_dir=default_local_dir, default_hdfs_dir=default_hdfs_dir)
.remote(
save_strategy=save_strategy,
default_local_dir=default_local_dir,
default_hdfs_dir=default_hdfs_dir,
)
)


Expand Down Expand Up @@ -191,6 +226,7 @@ def __init__(

self.checkpoint_monitor = CheckpointMonitor.get_actor(
namespace=global_config.synchronizer.ray_namespace,
save_strategy=global_config.trainer.save_strategy,
default_local_dir=config.trainer.default_local_dir,
default_hdfs_dir=config.trainer.default_hdfs_dir,
)
Expand Down