diff --git a/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md b/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md index 00139883f3..b578780c6d 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md @@ -98,7 +98,7 @@ synchronizer: trainer: grad_clip: 1.0 use_dynamic_bsz: true - ppo_max_token_len_per_gpu: 16384 + max_token_len_per_gpu: 16384 ulysses_sequence_parallel_size: 1 ``` diff --git a/trinity/common/config.py b/trinity/common/config.py index f96077baf3..ba66634539 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -476,7 +476,7 @@ class TrainerConfig: # trainer configs grad_clip: float = 1.0 use_dynamic_bsz: bool = True - # if None, automatically set to 2 * model.max_model_len / ulysses_sequence_parallel_size + # if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size) max_token_len_per_gpu: Optional[int] = None ulysses_sequence_parallel_size: int = 1 # sp size # TODO: extract more train-related params from underlying trainer engine diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 31a95d0b76..a5058b456e 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -31,13 +31,13 @@ class ConfigManager: def __init__(self): - load_plugins() + if "_init_config_manager" not in st.session_state: + self.reset_session_state() + load_plugins() self.unfinished_fields = set() CONFIG_GENERATORS.set_unfinished_fields(self.unfinished_fields) st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:") st.title("Trinity-RFT Config Generator") - if "_init_config_manager" not in st.session_state: - self.reset_session_state() self.maintain_session_state() mode = st.pills( "Select Mode", @@ -93,10 +93,8 @@ def maintain_list_state(prefix, key_list): self.inference_model_keys = [ "model_path", - "engine_type", "engine_num", "tensor_parallel_size", - "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill", @@ -104,6 +102,7 @@ def maintain_list_state(prefix, key_list): "dtype", "seed", "enable_thinking", + "enable_history", "enable_openai_api", "enable_auto_tool_choice", "tool_call_parser", @@ -115,48 +114,43 @@ def get_configs(self, *config_names: str, columns_spec: List[int] = None): CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) def beginner_mode(self): - st.header("Essential Configs") - self.get_configs("project", "exp_name", columns_spec=[1, 2]) - - self.get_configs("model_path") - + st.subheader("Global Config") + self.get_configs("project", "exp_name") self.get_configs("checkpoint_root_dir") + self.get_configs("monitor_type", "log_level", "save_interval") - if st.session_state["algorithm_type"] != "dpo": - self.get_configs("taskset_path") - else: - self.get_configs("experience_buffer_path") - - self.get_configs("algorithm_type", "monitor_type") - - st.header("Important Configs") - self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size") - - self.get_configs("total_epochs", "explore_batch_size", "repeat_times", "train_batch_size") + st.subheader("Model Config") + self.get_configs("model_path", "max_model_len", columns_spec=[3, 1]) - self.get_configs("storage_type", "max_response_tokens", "max_model_len", "ppo_epochs") + st.subheader("Algorithm Config") + self.get_configs("algorithm_type", "repeat_times", "actor_lr", "critic_lr") - self.get_configs("sync_interval", "eval_interval", "save_interval") - - if st.session_state["algorithm_type"] != "dpo": - self.get_configs("taskset_args") + st.subheader("Dataset Config") + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + self.get_configs("taskset_path", "explore_batch_size", columns_spec=[3, 1]) else: + self.get_configs("experience_buffer_path", "train_batch_size", columns_spec=[3, 1]) + if st.session_state["algorithm_type"] == "dpo": self.get_configs("dpo_dataset_kwargs") + elif st.session_state["algorithm_type"] == "sft": + self.get_configs("sft_dataset_kwargs") + else: + self.get_configs("taskset_args") + self.get_configs("default_workflow_type", "default_reward_fn_type") + self.get_configs("total_epochs", "total_steps") - self.get_configs( - "default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type" - ) + st.subheader("Resource Config") + self.get_configs("node_num", "gpu_per_node") + self.get_configs("engine_num", "tensor_parallel_size") + self.get_configs("trainer_gpu_num_display", "actor_ulysses_sequence_parallel_size") - self.get_configs( - "actor_ppo_micro_batch_size_per_gpu", - "actor_lr", - "ref_log_prob_micro_batch_size_per_gpu", - ) - - self.get_configs("critic_ppo_micro_batch_size_per_gpu", "critic_lr") + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + st.subheader("Synchronizer Config") + st.caption("Synchronization between trainer and explorer.") + self.get_configs("sync_method", "sync_style", "sync_interval") def _expert_model_part(self): - self.get_configs("project", "exp_name", columns_spec=[1, 2]) + self.get_configs("project", "exp_name") self.get_configs("model_path") self.get_configs("critic_model_path") @@ -167,29 +161,32 @@ def _expert_model_part(self): self.get_configs("max_response_tokens", "max_model_len") def _expert_buffer_part(self): - self.get_configs("total_epochs", "explore_batch_size", "train_batch_size") + self.get_configs("total_epochs", "total_steps", "explore_batch_size", "train_batch_size") self.get_configs( "default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type" ) - self.get_configs("system_prompt") - self.get_configs("reply_prefix") - if st.session_state["algorithm_type"] != "dpo": - with st.expander("Taskset Configs", expanded=True): - self.get_configs("taskset_path") - self.get_configs("taskset_args") - else: + if st.session_state["algorithm_type"] == "dpo": with st.expander("DPO Dataset Configs", expanded=True): self.get_configs("experience_buffer_path") self.get_configs("storage_type") self.get_configs("dpo_dataset_kwargs") + elif st.session_state["algorithm_type"] == "sft": + with st.expander("SFT Dataset Configs", expanded=True): + self.get_configs("experience_buffer_path") + self.get_configs("storage_type") + self.get_configs("sft_dataset_kwargs") + else: + with st.expander("Taskset Configs", expanded=True): + self.get_configs("taskset_path") + self.get_configs("taskset_args") with st.expander("Eval Tasksets Configs", expanded=True): self.get_configs("eval_tasksets") - if st.session_state["algorithm_type"] != "dpo": - with st.expander("Experiences Buffer Configs", expanded=True): + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + with st.expander("Experience Buffer Configs", expanded=True): self.get_configs("storage_type") self.get_configs("experience_buffer_path") self.get_configs("use_priority_queue") @@ -201,24 +198,19 @@ def _expert_buffer_part(self): # self.get_configs("buffer_max_retry_times", "max_retry_interval") def _expert_explorer_part(self): - self.get_configs("sync_method", "sync_interval", "sync_timeout") - - self.get_configs( - "runner_per_model", "max_timeout", "explorer_max_retry_times", "eval_interval" - ) + self.get_configs("sync_method", "sync_style", "sync_interval", "sync_timeout") - self.get_configs("bench_on_latest_checkpoint") + self.get_configs("runner_per_model", "eval_interval") with st.expander("Rollout Model Config", expanded=True): self.get_configs("engine_type", "engine_num", "tensor_parallel_size") self.get_configs("gpu_memory_utilization", "dtype", "seed") - self.get_configs( - "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill" - ) + self.get_configs("enforce_eager", "enable_prefix_caching", "enable_chunked_prefill") - self.get_configs("enable_thinking", "enable_openai_api", "enable_auto_tool_choice") + self.get_configs("enable_thinking", "enable_history") + self.get_configs("enable_openai_api", "enable_auto_tool_choice") self.get_configs("tool_call_parser", "reasoning_parser") with st.expander("Auxiliary Models", expanded=True): @@ -226,8 +218,9 @@ def _expert_explorer_part(self): def _expert_trainer_part(self): self.get_configs("algorithm_type", "repeat_times", "save_interval") - self.get_configs("sample_strategy", "advantage_fn", "entropy_loss_fn") - self.get_configs("policy_loss_fn", "kl_penalty_fn", "kl_loss_fn") + self.get_configs("policy_loss_fn", "advantage_fn", "sample_strategy") + self.get_configs("kl_penalty_fn", "kl_loss_fn", "kl_coef_in_kl_loss_fn") + self.get_configs("entropy_loss_fn", "entropy_coef_in_entropy_loss_fn") with st.expander("Advanced Algorithm Config"): algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) @@ -240,6 +233,8 @@ def _expert_trainer_part(self): default_args = register_map[key].get(value).default_args() for sub_key in default_args.keys(): full_key = sub_key + "_in_" + key + if full_key in ("kl_coef_in_kl_loss_fn", "entropy_coef_in_entropy_loss_fn"): + continue config_key_list.append(full_key) idx = 0 @@ -292,7 +287,7 @@ def _expert_verl_training_part(self): self.get_configs("recompute_modules") with st.expander("Advanced Config"): - self.get_configs("critic_warmup", "total_training_steps") + self.get_configs("critic_warmup") self.get_configs("default_hdfs_dir") @@ -302,19 +297,22 @@ def _expert_verl_training_part(self): def _expert_verl_actor_part(self): st.subheader("Actor Model Config") + + self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") + + self.get_configs("actor_grad_clip", "actor_ulysses_sequence_parallel_size") + self.get_configs( "actor_ppo_micro_batch_size_per_gpu", "ref_log_prob_micro_batch_size_per_gpu", - "actor_ulysses_sequence_parallel_size", - "actor_entropy_from_logits_with_chunking", - "actor_entropy_checkpointing", + "actor_ppo_max_token_len_per_gpu", ) - self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") + self.get_configs("actor_entropy_from_logits_with_chunking", "actor_entropy_checkpointing") - self.get_configs("actor_grad_clip") + self.get_configs("actor_load_checkpoint") - self.get_configs("actor_load_checkpoint", "actor_save_checkpoint") + self.get_configs("actor_save_checkpoint") def _expert_verl_critic_part(self): st.subheader("Critic Model Config") @@ -434,7 +432,8 @@ def _generate_verl_config(self): "actor_ppo_micro_batch_size_per_gpu" ], "use_dynamic_bsz": use_dynamic_bsz, - "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu, + "ppo_max_token_len_per_gpu": st.session_state["actor_ppo_max_token_len_per_gpu"] + or ppo_max_token_len_per_gpu, "ppo_epochs": st.session_state["ppo_epochs"], "ulysses_sequence_parallel_size": st.session_state[ "actor_ulysses_sequence_parallel_size" @@ -449,9 +448,6 @@ def _generate_verl_config(self): }, }, "ref": { - "log_prob_micro_batch_size_per_gpu": st.session_state[ - "ref_log_prob_micro_batch_size_per_gpu" - ], "log_prob_use_dynamic_bsz": use_dynamic_bsz, "log_prob_max_token_len_per_gpu": ppo_max_token_len_per_gpu, "ulysses_sequence_parallel_size": st.session_state[ @@ -491,7 +487,6 @@ def _generate_verl_config(self): "lr": st.session_state["critic_lr"], "lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"], "warmup_style": st.session_state["critic_warmup_style"], - "total_training_steps": (st.session_state["total_training_steps"] or -1), }, "model": { "override_config": {}, @@ -558,33 +553,40 @@ def _gen_algorithm_config(self): def _gen_buffer_config(self): experience_buffer_path = st.session_state["experience_buffer_path"].strip() - if st.session_state["algorithm_type"] != "dpo": + if st.session_state["algorithm_type"] not in ("dpo", "sft"): if ( not experience_buffer_path and st.session_state["storage_type"] == StorageType.SQL.value ): experience_buffer_path = f"sqlite:///{os.path.join(st.session_state['checkpoint_root_dir'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" + else: + st.session_state["storage_type"] = StorageType.FILE.value buffer_config = { "batch_size": st.session_state["explore_batch_size"], "train_batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], + "total_steps": st.session_state["total_steps"], "explorer_input": {}, "trainer_input": { "experience_buffer": { "name": "experience_buffer", "storage_type": st.session_state["storage_type"], "path": experience_buffer_path, - # "max_retry_interval": st.session_state["max_retry_interval"], - # "max_retry_times": st.session_state["buffer_max_retry_times"], }, }, } if not experience_buffer_path: del buffer_config["trainer_input"]["experience_buffer"]["path"] if st.session_state["train_batch_size"] is None: - del buffer_config["train_batch_size"] - if st.session_state["algorithm_type"] != "dpo": + if st.session_state["algorithm_type"] in ("dpo", "sft"): + buffer_config["train_batch_size"] = ( + st.session_state["explore_batch_size"] * st.session_state["repeat_times"] + ) + del buffer_config["batch_size"] + else: + del buffer_config["train_batch_size"] + if st.session_state["algorithm_type"] not in ("dpo", "sft"): experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["use_priority_queue"] = st.session_state["use_priority_queue"] experience_buffer["reuse_cooldown_time"] = st.session_state["reuse_cooldown_time"] @@ -614,8 +616,6 @@ def _gen_buffer_config(self): "default_workflow_type": st.session_state["default_workflow_type"], "default_eval_workflow_type": st.session_state["default_eval_workflow_type"], "default_reward_fn_type": st.session_state["default_reward_fn_type"], - "system_prompt": st.session_state["system_prompt"], - "reply_prefix": st.session_state["reply_prefix"], } for idx in range(st.session_state["_eval_tasksets_num"]): if st.session_state[f"eval_taskset_{idx}_path"].strip(): @@ -650,25 +650,28 @@ def _gen_buffer_config(self): "chosen_key": st.session_state["dpo_dataset_chosen_key"], "rejected_key": st.session_state["dpo_dataset_rejected_key"], } + elif st.session_state["algorithm_type"] == "sft": + experience_buffer = buffer_config["trainer_input"]["experience_buffer"] + experience_buffer["split"] = st.session_state["sft_dataset_train_split"] + experience_buffer["format"] = { + "prompt_type": st.session_state["sft_dataset_prompt_type"], + "prompt_key": st.session_state["sft_dataset_prompt_key"], + "messages_key": st.session_state["sft_dataset_messages_key"], + } return buffer_config def _gen_explorer_config(self): explorer_config = { "runner_per_model": st.session_state["runner_per_model"], - "max_timeout": st.session_state["max_timeout"], - "max_retry_times": st.session_state["explorer_max_retry_times"], "rollout_model": { key: st.session_state[key] for key in self.inference_model_keys if key != "model_path" - # "max_response_tokens": None, # TODO - # "max_model_len": None, # TODO # "chat_template": None, # TODO: add chat template }, "auxiliary_models": [], "eval_interval": st.session_state["eval_interval"], - "bench_on_latest_checkpoint": st.session_state["bench_on_latest_checkpoint"], } for i in range(st.session_state["_auxiliary_models_num"]): auxiliary_model_config = { @@ -726,7 +729,7 @@ def generate_config(self): "trainer_type": st.session_state["trainer_type"], "save_interval": st.session_state["save_interval"], "enable_preview": st.session_state["enable_preview"], - "actor_grad_clip": st.session_state["actor_grad_clip"], + "grad_clip": st.session_state["actor_grad_clip"], "trainer_config": trainer_config, }, "monitor": { @@ -734,9 +737,13 @@ def generate_config(self): }, "synchronizer": { "sync_method": st.session_state["sync_method"], + "sync_style": st.session_state["sync_style"], "sync_interval": st.session_state["sync_interval"], "sync_timeout": st.session_state["sync_timeout"], }, + "log": { + "level": st.session_state["log_level"], + }, } if use_critic(): @@ -747,11 +754,30 @@ def generate_config(self): ) st.session_state.config_generated = True - st.header("Generated Config File") - buttons = st.container() - save_btn, run_btn = buttons.columns(2, vertical_alignment="bottom") + st.subheader("Generated Config File") + # buttons = st.container() + # save_btn, run_btn = buttons.columns(2, vertical_alignment="bottom") yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False) - save_btn.download_button( + # save_btn.download_button( + # "Save", + # data=yaml_config, + # file_name=f"{config['project']}-{config['name']}.yaml", + # mime="text/plain", + # icon=":material/download:", + # use_container_width=True, + # ) + # run_btn.button( + # "Run", + # on_click=self.run_config, + # args=( + # buttons, + # yaml_config, + # ), + # icon=":material/terminal:", + # use_container_width=True, + # disabled=st.session_state.is_running, + # ) + st.download_button( "Save", data=yaml_config, file_name=f"{config['project']}-{config['name']}.yaml", @@ -759,17 +785,6 @@ def generate_config(self): icon=":material/download:", use_container_width=True, ) - run_btn.button( - "Run", - on_click=self.run_config, - args=( - buttons, - yaml_config, - ), - icon=":material/terminal:", - use_container_width=True, - disabled=st.session_state.is_running, - ) st.code(yaml_config, language="yaml") def run_config(self, parent, yaml_config: str) -> None: diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py index 7246c2d364..ca1270ab36 100644 --- a/trinity/manager/config_registry/algorithm_config_manager.py +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -6,7 +6,7 @@ OPMDAdvantageFn, PPOAdvantageFn, ) -from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm +from trinity.algorithm.algorithm import ALGORITHM_TYPE, GRPOAlgorithm from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( ENTROPY_LOSS_FN, EntropyLossFn, @@ -23,15 +23,16 @@ from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num +from trinity.utils.registry import Registry @CONFIG_GENERATORS.register_config( - default_value="ppo", - other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()}, + default_value="grpo", + other_configs={"mode": "both", "_current_default_config": GRPOAlgorithm.default_config()}, ) def set_algorithm_type(**kwargs): def on_change(): - if st.session_state["algorithm_type"] == "dpo": + if st.session_state["algorithm_type"] in ("dpo", "sft"): st.session_state["mode"] = "train" else: st.session_state["mode"] = "both" @@ -45,25 +46,27 @@ def on_change(): candidates = list(ALGORITHM_TYPE.modules.keys()) st.selectbox( "Algorithm Type", - candidates, + options=candidates, + format_func=lambda x: x.upper(), on_change=on_change, **kwargs, ) @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["repeat_times"], + default_value=GRPOAlgorithm.default_config()["repeat_times"], visible=lambda: "repeat_times" in st.session_state["_current_default_config"], other_configs={ "_grouped_adv_repeat_times": 2, "_not_grouped_adv_repeat_times": 1, }, ) -def set_repeat_times(**kwargs): # TODO +def set_repeat_times(**kwargs): key = kwargs.get("key") grouped_adv_algorithms = [ "grpo", - "opmd", # TODO: may add rloo + "opmd", + "rloo", ] if st.session_state["algorithm_type"] in grouped_adv_algorithms: min_repeat_times = 2 @@ -82,7 +85,7 @@ def on_change(): "Repeat Times", min_value=min_repeat_times, help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + "and it must be greater than `1` when `algorithm_type` is `grpo`, `opmd` or 'rloo`.", on_change=on_change, **kwargs, ) @@ -92,15 +95,17 @@ def on_change(): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["sample_strategy"], + default_value=GRPOAlgorithm.default_config()["sample_strategy"], visible=lambda: "sample_strategy" in st.session_state["_current_default_config"], ) def set_sample_strategy(**kwargs): + on_change = _create_on_change_callback("sample_strategy", SAMPLE_STRATEGY, **kwargs) candidates = list(SAMPLE_STRATEGY.modules.keys()) st.selectbox( "Sample Strategy", candidates, help="The sample strategy used to obtain experiences.", + on_change=on_change, **kwargs, ) @@ -124,15 +129,18 @@ def set_expert_data_ratio_in_sample_strategy(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["advantage_fn"], + default_value=GRPOAlgorithm.default_config()["advantage_fn"], visible=lambda: "advantage_fn" in st.session_state["_current_default_config"], ) def set_advantage_fn(**kwargs): + on_change = _create_on_change_callback("advantage_fn", ADVANTAGE_FN, **kwargs) candidates = list(ADVANTAGE_FN.modules.keys()) st.selectbox( "Advantage Function", - candidates, + options=candidates, + format_func=lambda x: x.upper(), help="The advantage function used to compute advantages.", + on_change=on_change, **kwargs, ) @@ -142,7 +150,7 @@ def set_advantage_fn(**kwargs): visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"}, ) def set_gamma_in_advantage_fn(**kwargs): - st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) + st.number_input(r"Gamma :blue-badge[$\gamma$]", help="Discounted factor used in RL", **kwargs) @CONFIG_GENERATORS.register_config( @@ -150,14 +158,18 @@ def set_gamma_in_advantage_fn(**kwargs): visible=lambda: st.session_state["advantage_fn"] == "ppo", ) def set_lam_in_advantage_fn(**kwargs): - st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) + st.number_input( + r"Lambda :blue-badge[$\lambda$]", + help="Lambda value when computing Generalized Advantage Estimation", + **kwargs, + ) @CONFIG_GENERATORS.register_config( default_value=GRPOAdvantageFn.default_args()["epsilon"], visible=lambda: st.session_state["advantage_fn"] == "grpo", ) -def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message +def set_epsilon_in_advantage_fn(**kwargs): st.number_input( r"GRPO Epsilon", help=r""" @@ -194,14 +206,17 @@ def set_tau_in_advantage_fn(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["kl_loss_fn"], + default_value=GRPOAlgorithm.default_config()["kl_loss_fn"], visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"], ) def set_kl_loss_fn(**kwargs): + on_change = _create_on_change_callback("kl_loss_fn", KL_FN, **kwargs) candidates = list(KL_FN.modules.keys()) st.selectbox( "KL Loss Type", - candidates, + options=candidates, + format_func=lambda x: x.upper(), + on_change=on_change, **kwargs, ) @@ -224,14 +239,17 @@ def set_kl_coef_in_kl_loss_fn(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["kl_penalty_fn"], + default_value=GRPOAlgorithm.default_config()["kl_penalty_fn"], visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"], ) def set_kl_penalty_fn(**kwargs): + on_change = _create_on_change_callback("kl_penalty_fn", KL_FN, **kwargs) candidates = list(KL_FN.modules.keys()) st.selectbox( "KL Penalty Type", - candidates, + options=candidates, + format_func=lambda x: x.upper(), + on_change=on_change, **kwargs, ) @@ -267,14 +285,17 @@ def set_kl_coef_in_kl_penalty_fn(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["policy_loss_fn"], + default_value=GRPOAlgorithm.default_config()["policy_loss_fn"], visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"], ) def set_policy_loss_fn(**kwargs): + on_change = _create_on_change_callback("policy_loss_fn", POLICY_LOSS_FN, **kwargs) candidates = list(POLICY_LOSS_FN.modules.keys()) st.selectbox( "Policy Loss Fn", - candidates, + options=candidates, + format_func=lambda x: x.upper(), + on_change=on_change, **kwargs, ) @@ -356,12 +377,18 @@ def set_mu_in_policy_loss_fn(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=PPOAlgorithm.default_config()["entropy_loss_fn"], + default_value=GRPOAlgorithm.default_config()["entropy_loss_fn"], visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"], ) def set_entropy_loss_fn(**kwargs): + on_change = _create_on_change_callback("entropy_loss_fn", ENTROPY_LOSS_FN, **kwargs) candidates = list(ENTROPY_LOSS_FN.modules.keys()) - st.selectbox("Entropy Loss Function", candidates, **kwargs) + st.selectbox( + "Entropy Loss Function", + options=candidates, + on_change=on_change, + **kwargs, + ) @CONFIG_GENERATORS.register_config( @@ -376,3 +403,19 @@ def set_entropy_coef_in_entropy_loss_fn(**kwargs): format="%.1e", **kwargs, ) + + +# define on_change +def _create_on_change_callback(key_name: str, registry: Registry, **kwargs): + """Creates an on_change callback to update dependent configs.""" + + def on_change(): + value = st.session_state[kwargs.get("key", key_name)] + value_class = registry.get(value) + if value_class: + default_args = value_class.default_args() + for arg_key, arg_value in default_args.items(): + full_key = f"{arg_key}_in_{key_name}" + st.session_state[full_key] = arg_value + + return on_change diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index de86e77ac4..b37bf20936 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -12,6 +12,13 @@ def set_total_epochs(**kwargs): st.number_input("Total Epochs", min_value=1, **kwargs) +@CONFIG_GENERATORS.register_config(default_value=None) +def set_total_steps(**kwargs): + st.number_input( + "Total Steps", min_value=1, help="If set, `Total Epochs` will be ignored", **kwargs + ) + + @CONFIG_GENERATORS.register_config(default_value=96) def set_explore_batch_size(**kwargs): st.number_input( @@ -45,7 +52,7 @@ def _str_for_train_batch_size(): else "`gpu_per_node * node_num`" ) return ( - f"`train_batch_size` defaults to `task_batch_size` * `repeat_times`.\n\n" + f"Number of experiences in a mini-batch; defaults to `task_batch_size` * `repeat_times`.\n\n" f"Please ensure that `train_batch_size` ({get_train_batch_size()}) can be divided by " f"{trainer_gpu_num_str} ({st.session_state['trainer_gpu_num']})." ) @@ -130,8 +137,7 @@ def set_taskset_args(**kwargs): subset_name_col.text_input( "Subset Name :orange-badge[(Needs review)]", key="taskset_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + help="The subset name used for `datasets.load_datasets`, defaults to `None`", ) split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") prompt_key_col, response_key_col = st.columns(2) @@ -163,8 +169,7 @@ def _set_eval_taskset_idx(idx): subset_name_col.text_input( "Subset Name :orange-badge[(Needs review)]", key=f"eval_taskset_{idx}_subset_name", - help="The subset name used for `datasets.load_datasets`, see " - "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + help="The subset name used for `datasets.load_datasets`, defaults to `None`", ) split_col.text_input( "Eval Split :orange-badge[(Needs review)]", @@ -264,46 +269,25 @@ def set_default_reward_fn_type(**kwargs): ) -@CONFIG_GENERATORS.register_config(default_value=None) -def set_system_prompt(**kwargs): - st.text_area( - "System Prompt", - placeholder="""You are a helpful assistant that solves MATH problems....""", - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config(default_value=None) -def set_reply_prefix(**kwargs): - st.text_area( - "Assistant Reply Prefix", - placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ - """and a common setting is: \nLet me solve this step by step. """, - **kwargs, - ) - - @CONFIG_GENERATORS.register_config( default_value=StorageType.QUEUE.value, other_configs={ - "_dpo_storage_type": StorageType.FILE.value, - "_not_dpo_storage_type": StorageType.QUEUE.value, + "_offline_dataset_storage_type": StorageType.FILE.value, + "_not_offline_dataset_storage_type": StorageType.QUEUE.value, }, ) def set_storage_type(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == "dpo": - st.session_state[key] = st.session_state["_dpo_storage_type"] + if st.session_state["algorithm_type"] in ("dpo", "sft"): + st.session_state[key] = st.session_state["_offline_dataset_storage_type"] storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: - st.session_state[key] = st.session_state["_not_dpo_storage_type"] + st.session_state[key] = st.session_state["_not_offline_dataset_storage_type"] storage_candidates = [StorageType.QUEUE.value] def on_change(): - if st.session_state["algorithm_type"] == "dpo": - st.session_state["_dpo_storage_type"] = st.session_state[key] - else: - st.session_state["_not_dpo_storage_type"] = st.session_state[key] + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + st.session_state["_not_offline_dataset_storage_type"] = st.session_state[key] st.selectbox( "Storage Type", @@ -326,7 +310,7 @@ def set_reuse_cooldown_time(**kwargs): "Reuse Cooldown Time", min_value=0.0, max_value=1e5, - help="Leave blank to indicate no reuse", + help="Leave blank to indicate no experience reuse", placeholder=None, **kwargs, ) @@ -357,32 +341,33 @@ def set_priority_decay(**kwargs): @CONFIG_GENERATORS.register_config( default_value="", other_configs={ - "_dpo_experience_buffer_path": "", - "_not_dpo_experience_buffer_path": "", + "_offline_dataset_experience_buffer_path": "", + "_not_offline_dataset_experience_buffer_path": "", }, ) -def set_experience_buffer_path(**kwargs): # TODO +def set_experience_buffer_path(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == "dpo": - if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: - st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] - st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] - title = "DPO Dataset Path" - help_msg = r"""This path to DPO dataset, - -if `storage_type == StorageType.FILE`, this should be a path to a file, - -if `storage_type == StorageType.SQL`, this should be a path to database.""" + if st.session_state["algorithm_type"] in ("dpo", "sft"): + if ( + st.session_state["taskset_path"] + and not st.session_state["_offline_dataset_experience_buffer_path"] + ): + st.session_state["_offline_dataset_experience_buffer_path"] = st.session_state[ + "taskset_path" + ] + st.session_state[key] = st.session_state["_offline_dataset_experience_buffer_path"] + title = "Dataset Path" + help_msg = r"""Path to the dataset.""" else: - st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"] + st.session_state[key] = st.session_state["_not_offline_dataset_experience_buffer_path"] title = "Experience Buffer Path" help_msg = r"""This path is used for experiences persistent storage, default to `None`.""" def on_change(): - if st.session_state["algorithm_type"] == "dpo": - st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] + if st.session_state["algorithm_type"] in ("dpo", "sft"): + st.session_state["_offline_dataset_experience_buffer_path"] = st.session_state[key] else: - st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] + st.session_state["_not_offline_dataset_experience_buffer_path"] = st.session_state[key] st.text_input(title, help=help_msg, on_change=on_change, **kwargs) @@ -393,24 +378,36 @@ def check_experience_buffer_path(unfinished_fields: set, key: str): if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input DPO dataset path.") + elif st.session_state["algorithm_type"] == "sft": + if not st.session_state[key].strip(): + unfinished_fields.add(key) + st.warning("Please input SFT dataset path.") @CONFIG_GENERATORS.register_config( other_configs={ + "dpo_dataset_subset_name": None, "dpo_dataset_train_split": "train", - "dpo_dataset_prompt_type": PromptType.MESSAGES.value, + "dpo_dataset_prompt_type": PromptType.PLAINTEXT.value, "dpo_dataset_prompt_key": "prompt", "dpo_dataset_chosen_key": "chosen", "dpo_dataset_rejected_key": "rejected", } ) def set_dpo_dataset_kwargs(**kwargs): - dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) + ( + dpo_dataset_subset_name_col, + dpo_dataset_train_split_col, + dpo_dataset_prompt_type_col, + ) = st.columns(3) + dpo_dataset_subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", key="dpo_dataset_subset_name" + ) dpo_dataset_train_split_col.text_input( - "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" + "Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split" ) dpo_dataset_prompt_type_col.selectbox( - "DPO Dataset Prompt Type :orange-badge[(Needs review)]", + "Prompt Type :orange-badge[(Needs review)]", [prompt_type.value for prompt_type in PromptType], key="dpo_dataset_prompt_type", ) @@ -421,12 +418,58 @@ def set_dpo_dataset_kwargs(**kwargs): dpo_dataset_rejected_key_col, ) = st.columns(3) dpo_dataset_prompt_key_col.text_input( - "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" + "Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key" ) dpo_dataset_chosen_key_col.text_input( - "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" + "Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key" ) dpo_dataset_rejected_key_col.text_input( - "DPO Dataset Rejected Key :orange-badge[(Needs review)]", + "Rejected Key :orange-badge[(Needs review)]", key="dpo_dataset_rejected_key", ) + + +@CONFIG_GENERATORS.register_config( + other_configs={ + "sft_dataset_subset_name": None, + "sft_dataset_train_split": "train", + "sft_dataset_prompt_type": PromptType.MESSAGES.value, + "sft_dataset_prompt_key": "prompt", + "sft_dataset_response_key": "response", + "sft_dataset_messages_key": "messages", + } +) +def set_sft_dataset_kwargs(**kwargs): + ( + sft_dataset_subset_name_col, + sft_dataset_train_split_col, + sft_dataset_prompt_type_col, + ) = st.columns(3) + sft_dataset_subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", key="sft_dataset_subset_name" + ) + sft_dataset_train_split_col.text_input( + "Train Split :orange-badge[(Needs review)]", key="sft_dataset_train_split" + ) + sft_dataset_prompt_type_col.selectbox( + "Prompt Type :orange-badge[(Needs review)]", + [prompt_type.value for prompt_type in PromptType], + key="sft_dataset_prompt_type", + help="When `Prompt Type` is `plaintext`, `Prompt Key` and `Response Key` are effective; when `Prompt Type` is `messages`, `Messages Key` is effective.", + ) + + ( + sft_dataset_prompt_key_col, + sft_dataset_response_key_col, + sft_dataset_messages_key_col, + ) = st.columns(3) + sft_dataset_prompt_key_col.text_input( + "Prompt Key :orange-badge[(Needs review)]", key="sft_dataset_prompt_key" + ) + sft_dataset_response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", key="sft_dataset_response_key" + ) + sft_dataset_messages_key_col.text_input( + "Messages Key :orange-badge[(Needs review)]", + key="sft_dataset_messages_key", + ) diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index d85450ec8c..4db9540e31 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.common.constants import SyncMethod +from trinity.common.constants import SyncMethod, SyncStyle from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num @@ -14,31 +14,11 @@ def set_runner_per_model(**kwargs): st.number_input("Runner per Model", min_value=1, **kwargs) -@CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) -def set_max_timeout(**kwargs): - st.number_input("Max Timeout", min_value=0, **kwargs) - - -@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) -def set_explorer_max_retry_times(**kwargs): - st.number_input("Explorer Max Retry Times", min_value=0, **kwargs) - - @CONFIG_GENERATORS.register_config(default_value=1000, visible=explorer_visible) def set_eval_interval(**kwargs): st.number_input("Eval Interval", min_value=1, **kwargs) -@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) -def set_bench_on_latest_checkpoint(**kwargs): - st.checkbox("Eval on Latest Checkpoint", **kwargs) - - -@CONFIG_GENERATORS.register_config(default_value="vllm_async", visible=explorer_visible) -def set_engine_type(**kwargs): - st.selectbox("Engine Type", ["vllm_async", "vllm"], **kwargs) - - @CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible) def set_engine_num(**kwargs): key = kwargs.get("key") @@ -48,7 +28,7 @@ def set_engine_num(**kwargs): st.session_state[key] = max_engine_num set_trainer_gpu_num() st.number_input( - "Engine Num", + "Explorer Engine Num", min_value=1, max_value=max_engine_num, on_change=set_trainer_gpu_num, @@ -92,11 +72,6 @@ def check_tensor_parallel_size(unfinished_fields: set, key: str): ) -@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) -def set_use_v1(**kwargs): - st.checkbox("Use V1 Engine", **kwargs) - - @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) def set_enforce_eager(**kwargs): st.checkbox("Enforce Eager", **kwargs) @@ -127,16 +102,16 @@ def set_seed(**kwargs): st.number_input("Seed", step=1, **kwargs) -# TODO: max_response_tokens -# TODO: max_model_len -# TODO: chat_template - - @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_thinking(**kwargs): st.checkbox("Enable Thinking For Qwen3", **kwargs) +@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) +def set_enable_history(**kwargs): + st.checkbox("Enable History Recording", **kwargs) + + @CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible) def set_enable_openai_api(**kwargs): st.checkbox("Enable OpenAI API", **kwargs) @@ -174,13 +149,10 @@ def _set_auxiliary_model_idx(idx): if col2.button("✖️", key=f"auxiliary_model_{idx}_del_flag", type="primary"): st.rerun() - engine_type_col, engine_num_col, tensor_parallel_size_col = st.columns(3) + engine_num_col, tensor_parallel_size_col = st.columns(2) total_gpu_num = st.session_state["total_gpu_num"] - engine_type_col.selectbox( - "Engine Type", ["vllm_async"], key=f"auxiliary_model_{idx}_engine_type" - ) engine_num_col.number_input( - "Engine Num", + "Explorer Engine Num", min_value=1, max_value=total_gpu_num - 1, on_change=set_trainer_gpu_num, @@ -207,12 +179,10 @@ def _set_auxiliary_model_idx(idx): seed_col.number_input("Seed", step=1, key=f"auxiliary_model_{idx}_seed") ( - use_v1_col, enforce_eager_col, enable_prefix_caching_col, enable_chunked_prefill_col, - ) = st.columns(4) - use_v1_col.checkbox("Use V1 Engine", key=f"auxiliary_model_{idx}_use_v1") + ) = st.columns(3) enforce_eager_col.checkbox("Enforce Eager", key=f"auxiliary_model_{idx}_enforce_eager") enable_prefix_caching_col.checkbox( "Prefix Caching", key=f"auxiliary_model_{idx}_enable_prefix_caching" @@ -236,7 +206,6 @@ def set_auxiliary_models(**kwargs): st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] = 1 st.session_state[f"auxiliary_model_{idx}_gpu_memory_utilization"] = 0.9 st.session_state[f"auxiliary_model_{idx}_seed"] = 42 - st.session_state[f"auxiliary_model_{idx}_use_v1"] = True st.session_state[f"auxiliary_model_{idx}_enforce_eager"] = True st.session_state["_auxiliary_models_num"] += 1 set_trainer_gpu_num() @@ -274,34 +243,66 @@ def check_auxiliary_models(unfinished_fields: set, key: str): @CONFIG_GENERATORS.register_config( default_value=SyncMethod.NCCL.value, visible=explorer_visible, - other_configs={"_not_dpo_sync_method": SyncMethod.NCCL.value}, + other_configs={"_not_offline_dataset_sync_method": SyncMethod.NCCL.value}, ) def set_sync_method(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == "dpo": + if st.session_state["algorithm_type"] in ("dpo", "sft"): st.session_state[key] = SyncMethod.CHECKPOINT.value disabled = True else: - st.session_state[key] = st.session_state["_not_dpo_sync_method"] + st.session_state[key] = st.session_state["_not_offline_dataset_sync_method"] disabled = False def on_change(): - if st.session_state["algorithm_type"] != "dpo": - st.session_state["_not_dpo_sync_method"] = st.session_state[key] + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + st.session_state["_not_offline_dataset_sync_method"] = st.session_state[key] st.selectbox( "Sync Method", [sync_method.value for sync_method in SyncMethod], - help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps. + help="""`nccl`: The explorer and trainer sync model weights once by NCCL. + +`checkpoint`: The trainer saves the model checkpoint, and the explorer loads it at `sync_interval`. -`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""", +`memory`: The trainer and explorer sync model weights in memory.""", disabled=disabled, on_change=on_change, **kwargs, ) -@CONFIG_GENERATORS.register_config(default_value=10, visible=explorer_visible) +@CONFIG_GENERATORS.register_config( + default_value=SyncStyle.FIXED.value, + visible=explorer_visible, + other_configs={"_not_offline_dataset_sync_style": SyncStyle.FIXED.value}, +) +def set_sync_style(**kwargs): + key = kwargs.get("key") + if st.session_state["algorithm_type"] in ("dpo", "sft"): + st.session_state[key] = SyncStyle.CHECKPOINT.value + disabled = True + else: + st.session_state[key] = st.session_state["_not_offline_dataset_sync_style"] + disabled = False + + def on_change(): + if st.session_state["algorithm_type"] not in ("dpo", "sft"): + st.session_state["_not_offline_dataset_sync_style"] = st.session_state[key] + + st.selectbox( + "Sync Style", + [sync_style.value for sync_style in SyncStyle], + help="""`fixed`: The explorer and trainer sync model weights once every `sync_interval` steps. + +`dynamic_by_explorer`: The explorer decides to request a sync after `sync_interval` steps.""", + disabled=disabled, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value=1, visible=explorer_visible) def set_sync_interval(**kwargs): st.number_input( "Sync Interval", diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py index 19c026e798..a2207d1ac8 100644 --- a/trinity/manager/config_registry/model_config_manager.py +++ b/trinity/manager/config_registry/model_config_manager.py @@ -24,9 +24,12 @@ def set_trainer_gpu_num(): engine_num = st.session_state[f"auxiliary_model_{idx}_engine_num"] tensor_parallel_size = st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] trainer_gpu_num -= engine_num * tensor_parallel_size - st.session_state["trainer_gpu_num"] = trainer_gpu_num + st.session_state["trainer_gpu_num"] = int(trainer_gpu_num) else: # model == train - st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"] + st.session_state["trainer_gpu_num"] = int(st.session_state["total_gpu_num"]) + + # sync number to display + st.session_state["trainer_gpu_num_display"] = st.session_state["trainer_gpu_num"] @CONFIG_GENERATORS.register_config(default_value="Trinity-RFT") @@ -34,7 +37,7 @@ def set_project(**kwargs): st.text_input("Project", **kwargs) -@CONFIG_GENERATORS.register_config(default_value="qwen2.5-1.5B") +@CONFIG_GENERATORS.register_config(default_value="Example") def set_exp_name(**kwargs): st.text_input("Experiment Name", **kwargs) @@ -56,9 +59,21 @@ def check_checkpoint_root_dir(unfinished_fields: set, key: str): @CONFIG_GENERATORS.register_config(default_value="tensorboard") def set_monitor_type(**kwargs): + candidates = list(MONITOR.modules.keys()) st.selectbox( "Monitor Type", - options=MONITOR.modules.keys(), + options=candidates, + format_func=lambda x: x.capitalize(), + help="Set your API_KEY in environment variables if using `Wandb` or `MLFlow`", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config(default_value="INFO") +def set_log_level(**kwargs): + st.selectbox( + "Log Level", + options=["DEBUG", "INFO", "WARNING", "ERROR"], **kwargs, ) @@ -104,7 +119,7 @@ def set_max_response_tokens(**kwargs): st.number_input("Max Response Length", min_value=1, **kwargs) -@CONFIG_GENERATORS.register_config(default_value=2048) +@CONFIG_GENERATORS.register_config(default_value=4096) def set_max_model_len(**kwargs): st.number_input("Max Model Length", min_value=1, **kwargs) diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index 019ba4678b..060851c154 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -7,6 +7,20 @@ from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +@CONFIG_GENERATORS.register_config(default_value=6) +def set_trainer_gpu_num_display(**kwargs): + from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num + + st.number_input( + "Trainer GPU Num", + disabled=True, + step=1, + on_change=set_trainer_gpu_num, + help="Automatically calculated based on total GPU number and Explorer configurations", + **kwargs, + ) + + def use_critic(): algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) return algorithm.use_critic @@ -28,7 +42,7 @@ def set_save_interval(**kwargs): @CONFIG_GENERATORS.register_config(default_value=True) def set_enable_preview(**kwargs): - st.checkbox("Enable Preview", **kwargs) + st.checkbox("Enable Experience Preview", **kwargs) @CONFIG_GENERATORS.register_config(default_value=1.0) @@ -232,11 +246,6 @@ def set_critic_warmup(**kwargs): st.number_input("Critic Warmup Steps", min_value=0, **kwargs) -@CONFIG_GENERATORS.register_config(default_value=None) -def set_total_training_steps(**kwargs): - st.number_input("Total Training Steps", min_value=1, **kwargs) - - @CONFIG_GENERATORS.register_config(default_value=None) def set_default_hdfs_dir(**kwargs): st.text_input("Default HDFS Dir", **kwargs) @@ -298,7 +307,11 @@ def set_actor_ppo_micro_batch_size_per_gpu(**kwargs): max_value = get_train_batch_size_per_gpu() st.session_state[key] = min(st.session_state[key], max_value) st.number_input( - "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs + "Micro Batch Size Per GPU :blue-badge[(Actor)]", + min_value=1, + max_value=max_value, + help="Micro batch size per GPU; effective when `use_dynamic_bsz` is False", + **kwargs, ) @@ -312,6 +325,16 @@ def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs): ) +@CONFIG_GENERATORS.register_config(default_value=16384) +def set_actor_ppo_max_token_len_per_gpu(**kwargs): + st.number_input( + "Max Token Len Per GPU :blue-badge[(Actor)]", + min_value=1, + help="Max token length per GPU for actor model; effective when `use_dynamic_bsz` is True", + **kwargs, + ) + + @CONFIG_GENERATORS.register_config(default_value=1) def set_actor_ulysses_sequence_parallel_size(**kwargs): st.number_input( @@ -387,6 +410,7 @@ def set_critic_lr(**kwargs): min_value=1e-7, max_value=1e-3, format="%.1e", + help="Effective only when using PPO algorithm", **kwargs, )