From 122f2a3c95d9cee4f90aeaf753adb2e8254d1f9b Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Sat, 6 Sep 2025 16:49:28 +0800 Subject: [PATCH 1/3] bug fix in load_plugin && config_manager --- trinity/cli/launcher.py | 2 ++ trinity/common/constants.py | 8 -------- trinity/manager/config_manager.py | 12 ++++++++++-- .../manager/config_registry/buffer_config_manager.py | 2 +- .../manager/config_registry/model_config_manager.py | 8 ++++---- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 7031d0829d..12fea13e1e 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -20,6 +20,7 @@ from trinity.trainer.trainer import Trainer from trinity.utils.dlc_utils import setup_ray_cluster from trinity.utils.log import get_logger +from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) @@ -124,6 +125,7 @@ def both(config: Config) -> None: def run(config_path: str, dlc: bool = False, plugin_dir: str = None): + load_plugins() config = load_config(config_path) config.check_and_update() pprint(config) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 531d113965..aa35472c5b 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -56,14 +56,6 @@ class StorageType(CaseInsensitiveEnum): FILE = "file" -class MonitorType(CaseInsensitiveEnum): - """Monitor Type.""" - - WANDB = "wandb" - TENSORBOARD = "tensorboard" - MLFLOW = "mlflow" - - class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): def __call__(cls, value, *args, **kwargs): if value == "online": diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index fa3c4da524..d3a6d52315 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -16,6 +16,7 @@ from trinity.common.constants import StorageType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic +from trinity.utils.plugin_loader import load_plugins register_map = { "sample_strategy": SAMPLE_STRATEGY, @@ -29,6 +30,7 @@ class ConfigManager: def __init__(self): + load_plugins() self.unfinished_fields = set() CONFIG_GENERATORS.set_unfinished_fields(self.unfinished_fields) st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:") @@ -380,7 +382,10 @@ def _generate_verl_config(self): "actor_entropy_from_logits_with_chunking" ], "entropy_checkpointing": st.session_state["actor_entropy_checkpointing"], - "checkpoint": {"contents": st.session_state["actor_checkpoint"]}, + "checkpoint": { + "load_contents": st.session_state["actor_checkpoint"], + "save_contents": st.session_state["actor_checkpoint"], + }, "optim": { "lr": st.session_state["actor_lr"], "lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"], @@ -466,7 +471,10 @@ def _generate_verl_config(self): "shuffle": False, "grad_clip": st.session_state["critic_grad_clip"], "cliprange_value": st.session_state["critic_cliprange_value"], - "checkpoint": {"contents": st.session_state["critic_checkpoint"]}, + "checkpoint": { + "load_contents": st.session_state["critic_checkpoint"], + "save_contents": st.session_state["critic_checkpoint"], + }, } else: del trainer_config["critic"] diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index 37cf11ea75..b999fa7c71 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.buffer.queue import PRIORITY_FUNC +from trinity.buffer.storage.queue import PRIORITY_FUNC from trinity.common.constants import PromptType, StorageType from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py index 53038676f6..0e8a7a3cc2 100644 --- a/trinity/manager/config_registry/model_config_manager.py +++ b/trinity/manager/config_registry/model_config_manager.py @@ -2,9 +2,9 @@ import streamlit as st -from trinity.common.constants import MonitorType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic +from trinity.utils.monitor import MONITOR def set_total_gpu_num(): @@ -54,11 +54,11 @@ def check_checkpoint_root_dir(unfinished_fields: set, key: str): st.warning("Please input an absolute path.") -@CONFIG_GENERATORS.register_config(default_value=MonitorType.TENSORBOARD.value) +@CONFIG_GENERATORS.register_config(default_value="tensorboard") def set_monitor_type(**kwargs): st.selectbox( "Monitor Type", - options=[monitor_type.value for monitor_type in MonitorType], + options=MONITOR.modules.keys(), **kwargs, ) @@ -96,7 +96,7 @@ def set_max_response_tokens(**kwargs): @CONFIG_GENERATORS.register_config(default_value=2048) def set_max_model_len(**kwargs): - st.number_input("Max Token Length", min_value=1, **kwargs) + st.number_input("Max Model Length", min_value=1, **kwargs) # Cluster Config From b589ef58ee1f1f58e331a4e3bf990c31fb8a94ea Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 8 Sep 2025 16:07:32 +0800 Subject: [PATCH 2/3] update config manager --- trinity/manager/config_manager.py | 128 +++++++++++---- .../config_registry/buffer_config_manager.py | 53 +++--- .../config_registry/model_config_manager.py | 10 ++ .../config_registry/trainer_config_manager.py | 155 +++++++++++++----- 4 files changed, 250 insertions(+), 96 deletions(-) diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index d3a6d52315..e36c5b28af 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -133,7 +133,7 @@ def beginner_mode(self): st.header("Important Configs") self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size") - self.get_configs("total_epochs", "explore_batch_size", "train_batch_size", "repeat_times") + self.get_configs("total_epochs", "explore_batch_size", "repeat_times", "train_batch_size") self.get_configs("storage_type", "max_response_tokens", "max_model_len", "ppo_epochs") @@ -203,9 +203,10 @@ def _expert_buffer_part(self): self.get_configs("use_priority_queue") self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay") - self.buffer_advanced_tab = st.expander("Advanced Config") - with self.buffer_advanced_tab: - self.get_configs("buffer_max_retry_times", "max_retry_interval") + # TODO: used for SQL storage + # self.buffer_advanced_tab = st.expander("Advanced Config") + # with self.buffer_advanced_tab: + # self.get_configs("buffer_max_retry_times", "max_retry_interval") def _expert_explorer_part(self): self.get_configs("sync_method", "sync_interval", "sync_timeout") @@ -267,15 +268,43 @@ def _expert_verl_training_part(self): self.get_configs("ppo_epochs", "training_strategy", "resume_mode", "impl_backend") - self.get_configs("param_offload", "optimizer_offload", "forward_prefetch") self.get_configs("resume_from_path") + if st.session_state["training_strategy"] == "fsdp": + self.get_configs("param_offload", "optimizer_offload", "forward_prefetch") + elif st.session_state["training_strategy"] == "fsdp2": + self.get_configs("offload_policy", "reshard_after_forward") + elif st.session_state["training_strategy"] == "megatron": + with st.expander("Megatron Config"): + self.get_configs("param_offload", "grad_offload", "optimizer_offload") + self.get_configs( + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "virtual_pipeline_model_parallel_size", + ) + self.get_configs( + "expert_model_parallel_size", + "expert_tensor_parallel_size", + "context_parallel_size", + ) + self.get_configs( + "sequence_parallel", + "use_distributed_optimizer", + "use_dist_checkpointing", + "use_mbridge", + ) + self.get_configs("dist_checkpointing_path") + self.get_configs( + "recompute_granularity", "recompute_method", "recompute_num_layers" + ) + self.get_configs("recompute_modules") + with st.expander("Advanced Config"): self.get_configs("critic_warmup", "total_training_steps") self.get_configs("default_hdfs_dir") - self.get_configs("remove_previous_ckpt_in_save", "del_local_ckpt_after_load") + self.get_configs("del_local_ckpt_after_load") self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep") @@ -342,16 +371,53 @@ def _generate_verl_config(self): use_dynamic_bsz = "dynamic_bsz" in st.session_state["training_args"] use_fused_kernels = "use_fused_kernels" in st.session_state["training_args"] - if st.session_state["training_strategy"] == "fsdp": - fsdp_config = { - "wrap_policy": {"min_num_params": 0}, - "param_offload": st.session_state["param_offload"], - "optimizer_offload": st.session_state["optimizer_offload"], - "fsdp_size": -1, - "forward_prefetch": st.session_state["forward_prefetch"], + if st.session_state["training_strategy"] in {"fsdp", "fsdp2"}: + distribution_config = { + "fsdp_config": { + "fsdp_size": -1, + # for fsdp + "wrap_policy": {"min_num_params": 0}, + "param_offload": st.session_state["param_offload"], + "optimizer_offload": st.session_state["optimizer_offload"], + "forward_prefetch": st.session_state["forward_prefetch"], + # for fsdp2 + "offload_policy": st.session_state["offload_policy"], + "reshard_after_forward": st.session_state["reshard_after_forward"], + } + } + elif st.session_state["training_strategy"] == "megatron": + distribution_config = { + "megatron": { + "param_offload": st.session_state["param_offload"], + "grad_offload": st.session_state["grad_offload"], + "optimizer_offload": st.session_state["optimizer_offload"], + "tensor_model_parallel_size": st.session_state["tensor_model_parallel_size"], + "pipeline_model_parallel_size": st.session_state[ + "pipeline_model_parallel_size" + ], + "virtual_pipeline_model_parallel_size": st.session_state[ + "virtual_pipeline_model_parallel_size" + ], + "expert_model_parallel_size": st.session_state["expert_model_parallel_size"], + "expert_tensor_parallel_size": st.session_state["expert_tensor_parallel_size"], + "context_parallel_size": st.session_state["context_parallel_size"], + "sequence_parallel": st.session_state["sequence_parallel"], + "use_distributed_optimizer": st.session_state["use_distributed_optimizer"], + "use_dist_checkpointing": st.session_state["use_dist_checkpointing"], + "dist_checkpointing_path": st.session_state["dist_checkpointing_path"], + "seed": st.session_state["seed"], + # TODO: override_ddp_config + "override_transformer_config": { + "recompute_granularity": st.session_state["recompute_granularity"], + "recompute_modules": st.session_state["recompute_modules"], + "recompute_method": st.session_state["recompute_method"], + "recompute_num_layers": st.session_state["recompute_num_layers"], + }, + "use_mbridge": st.session_state["use_mbridge"], + } } else: - fsdp_config = {} + distribution_config = {} ppo_max_token_len_per_gpu = ( st.session_state["repeat_times"] * st.session_state["max_model_len"] @@ -374,7 +440,6 @@ def _generate_verl_config(self): "use_dynamic_bsz": use_dynamic_bsz, "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu, "ppo_epochs": st.session_state["ppo_epochs"], - "shuffle": False, "ulysses_sequence_parallel_size": st.session_state[ "actor_ulysses_sequence_parallel_size" ], @@ -390,16 +455,10 @@ def _generate_verl_config(self): "lr": st.session_state["actor_lr"], "lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"], "warmup_style": st.session_state["actor_warmup_style"], - "total_training_steps": ( - -1 - if st.session_state["total_training_steps"] is None - else st.session_state["total_training_steps"] - ), + "total_training_steps": (st.session_state["total_training_steps"] or -1), }, - "fsdp_config": copy.deepcopy(fsdp_config), }, "ref": { - "fsdp_config": copy.deepcopy(fsdp_config), "log_prob_micro_batch_size_per_gpu": st.session_state[ "ref_log_prob_micro_batch_size_per_gpu" ], @@ -420,14 +479,15 @@ def _generate_verl_config(self): "resume_mode": st.session_state["resume_mode"], "resume_from_path": st.session_state["resume_from_path"], "default_hdfs_dir": st.session_state["default_hdfs_dir"], - "remove_previous_ckpt_in_save": st.session_state["remove_previous_ckpt_in_save"], "del_local_ckpt_after_load": st.session_state["del_local_ckpt_after_load"], - "val_before_train": False, "max_actor_ckpt_to_keep": st.session_state["max_actor_ckpt_to_keep"], "max_critic_ckpt_to_keep": st.session_state["max_critic_ckpt_to_keep"], }, } + trainer_config["actor_rollout_ref"]["actor"].update(copy.deepcopy(distribution_config)) + trainer_config["actor_rollout_ref"]["ref"].update(copy.deepcopy(distribution_config)) + if use_fused_kernels: trainer_config["actor_rollout_ref"]["model"]["fused_kernel_options"] = { "impl_backend": st.session_state["impl_backend"], @@ -441,18 +501,13 @@ def _generate_verl_config(self): "lr": st.session_state["critic_lr"], "lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"], "warmup_style": st.session_state["critic_warmup_style"], - "total_training_steps": ( - -1 - if st.session_state["total_training_steps"] is None - else st.session_state["total_training_steps"] - ), + "total_training_steps": (st.session_state["total_training_steps"] or -1), }, "model": { "override_config": {}, "external_lib": None, "enable_gradient_checkpointing": enable_gradient_checkpointing, "use_remove_padding": use_remove_padding, - "fsdp_config": copy.deepcopy(fsdp_config), }, "ppo_mini_batch_size": st.session_state["train_batch_size"], "ppo_micro_batch_size_per_gpu": st.session_state[ @@ -468,7 +523,6 @@ def _generate_verl_config(self): "critic_ulysses_sequence_parallel_size" ], "ppo_epochs": st.session_state["ppo_epochs"], - "shuffle": False, "grad_clip": st.session_state["critic_grad_clip"], "cliprange_value": st.session_state["critic_cliprange_value"], "checkpoint": { @@ -476,6 +530,10 @@ def _generate_verl_config(self): "save_contents": st.session_state["critic_checkpoint"], }, } + if st.session_state["training_strategy"] in {"fsdp", "fsdp2"}: + trainer_config["critic"]["model"].update(copy.deepcopy(distribution_config)) + elif st.session_state["training_strategy"] == "megatron": + trainer_config["critic"].update(copy.deepcopy(distribution_config)) else: del trainer_config["critic"] return trainer_config @@ -527,12 +585,14 @@ def _gen_buffer_config(self): "name": "experience_buffer", "storage_type": st.session_state["storage_type"], "path": experience_buffer_path, - "max_retry_interval": st.session_state["max_retry_interval"], - "max_retry_times": st.session_state["buffer_max_retry_times"], + # "max_retry_interval": st.session_state["max_retry_interval"], + # "max_retry_times": st.session_state["buffer_max_retry_times"], }, "sft_warmup_steps": st.session_state["sft_warmup_steps"], }, } + if st.session_state["train_batch_size"] is None: + del buffer_config["train_batch_size"] if st.session_state["algorithm_type"] != "dpo": experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["use_priority_queue"] = st.session_state["use_priority_queue"] @@ -673,6 +733,8 @@ def generate_config(self): "data_processor": {}, # TODO: Add data processor config "model": { "model_path": st.session_state["model_path"], + "max_prompt_tokens": st.session_state["max_prompt_tokens"], + "min_response_tokens": st.session_state["min_response_tokens"], "max_response_tokens": st.session_state["max_response_tokens"], "max_model_len": st.session_state["max_model_len"], }, diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index b999fa7c71..7c378c70af 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -22,6 +22,22 @@ def set_explore_batch_size(**kwargs): ) +def get_train_batch_size() -> int: + return ( + st.session_state["train_batch_size"] + or st.session_state["explore_batch_size"] * st.session_state["repeat_times"] + ) + + +def get_train_batch_size_per_gpu() -> int: + return st.session_state["_train_batch_size_per_gpu"] or max( + st.session_state["explore_batch_size"] + * st.session_state["repeat_times"] + // st.session_state["trainer_gpu_num"], + 1, + ) + + def _str_for_train_batch_size(): trainer_gpu_num_str = ( "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" @@ -29,23 +45,26 @@ def _str_for_train_batch_size(): else "`gpu_per_node * node_num`" ) return ( - f"Usually set to `task_batch_size` * `repeat_times`." - f"Please ensure that `train_batch_size` can be divided by " - f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." + f"`train_batch_size` defaults to `task_batch_size` * `repeat_times`.\n\n" + f"Please ensure that `train_batch_size` ({get_train_batch_size()}) can be divided by " + f"{trainer_gpu_num_str} ({st.session_state['trainer_gpu_num']})." ) @CONFIG_GENERATORS.register_config( - default_value=96, + default_value=None, visible=lambda: st.session_state["trainer_gpu_num"] > 0, - other_configs={"_train_batch_size_per_gpu": 16}, + other_configs={"_train_batch_size_per_gpu": None}, ) def set_train_batch_size(**kwargs): key = kwargs.get("key") trainer_gpu_num = st.session_state["trainer_gpu_num"] st.session_state[key] = ( st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + if st.session_state["_train_batch_size_per_gpu"] is not None + else None ) + placeholder = st.session_state["explore_batch_size"] * st.session_state["repeat_times"] def on_change(): st.session_state["_train_batch_size_per_gpu"] = max( @@ -58,13 +77,14 @@ def on_change(): step=trainer_gpu_num, help=_str_for_train_batch_size(), on_change=on_change, + placeholder=placeholder, **kwargs, ) @CONFIG_GENERATORS.register_check() def check_train_batch_size(unfinished_fields: set, key: str): - if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0: + if get_train_batch_size() % st.session_state["trainer_gpu_num"] != 0: unfinished_fields.add(key) st.warning(_str_for_train_batch_size()) @@ -91,25 +111,6 @@ def check_taskset_path(unfinished_fields: set, key: str): st.warning("Please input taskset path.") -# def _set_temperature(self): -# st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) - -# def _set_top_p(self): -# st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0) - -# def _set_top_k(self): -# st.number_input( -# "Top-k", -# key="top_k", -# min_value=-1, -# max_value=512, -# help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.", -# ) - -# def _set_logprobs(self): -# st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) - - @CONFIG_GENERATORS.register_config( visible=lambda: st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"], @@ -138,7 +139,7 @@ def set_taskset_args(**kwargs): response_key_col.text_input( "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" ) - # self._set_configs_with_st_columns(["temperature", "logprobs"]) + temperature_col, logprobs_col = st.columns(2) temperature_col.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) logprobs_col.number_input("Logprobs", key="logprobs", min_value=0, max_value=20) diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py index 0e8a7a3cc2..19c026e798 100644 --- a/trinity/manager/config_registry/model_config_manager.py +++ b/trinity/manager/config_registry/model_config_manager.py @@ -89,6 +89,16 @@ def set_critic_model_path(**kwargs): ) +@CONFIG_GENERATORS.register_config(default_value=None) +def set_max_prompt_tokens(**kwargs): + st.number_input("Max Prompt Length", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_min_response_tokens(**kwargs): + st.number_input("Min Response Length", min_value=1, **kwargs) + + @CONFIG_GENERATORS.register_config(default_value=1024) def set_max_response_tokens(**kwargs): st.number_input("Max Response Length", min_value=1, **kwargs) diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index a81ed84e26..e409a2dbc9 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -1,7 +1,9 @@ import streamlit as st from trinity.algorithm.algorithm import ALGORITHM_TYPE -from trinity.common.constants import SyncMethod +from trinity.manager.config_registry.buffer_config_manager import ( + get_train_batch_size_per_gpu, +) from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS @@ -15,32 +17,11 @@ def set_trainer_type(**kwargs): st.selectbox("Trainer Type", ["verl"], **kwargs) -@CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100}) +@CONFIG_GENERATORS.register_config(default_value=100) def set_save_interval(**kwargs): - key = kwargs.get("key") - if ( - st.session_state["algorithm_type"] == "dpo" - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state[key] = st.session_state["_nccl_save_interval"] - freeze_save_interval = False - else: - st.session_state[key] = st.session_state["sync_interval"] - freeze_save_interval = True - - def on_change(): - if ( - st.session_state["algorithm_type"] == "dpo" - or st.session_state["sync_method"] == SyncMethod.NCCL.value - ): - st.session_state["_nccl_save_interval"] = st.session_state[key] - st.number_input( "Save Interval", min_value=1, - help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`", - disabled=freeze_save_interval, - on_change=on_change, **kwargs, ) @@ -95,31 +76,136 @@ def set_ppo_epochs(**kwargs): def set_training_strategy(**kwargs): st.selectbox( "Training Strategy", - ["fsdp", "megatron"], + ["fsdp", "fsdp2", "megatron"], help="megatron is not tested", **kwargs, ) def use_fsdp(): + return st.session_state["training_strategy"] in {"fsdp", "fsdp2"} + + +def use_fsdp1(): return st.session_state["training_strategy"] == "fsdp" -@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +def use_fsdp2(): + return st.session_state["training_strategy"] == "fsdp2" + + +@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) def set_param_offload(**kwargs): - st.checkbox("FSDP Param Offload", **kwargs) + st.checkbox("Param Offload", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +@CONFIG_GENERATORS.register_config(default_value=False) +def set_grad_offload(**kwargs): + st.checkbox("Grad Offload", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) def set_optimizer_offload(**kwargs): - st.checkbox("FSDP Optimizer Offload", **kwargs) + st.checkbox("Optimizer Offload", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp) +@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) def set_forward_prefetch(**kwargs): st.checkbox("FSDP Forward Prefetch", **kwargs) +@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp2) +def set_offload_policy(**kwargs): + st.checkbox("Enable FSDP2 offload_policy", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True) # , visible=use_fsdp2) +def set_reshard_after_forward(**kwargs): + st.checkbox("FSDP2 Reshard After Forward", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_tensor_model_parallel_size(**kwargs): + st.number_input("Tensor Model Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_pipeline_model_parallel_size(**kwargs): + st.number_input("Pipeline Model Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_virtual_pipeline_model_parallel_size(**kwargs): + st.number_input("Virtual Pipeline Model Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_expert_model_parallel_size(**kwargs): + st.number_input("Expert Model Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_expert_tensor_parallel_size(**kwargs): + st.number_input("Expert Tensor Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=1) +def set_context_parallel_size(**kwargs): + st.number_input("Context Parallel Size", min_value=1, **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_sequence_parallel(**kwargs): + st.checkbox("Sequence Parallel", **kwargs) + + +# TODO: check parallel settings + + +@CONFIG_GENERATORS.register_config(default_value=True) +def set_use_distributed_optimizer(**kwargs): + st.checkbox("Use Distributed Optimizer", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_use_dist_checkpointing(**kwargs): + st.checkbox("Use Distributed Checkpointing", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_dist_checkpointing_path(**kwargs): + st.text_input("Distributed Checkpointing Path", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=False) +def set_use_mbridge(**kwargs): + st.checkbox("Use MBridge", **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_recompute_granularity(**kwargs): + st.selectbox("Recompute Granularity", ["selective", "full"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=["core_attn"]) +def set_recompute_modules(**kwargs): + st.multiselect( + "Recompute Modules", + ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe", "shared_experts"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_recompute_method(**kwargs): + st.selectbox("Recompute Method", ["uniform", "block"], **kwargs) + + +@CONFIG_GENERATORS.register_config(default_value=None) +def set_recompute_num_layers(**kwargs): + st.number_input("Recompute Num Layers", min_value=1, **kwargs) + + @CONFIG_GENERATORS.register_config(default_value="auto") def set_resume_mode(**kwargs): st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs) @@ -168,11 +254,6 @@ def set_default_hdfs_dir(**kwargs): st.text_input("Default HDFS Dir", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False) -def set_remove_previous_ckpt_in_save(**kwargs): - st.checkbox("Remove Previous Checkpoint in Save", **kwargs) - - @CONFIG_GENERATORS.register_config(default_value=False) def set_del_local_ckpt_after_load(**kwargs): st.checkbox("Delete Local Checkpoint After Load", **kwargs) @@ -226,7 +307,7 @@ def set_target_kl(**kwargs): @CONFIG_GENERATORS.register_config(default_value=4) def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") - max_value = st.session_state["_train_batch_size_per_gpu"] + max_value = get_train_batch_size_per_gpu() st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs @@ -236,7 +317,7 @@ def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): @CONFIG_GENERATORS.register_config(default_value=8) def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") - max_value = st.session_state["_train_batch_size_per_gpu"] + max_value = get_train_batch_size_per_gpu() st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs @@ -356,7 +437,7 @@ def set_critic_cliprange_value(**kwargs): @CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic) def set_critic_ppo_micro_batch_size_per_gpu(**kwargs): key = kwargs.get("key") - max_value = st.session_state["_train_batch_size_per_gpu"] + max_value = get_train_batch_size_per_gpu() st.session_state[key] = min(st.session_state[key], max_value) st.number_input( "Micro Batch Size Per GPU :blue-badge[(Critic)]", From 03695809bb11e1bdd8188a812cbf291d3d5407ee Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 8 Sep 2025 17:25:15 +0800 Subject: [PATCH 3/3] apply suggestions from gemini && bug fix in plugin loader --- trinity/cli/launcher.py | 6 ++++- trinity/manager/config_manager.py | 13 +++++++---- .../config_registry/trainer_config_manager.py | 22 +++++-------------- trinity/utils/plugin_loader.py | 2 ++ 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 12fea13e1e..ec32fb3cf0 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -125,13 +125,17 @@ def both(config: Config) -> None: def run(config_path: str, dlc: bool = False, plugin_dir: str = None): + if plugin_dir: + os.environ[PLUGIN_DIRS_ENV_VAR] = os.pathsep.join( + [os.environ.get(PLUGIN_DIRS_ENV_VAR, ""), plugin_dir] + ) load_plugins() config = load_config(config_path) config.check_and_update() pprint(config) envs = { - PLUGIN_DIRS_ENV_VAR: plugin_dir or "", + PLUGIN_DIRS_ENV_VAR: os.environ.get(PLUGIN_DIRS_ENV_VAR, ""), LOG_DIR_ENV_VAR: config.log.save_dir, LOG_LEVEL_ENV_VAR: config.log.level, LOG_NODE_IP_ENV_VAR: "1" if config.log.group_by_node else "0", diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index e36c5b28af..14b5f64de9 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -14,6 +14,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY from trinity.common.constants import StorageType +from trinity.manager.config_registry.buffer_config_manager import get_train_batch_size from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic from trinity.utils.plugin_loader import load_plugins @@ -371,16 +372,20 @@ def _generate_verl_config(self): use_dynamic_bsz = "dynamic_bsz" in st.session_state["training_args"] use_fused_kernels = "use_fused_kernels" in st.session_state["training_args"] - if st.session_state["training_strategy"] in {"fsdp", "fsdp2"}: + if st.session_state["training_strategy"] == "fsdp": distribution_config = { "fsdp_config": { "fsdp_size": -1, - # for fsdp "wrap_policy": {"min_num_params": 0}, "param_offload": st.session_state["param_offload"], "optimizer_offload": st.session_state["optimizer_offload"], "forward_prefetch": st.session_state["forward_prefetch"], - # for fsdp2 + } + } + elif st.session_state["training_strategy"] == "fsdp2": + distribution_config = { + "fsdp_config": { + "fsdp_size": -1, "offload_policy": st.session_state["offload_policy"], "reshard_after_forward": st.session_state["reshard_after_forward"], } @@ -509,7 +514,7 @@ def _generate_verl_config(self): "enable_gradient_checkpointing": enable_gradient_checkpointing, "use_remove_padding": use_remove_padding, }, - "ppo_mini_batch_size": st.session_state["train_batch_size"], + "ppo_mini_batch_size": get_train_batch_size(), "ppo_micro_batch_size_per_gpu": st.session_state[ "critic_ppo_micro_batch_size_per_gpu" ], diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index e409a2dbc9..b7ae8fa066 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -82,19 +82,7 @@ def set_training_strategy(**kwargs): ) -def use_fsdp(): - return st.session_state["training_strategy"] in {"fsdp", "fsdp2"} - - -def use_fsdp1(): - return st.session_state["training_strategy"] == "fsdp" - - -def use_fsdp2(): - return st.session_state["training_strategy"] == "fsdp2" - - -@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) +@CONFIG_GENERATORS.register_config(default_value=False) def set_param_offload(**kwargs): st.checkbox("Param Offload", **kwargs) @@ -104,22 +92,22 @@ def set_grad_offload(**kwargs): st.checkbox("Grad Offload", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) +@CONFIG_GENERATORS.register_config(default_value=False) def set_optimizer_offload(**kwargs): st.checkbox("Optimizer Offload", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp1) +@CONFIG_GENERATORS.register_config(default_value=False) def set_forward_prefetch(**kwargs): st.checkbox("FSDP Forward Prefetch", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=False) # , visible=use_fsdp2) +@CONFIG_GENERATORS.register_config(default_value=False) def set_offload_policy(**kwargs): st.checkbox("Enable FSDP2 offload_policy", **kwargs) -@CONFIG_GENERATORS.register_config(default_value=True) # , visible=use_fsdp2) +@CONFIG_GENERATORS.register_config(default_value=True) def set_reshard_after_forward(**kwargs): st.checkbox("FSDP2 Reshard After Forward", **kwargs) diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index c3d956f2b1..1f02e14efe 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -32,6 +32,8 @@ def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None: plugin_dirs = [plugin_dirs] plugin_dirs = set(plugin_dirs) for plugin_dir in plugin_dirs: + if plugin_dir == "": + continue if not os.path.exists(plugin_dir): logger.error(f"plugin-dir [{plugin_dir}] does not exist.") continue