From 552fe2da7fa05070260666dc6d9ea5ee6cec5115 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 14 Oct 2025 17:23:43 +0800 Subject: [PATCH 1/6] 1. Implement serial save. 2. No longer set `max_model_len` from model config.json --- trinity/common/config.py | 16 +--------------- .../trainer/verl/fsdp_checkpoint_manager.py | 2 ++ .../verl/megatron_checkpoint_manager.py | 1 + trinity/trainer/verl_trainer.py | 19 ++++++++++++++++--- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 7d97a7a7af..28c9c07d83 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -843,21 +843,7 @@ def _check_model(self) -> None: f"`max_model_len` is set to {model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`." ) else: - from transformers import AutoConfig, AutoTokenizer - from transformers.tokenization_utils_base import LARGE_INTEGER - - tokenizer = AutoTokenizer.from_pretrained(model.model_path) - config = AutoConfig.from_pretrained(model.model_path) - max_model_len = min( - getattr(tokenizer, "model_max_length", LARGE_INTEGER), - getattr(config, "max_position_embeddings", LARGE_INTEGER), - ) - if max_model_len >= LARGE_INTEGER: - max_model_len = MAX_MODEL_LEN - logger.warning( - f"Failed to get `max_model_len` from model {model.model_path}, use {MAX_MODEL_LEN} instead." - ) - model.max_model_len = max_model_len + raise ValueError("Unable to determine `max_model_len`, please set it manually.") # both max_prompt_tokens and max_response_tokens are None if model.max_prompt_tokens is None and model.max_response_tokens is None: diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 62f3ddbfee..c70f0801f1 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -115,6 +115,7 @@ def _save_with_thread( thread.join() def _save(): + ray.get(self.checkpoint_monitor.notify_started.remote()) torch.save(obj, path) log_with_rank( f"Saved {prefix} to {os.path.abspath(path)}", @@ -357,6 +358,7 @@ def save_checkpoint( # noqa: C901 self._save_model_thread.join() def _save_model(): + ray.get(self.checkpoint_monitor.notify_started.remote()) save_model.save_pretrained(hf_local_path, state_dict=state_dict) log_with_rank( f"Saved hf_model to {os.path.abspath(hf_local_path)}", diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index a3b01cd5a4..c83037be00 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -125,6 +125,7 @@ def _save_state_dict(self, local_path, global_step): def finalize_save_fn(): # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided + ray.get(self.checkpoint_monitor.notify_started.remote()) log_with_rank( f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 60e1e0a074..2e32362b01 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -3,6 +3,7 @@ Modified from verl/trainer/ppo/ray_trainer.py """ +import asyncio import os import sys from collections import defaultdict @@ -57,6 +58,9 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None): self.latest_checkpoint_step = 0 self.latest_state_dict_step = 0 + self.condition = asyncio.Condition() + self.saving_count = 0 + def update_latest_checkpoint_step(self, step: int): assert step >= self.latest_checkpoint_step if step == self.latest_checkpoint_step: @@ -87,7 +91,7 @@ def update_latest_state_dict_step(self, step: int): with open(self.local_latest_state_dict_iteration, "w") as f: f.write(str(step)) - def register_thread_count( + async def register_thread_count( self, step: int, *, @@ -99,7 +103,7 @@ def register_thread_count( if checkpoint_thread_count != 0: self.checkpoint_counter[step] += checkpoint_thread_count - def monitor_step(self, step: int, is_state_dict: bool = False): + async def monitor_step(self, step: int, is_state_dict: bool = False): if is_state_dict: self.state_dict_steps.add(step) if self.state_dict_counter[step] == 0: @@ -109,7 +113,16 @@ def monitor_step(self, step: int, is_state_dict: bool = False): if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0: self.update_latest_checkpoint_step(step) - def notify_finished(self, step: int, is_state_dict: bool = False): + async def notify_started(self): + async with self.condition: + while self.saving_count > 0: + await self.condition.wait_for(lambda: self.saving_count == 0) + self.saving_count += 1 + + async def notify_finished(self, step: int, is_state_dict: bool = False): + async with self.condition: + self.saving_count -= 1 + self.condition.notify() if is_state_dict: self.state_dict_counter[step] -= 1 if ( From 3ddf51c4e1da873687985cf852113f76e4069efa Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 14 Oct 2025 19:20:33 +0800 Subject: [PATCH 2/6] fix in unitttest --- tests/cli/launcher_test.py | 2 ++ trinity/common/config.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 067efb4a6f..1b8ab142e8 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -30,6 +30,8 @@ class TestLauncherMain(unittest.TestCase): def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) self._orig_argv = sys.argv.copy() self.config = get_template_config() self.config.checkpoint_root_dir = get_checkpoint_path() diff --git a/trinity/common/config.py b/trinity/common/config.py index 28c9c07d83..2fba12dc61 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -16,7 +16,6 @@ LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR, LOG_NODE_IP_ENV_VAR, - MAX_MODEL_LEN, PLUGIN_DIRS_ENV_VAR, TRAINER_NAME, PromptType, From 3e84d6f7dc8672450f141c661e6f45f384948af2 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 14 Oct 2025 21:01:31 +0800 Subject: [PATCH 3/6] fix unittest --- tests/common/vllm_test.py | 8 +++----- tests/explorer/workflow_test.py | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 2376ee3128..ba288b2369 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -94,12 +94,11 @@ def print_debug(*args): "repeat_times", "enable_history", "use_async", - "max_model_len", ), [ - (2, 2, 2, True, False, None), - (1, 2, 1, False, True, None), - (2, 1, 3, True, True, None), + (2, 2, 2, True, False), + (1, 2, 1, False, True), + (2, 1, 3, True, True), ], ) class ModelWrapperTest(RayUnittestBaseAysnc): @@ -108,7 +107,6 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.model.max_model_len = self.max_model_len self.config.explorer.rollout_model.engine_num = self.engine_num self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index c6b6019c21..1710ce9f38 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -466,7 +466,6 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.model.max_model_len = None # self.max_model_len self.config.explorer.rollout_model.engine_num = 1 # self.engine_num self.config.explorer.rollout_model.tensor_parallel_size = 1 # self.tensor_parallel_size self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE From 43f162d8b00e9c8fc93777f16540e5e3ad301b45 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 15 Oct 2025 09:47:58 +0800 Subject: [PATCH 4/6] fix in unittest --- tests/common/vllm_test.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index ba288b2369..6705efee7a 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -187,12 +187,7 @@ async def test_generate( "content": results[0].response_text, } ) - if self.max_model_len is not None: - with self.assertRaises(ValueError): - exp = self.model_wrapper.convert_messages_to_experience(messages) - return - else: - exp = self.model_wrapper.convert_messages_to_experience(messages) + exp = self.model_wrapper.convert_messages_to_experience(messages) tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) result_dict = tokenizer.apply_chat_template( messages, From 4fac567ad80a6d1fbc061aa1378b26c834d0ff56 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 15 Oct 2025 13:23:18 +0800 Subject: [PATCH 5/6] add save strategy --- .../source/tutorial/trinity_configs.md | 2 +- .../source_zh/tutorial/trinity_configs.md | 2 +- trinity/common/config.py | 3 ++ trinity/common/constants.py | 7 ++++ .../trainer/verl/fsdp_checkpoint_manager.py | 14 ++++++-- .../verl/megatron_checkpoint_manager.py | 5 ++- trinity/trainer/verl_trainer.py | 33 ++++++++++++++++--- 7 files changed, 56 insertions(+), 10 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 731f5c2de4..3977726ca4 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -160,7 +160,7 @@ model: - `model_path`: Path to the model being trained. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. -- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will be inferred from the model configuration. +- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will default to `max_prompt_tokens` + `max_response_tokens`. However, if either `max_prompt_tokens` or `max_response_tokens` is not set, we will raise an error. - `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 1f2ed9f114..8b2c4edc1a 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -160,7 +160,7 @@ model: - `model_path`: 被训练模型的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 -- `max_model_len`: 该模型所支持的单个序列最大 token 数。 +- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如果未设置该值,则会尝试将其默认设为 `max_prompt_tokens + max_response_tokens`。但如果 `max_prompt_tokens` 或 `max_response_tokens` 中有任何一个未设置,代码将会报错。 - `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 diff --git a/trinity/common/config.py b/trinity/common/config.py index 2fba12dc61..eac7eabe79 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -19,6 +19,7 @@ PLUGIN_DIRS_ENV_VAR, TRAINER_NAME, PromptType, + SaveStrategy, StorageType, SyncMethod, SyncStyle, @@ -470,6 +471,8 @@ class TrainerConfig: actor_grad_clip: Optional[float] = None # TODO: extract more train-related params from underlying trainer engine + save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED + # Only one needs to be set for `trainer_config` and `trainer_config_path` trainer_config: Any = field(default_factory=dict) trainer_config_path: str = "" diff --git a/trinity/common/constants.py b/trinity/common/constants.py index a457729862..ad092603d2 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -104,3 +104,10 @@ class SyncStyle(CaseInsensitiveEnum): FIXED = "fixed" DYNAMIC_BY_TRAINER = "dynamic_by_trainer" DYNAMIC_BY_EXPLORER = "dynamic_by_explorer" + + +class SaveStrategy(CaseInsensitiveEnum): + SINGLE_THREAD = "single_thread" + SINGLE_PROCESS = "single_process" + SINGLE_NODE = "single_node" + UNRESTRICTED = "unrestricted" diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index c70f0801f1..576e1e45fa 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -115,7 +115,10 @@ def _save_with_thread( thread.join() def _save(): - ray.get(self.checkpoint_monitor.notify_started.remote()) + runtime_context = ray.get_runtime_context() + node_id = runtime_context.get_node_id() + job_id = runtime_context.get_job_id() + ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id)) torch.save(obj, path) log_with_rank( f"Saved {prefix} to {os.path.abspath(path)}", @@ -358,7 +361,14 @@ def save_checkpoint( # noqa: C901 self._save_model_thread.join() def _save_model(): - ray.get(self.checkpoint_monitor.notify_started.remote()) + runtime_context = ray.get_runtime_context() + node_id = runtime_context.get_node_id() + job_id = runtime_context.get_job_id() + ray.get( + self.checkpoint_monitor.notify_started.remote( + node_id=node_id, job_id=job_id + ) + ) save_model.save_pretrained(hf_local_path, state_dict=state_dict) log_with_rank( f"Saved hf_model to {os.path.abspath(hf_local_path)}", diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index c83037be00..b65b943782 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -125,7 +125,10 @@ def _save_state_dict(self, local_path, global_step): def finalize_save_fn(): # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided - ray.get(self.checkpoint_monitor.notify_started.remote()) + runtime_context = ray.get_runtime_context() + node_id = runtime_context.get_node_id() + job_id = runtime_context.get_job_id() + ray.get(self.checkpoint_monitor.notify_started.remote(node_id=node_id, job_id=job_id)) log_with_rank( f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2e32362b01..54dfba2b7a 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -34,6 +34,7 @@ from trinity.algorithm.algorithm import ALGORITHM_TYPE from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config +from trinity.common.constants import SaveStrategy from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto @@ -41,7 +42,9 @@ class CheckpointMonitor: - def __init__(self, default_local_dir: str, default_hdfs_dir: str = None): + def __init__( + self, save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None + ): self.logger = get_logger("checkpoint_monitor", in_ray_actor=True) self.default_local_dir = default_local_dir self.default_hdfs_dir = default_hdfs_dir @@ -58,7 +61,9 @@ def __init__(self, default_local_dir: str, default_hdfs_dir: str = None): self.latest_checkpoint_step = 0 self.latest_state_dict_step = 0 + self.save_strategy = save_strategy self.condition = asyncio.Condition() + self.current_identifier = 0 self.saving_count = 0 def update_latest_checkpoint_step(self, step: int): @@ -113,16 +118,28 @@ async def monitor_step(self, step: int, is_state_dict: bool = False): if self.checkpoint_counter[step] == 0 and self.state_dict_counter[step] == 0: self.update_latest_checkpoint_step(step) - async def notify_started(self): + async def notify_started(self, node_id: str, job_id: str): + if self.save_strategy == SaveStrategy.SINGLE_THREAD: + identifier = self.current_identifier + 1 + elif self.save_strategy == SaveStrategy.SINGLE_PROCESS: + identifier = f"{node_id}_{job_id}" + elif self.save_strategy == SaveStrategy.SINGLE_NODE: + identifier = node_id + elif self.save_strategy == SaveStrategy.UNRESTRICTED: + return + else: + raise ValueError(f"Invalid save strategy: {self.save_strategy}") + async with self.condition: - while self.saving_count > 0: + if identifier != self.current_identifier and self.saving_count > 0: await self.condition.wait_for(lambda: self.saving_count == 0) + self.current_identifier = identifier self.saving_count += 1 async def notify_finished(self, step: int, is_state_dict: bool = False): async with self.condition: self.saving_count -= 1 - self.condition.notify() + self.condition.notify_all() if is_state_dict: self.state_dict_counter[step] -= 1 if ( @@ -144,6 +161,7 @@ async def notify_finished(self, step: int, is_state_dict: bool = False): def get_actor( cls, namespace: str, + save_strategy: Optional[SaveStrategy] = None, default_local_dir: Optional[str] = None, default_hdfs_dir: Optional[str] = None, ): @@ -154,7 +172,11 @@ def get_actor( namespace=namespace, get_if_exists=True, ) - .remote(default_local_dir=default_local_dir, default_hdfs_dir=default_hdfs_dir) + .remote( + save_strategy=save_strategy, + default_local_dir=default_local_dir, + default_hdfs_dir=default_hdfs_dir, + ) ) @@ -204,6 +226,7 @@ def __init__( self.checkpoint_monitor = CheckpointMonitor.get_actor( namespace=global_config.synchronizer.ray_namespace, + save_strategy=global_config.trainer.save_strategy, default_local_dir=config.trainer.default_local_dir, default_hdfs_dir=config.trainer.default_hdfs_dir, ) From 8ec8afaca99b01a6a8db8cef869bd841b136aba9 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 15 Oct 2025 16:48:18 +0800 Subject: [PATCH 6/6] add doc for save strategy --- docs/sphinx_doc/source/tutorial/trinity_configs.md | 8 +++++++- docs/sphinx_doc/source_zh/tutorial/trinity_configs.md | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 3977726ca4..9f020fcf89 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -160,7 +160,7 @@ model: - `model_path`: Path to the model being trained. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. -- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not set, it will default to `max_prompt_tokens` + `max_response_tokens`. However, if either `max_prompt_tokens` or `max_response_tokens` is not set, we will raise an error. +- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised. - `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. @@ -405,6 +405,7 @@ trainer: trainer_type: 'verl' save_interval: 100 total_steps: 1000 + save_strategy: "unrestricted" trainer_config: null trainer_config_path: '' ``` @@ -413,6 +414,11 @@ trainer: - `trainer_type`: Trainer backend implementation. Currently only supports `verl`. - `save_interval`: Frequency (in steps) at which to save model checkpoints. - `total_steps`: Total number of training steps. +- `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows: + - `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially. + - `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially. + - `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially. + - `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously. - `trainer_config`: The trainer configuration provided inline. - `trainer_config_path`: The path to the trainer configuration file. Only one of `trainer_config_path` and `trainer_config` should be specified. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 8b2c4edc1a..1e02ab9443 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -160,7 +160,7 @@ model: - `model_path`: 被训练模型的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 -- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如果未设置该值,则会尝试将其默认设为 `max_prompt_tokens + max_response_tokens`。但如果 `max_prompt_tokens` 或 `max_response_tokens` 中有任何一个未设置,代码将会报错。 +- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误。 - `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 @@ -405,6 +405,7 @@ trainer: trainer_type: 'verl' save_interval: 100 total_steps: 1000 + save_strategy: "unrestricted" trainer_config: null trainer_config_path: '' ``` @@ -413,6 +414,11 @@ trainer: - `trainer_type`: trainer 后端实现。目前仅支持 `verl`。 - `save_interval`: 保存模型检查点的频率(步)。 - `total_steps`: 总训练步数。 +- `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下: + - `single_thread`:整个系统中,仅允许一个线程进行模型保存,不同保存线程之间串行执行。 + - `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。 + - `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。 + - `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。 - `trainer_config`: 内联提供的 trainer 配置。 - `trainer_config_path`: trainer 配置文件的路径。`trainer_config_path` 和 `trainer_config` 只能指定其一。