Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ synchronizer:
trainer:
grad_clip: 1.0
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 16384
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
```

Expand Down
2 changes: 1 addition & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ class TrainerConfig:
# trainer configs
grad_clip: float = 1.0
use_dynamic_bsz: bool = True
# if None, automatically set to 2 * model.max_model_len / ulysses_sequence_parallel_size
# if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size)
max_token_len_per_gpu: Optional[int] = None
ulysses_sequence_parallel_size: int = 1 # sp size
# TODO: extract more train-related params from underlying trainer engine
Expand Down
213 changes: 114 additions & 99 deletions trinity/manager/config_manager.py

Large diffs are not rendered by default.

89 changes: 66 additions & 23 deletions trinity/manager/config_registry/algorithm_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
OPMDAdvantageFn,
PPOAdvantageFn,
)
from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm
from trinity.algorithm.algorithm import ALGORITHM_TYPE, GRPOAlgorithm
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import (
ENTROPY_LOSS_FN,
EntropyLossFn,
Expand All @@ -23,15 +23,16 @@
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num
from trinity.utils.registry import Registry


@CONFIG_GENERATORS.register_config(
default_value="ppo",
other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()},
default_value="grpo",
other_configs={"mode": "both", "_current_default_config": GRPOAlgorithm.default_config()},
)
def set_algorithm_type(**kwargs):
def on_change():
if st.session_state["algorithm_type"] == "dpo":
if st.session_state["algorithm_type"] in ("dpo", "sft"):
st.session_state["mode"] = "train"
else:
st.session_state["mode"] = "both"
Expand All @@ -45,25 +46,27 @@ def on_change():
candidates = list(ALGORITHM_TYPE.modules.keys())
st.selectbox(
"Algorithm Type",
candidates,
options=candidates,
format_func=lambda x: x.upper(),
on_change=on_change,
**kwargs,
)


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["repeat_times"],
default_value=GRPOAlgorithm.default_config()["repeat_times"],
visible=lambda: "repeat_times" in st.session_state["_current_default_config"],
other_configs={
"_grouped_adv_repeat_times": 2,
"_not_grouped_adv_repeat_times": 1,
},
)
def set_repeat_times(**kwargs): # TODO
def set_repeat_times(**kwargs):
key = kwargs.get("key")
grouped_adv_algorithms = [
"grpo",
"opmd", # TODO: may add rloo
"opmd",
"rloo",
]
if st.session_state["algorithm_type"] in grouped_adv_algorithms:
min_repeat_times = 2
Expand All @@ -82,7 +85,7 @@ def on_change():
"Repeat Times",
min_value=min_repeat_times,
help="`repeat_times` is used to set how many experiences each task can generate, "
"and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
"and it must be greater than `1` when `algorithm_type` is `grpo`, `opmd` or 'rloo`.",
on_change=on_change,
**kwargs,
)
Expand All @@ -92,15 +95,17 @@ def on_change():


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["sample_strategy"],
default_value=GRPOAlgorithm.default_config()["sample_strategy"],
visible=lambda: "sample_strategy" in st.session_state["_current_default_config"],
)
def set_sample_strategy(**kwargs):
on_change = _create_on_change_callback("sample_strategy", SAMPLE_STRATEGY, **kwargs)
candidates = list(SAMPLE_STRATEGY.modules.keys())
st.selectbox(
"Sample Strategy",
candidates,
help="The sample strategy used to obtain experiences.",
on_change=on_change,
**kwargs,
)

Expand All @@ -124,15 +129,18 @@ def set_expert_data_ratio_in_sample_strategy(**kwargs):


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["advantage_fn"],
default_value=GRPOAlgorithm.default_config()["advantage_fn"],
visible=lambda: "advantage_fn" in st.session_state["_current_default_config"],
)
def set_advantage_fn(**kwargs):
on_change = _create_on_change_callback("advantage_fn", ADVANTAGE_FN, **kwargs)
candidates = list(ADVANTAGE_FN.modules.keys())
st.selectbox(
"Advantage Function",
candidates,
options=candidates,
format_func=lambda x: x.upper(),
help="The advantage function used to compute advantages.",
on_change=on_change,
**kwargs,
)

Expand All @@ -142,22 +150,26 @@ def set_advantage_fn(**kwargs):
visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"},
)
def set_gamma_in_advantage_fn(**kwargs):
st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs)
st.number_input(r"Gamma :blue-badge[$\gamma$]", help="Discounted factor used in RL", **kwargs)


@CONFIG_GENERATORS.register_config(
default_value=PPOAdvantageFn.default_args()["lam"],
visible=lambda: st.session_state["advantage_fn"] == "ppo",
)
def set_lam_in_advantage_fn(**kwargs):
st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs)
st.number_input(
r"Lambda :blue-badge[$\lambda$]",
help="Lambda value when computing Generalized Advantage Estimation",
**kwargs,
)


@CONFIG_GENERATORS.register_config(
default_value=GRPOAdvantageFn.default_args()["epsilon"],
visible=lambda: st.session_state["advantage_fn"] == "grpo",
)
def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message
def set_epsilon_in_advantage_fn(**kwargs):
st.number_input(
r"GRPO Epsilon",
help=r"""
Expand Down Expand Up @@ -194,14 +206,17 @@ def set_tau_in_advantage_fn(**kwargs):


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["kl_loss_fn"],
default_value=GRPOAlgorithm.default_config()["kl_loss_fn"],
visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"],
)
def set_kl_loss_fn(**kwargs):
on_change = _create_on_change_callback("kl_loss_fn", KL_FN, **kwargs)
candidates = list(KL_FN.modules.keys())
st.selectbox(
"KL Loss Type",
candidates,
options=candidates,
format_func=lambda x: x.upper(),
on_change=on_change,
**kwargs,
)

Expand All @@ -224,14 +239,17 @@ def set_kl_coef_in_kl_loss_fn(**kwargs):


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["kl_penalty_fn"],
default_value=GRPOAlgorithm.default_config()["kl_penalty_fn"],
visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"],
)
def set_kl_penalty_fn(**kwargs):
on_change = _create_on_change_callback("kl_penalty_fn", KL_FN, **kwargs)
candidates = list(KL_FN.modules.keys())
st.selectbox(
"KL Penalty Type",
candidates,
options=candidates,
format_func=lambda x: x.upper(),
on_change=on_change,
**kwargs,
)

Expand Down Expand Up @@ -267,14 +285,17 @@ def set_kl_coef_in_kl_penalty_fn(**kwargs):


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["policy_loss_fn"],
default_value=GRPOAlgorithm.default_config()["policy_loss_fn"],
visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"],
)
def set_policy_loss_fn(**kwargs):
on_change = _create_on_change_callback("policy_loss_fn", POLICY_LOSS_FN, **kwargs)
candidates = list(POLICY_LOSS_FN.modules.keys())
st.selectbox(
"Policy Loss Fn",
candidates,
options=candidates,
format_func=lambda x: x.upper(),
on_change=on_change,
**kwargs,
)

Expand Down Expand Up @@ -356,12 +377,18 @@ def set_mu_in_policy_loss_fn(**kwargs):


@CONFIG_GENERATORS.register_config(
default_value=PPOAlgorithm.default_config()["entropy_loss_fn"],
default_value=GRPOAlgorithm.default_config()["entropy_loss_fn"],
visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"],
)
def set_entropy_loss_fn(**kwargs):
on_change = _create_on_change_callback("entropy_loss_fn", ENTROPY_LOSS_FN, **kwargs)
candidates = list(ENTROPY_LOSS_FN.modules.keys())
st.selectbox("Entropy Loss Function", candidates, **kwargs)
st.selectbox(
"Entropy Loss Function",
options=candidates,
on_change=on_change,
**kwargs,
)


@CONFIG_GENERATORS.register_config(
Expand All @@ -376,3 +403,19 @@ def set_entropy_coef_in_entropy_loss_fn(**kwargs):
format="%.1e",
**kwargs,
)


# define on_change
def _create_on_change_callback(key_name: str, registry: Registry, **kwargs):
"""Creates an on_change callback to update dependent configs."""

def on_change():
value = st.session_state[kwargs.get("key", key_name)]
value_class = registry.get(value)
if value_class:
default_args = value_class.default_args()
for arg_key, arg_value in default_args.items():
full_key = f"{arg_key}_in_{key_name}"
st.session_state[full_key] = arg_value

return on_change
Loading