From d7c43feb1d7f4f102b10ce21a25bcad44851a191 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Wed, 28 May 2025 13:55:01 +0800
Subject: [PATCH 01/28] Init Algorithm Module (#58)
---
trinity/algorithm/__init__.py | 9 +++++
trinity/algorithm/advantage_fn/__init__.py | 0
.../algorithm/advantage_fn/advantage_fn.py | 21 ++++++++++
trinity/algorithm/entropy_loss/__init__.py | 0
trinity/algorithm/kl_loss/__init__.py | 0
trinity/algorithm/policy_loss_fn/__init__.py | 0
.../policy_loss_fn/policy_loss_fn.py | 38 +++++++++++++++++++
7 files changed, 68 insertions(+)
create mode 100644 trinity/algorithm/__init__.py
create mode 100644 trinity/algorithm/advantage_fn/__init__.py
create mode 100644 trinity/algorithm/advantage_fn/advantage_fn.py
create mode 100644 trinity/algorithm/entropy_loss/__init__.py
create mode 100644 trinity/algorithm/kl_loss/__init__.py
create mode 100644 trinity/algorithm/policy_loss_fn/__init__.py
create mode 100644 trinity/algorithm/policy_loss_fn/policy_loss_fn.py
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
new file mode 100644
index 0000000000..f65ec67b47
--- /dev/null
+++ b/trinity/algorithm/__init__.py
@@ -0,0 +1,9 @@
+from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+
+__all__ = [
+ "AdvantageFn",
+ "ADVANTAGE_FN",
+ "PolicyLossFn",
+ "POLICY_LOSS_FN",
+]
diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py
new file mode 100644
index 0000000000..7e965b017c
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/advantage_fn.py
@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Tuple
+
+from trinity.utils.registry import Registry
+
+ADVANTAGE_FN = Registry("advantage_fn")
+
+
+class AdvantageFn(ABC):
+ @abstractmethod
+ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
+ """Calculate advantages from experiences
+
+ Args:
+ exps (`DataProto`): The input experiences.
+ kwargs (`Dict`): The step-level parameters for calculating advantages.
+
+ Returns:
+ `Any`: The experiences with advantages.
+ `Dict`: The metrics for logging.
+ """
diff --git a/trinity/algorithm/entropy_loss/__init__.py b/trinity/algorithm/entropy_loss/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/algorithm/kl_loss/__init__.py b/trinity/algorithm/kl_loss/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
new file mode 100644
index 0000000000..392f80e521
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
@@ -0,0 +1,38 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Tuple
+
+import torch
+
+from trinity.utils.registry import Registry
+
+POLICY_LOSS_FN = Registry("policy_loss_fn")
+
+
+class PolicyLossFn(ABC):
+ """
+ Policy Loss Function
+ """
+
+ @abstractmethod
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ experiences: Any,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ """
+ Args:
+ logprob (`torch.Tensor`): The log probability generated by the policy model.
+ old_logprob (`torch.Tensor`): The log probability generated by the reference model.
+ action_mask (`torch.Tensor`): The action mask.
+ advantages (`torch.Tensor`): The advantages.
+ experiences (`DataProto`): The input experiences.
+ kwargs (`Dict`): The step-level parameters for calculating the policy loss.
+
+ Returns:
+ `torch.Tensor`: Policy loss
+ `Dict`: The metrics for logging.
+ """
From 5cd6cb674a416b766ef093a96b31c0b342445a1d Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Wed, 28 May 2025 21:28:22 +0800
Subject: [PATCH 02/28] Add Policy Loss Functions (#62)
---
tests/template/config.yaml | 3 +
trinity/algorithm/__init__.py | 2 +-
trinity/algorithm/policy_loss_fn/__init__.py | 14 +
trinity/algorithm/policy_loss_fn/dpo_loss.py | 67 +++++
.../policy_loss_fn/opmd_policy_loss.py | 35 +++
.../policy_loss_fn/policy_loss_fn.py | 8 +
.../policy_loss_fn/ppo_policy_loss.py | 64 +++++
trinity/algorithm/policy_loss_fn/sft_loss.py | 35 +++
trinity/algorithm/utils.py | 14 +
trinity/common/config.py | 17 +-
trinity/common/verl_config.py | 4 +
trinity/trainer/trainer.py | 8 +-
trinity/trainer/verl/dp_actor.py | 250 +++++-------------
trinity/trainer/verl/fsdp_workers.py | 7 +-
trinity/trainer/verl_trainer.py | 14 +-
trinity/utils/registry.py | 2 +
16 files changed, 338 insertions(+), 206 deletions(-)
create mode 100644 trinity/algorithm/policy_loss_fn/dpo_loss.py
create mode 100644 trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
create mode 100644 trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
create mode 100644 trinity/algorithm/policy_loss_fn/sft_loss.py
create mode 100644 trinity/algorithm/utils.py
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index 09b6f9ca0d..a83a82655f 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -5,6 +5,9 @@ checkpoint_root_dir: ''
algorithm:
algorithm_type: ppo
repeat_times: 1
+ policy_loss_fn: ppo
+ policy_loss_fn_args:
+ clip_range: 0.2
model:
model_path: ''
max_prompt_tokens: 2048
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
index f65ec67b47..51d3da8317 100644
--- a/trinity/algorithm/__init__.py
+++ b/trinity/algorithm/__init__.py
@@ -1,5 +1,5 @@
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
__all__ = [
"AdvantageFn",
diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py
index e69de29bb2..66dce16cab 100644
--- a/trinity/algorithm/policy_loss_fn/__init__.py
+++ b/trinity/algorithm/policy_loss_fn/__init__.py
@@ -0,0 +1,14 @@
+from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
+from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
+from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
+
+__all__ = [
+ "POLICY_LOSS_FN",
+ "PolicyLossFn",
+ "PPOPolicyLossFn",
+ "OPMDPolicyLossFn",
+ "DPOLossFn",
+ "SFTLossFn",
+]
diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py
new file mode 100644
index 0000000000..3a9ea92f5c
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py
@@ -0,0 +1,67 @@
+"""DPO loss function."""
+
+from typing import Any, Dict, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.utils import masked_sum
+
+
+@POLICY_LOSS_FN.register_module("dpo")
+class DPOLossFn(PolicyLossFn):
+ def __init__(
+ self,
+ beta: float = 0.1,
+ label_smoothing: float = 0.0,
+ ) -> None:
+ self.beta = beta
+ self.label_smoothing = label_smoothing
+
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ experiences: Any,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ chosen_logprob = logprob[::2]
+ rejected_logprob = logprob[1::2]
+ chosen_mask = action_mask[::2]
+ rejected_mask = action_mask[1::2]
+ chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
+ rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
+
+ chosen_ref_logprob = old_logprob[::2]
+ rejected_ref_logprob = old_logprob[1::2]
+ chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
+ rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
+
+ chosen_ratios = chosen_logprob_sum - chosen_ref_logprob_sum
+ rejected_ratios = rejected_logprob_sum - rejected_ref_logprob_sum
+ logits = chosen_ratios - rejected_ratios
+ # TODO: support other loss functions
+ losses = (
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
+ )
+ loss = losses.mean()
+ chosen_reward = self.beta * chosen_ratios.detach().mean().item()
+ rejected_reward = self.beta * rejected_ratios.detach().mean().item()
+ accuracy_mean = (chosen_ratios.detach() > rejected_ratios.detach()).float().mean().item()
+ return loss, {
+ "chosen_reward": chosen_reward,
+ "rejected_reward": rejected_reward,
+ "accuracy_mean": accuracy_mean,
+ "dpo_loss": loss.detach().item(),
+ }
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "beta": 0.1,
+ "label_smoothing": 0.0,
+ }
diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
new file mode 100644
index 0000000000..dd521f9ee0
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
@@ -0,0 +1,35 @@
+"""PPO policy loss function.
+
+Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
+"""
+
+from typing import Any, Dict, Tuple
+
+import torch
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.utils import masked_mean
+
+
+@POLICY_LOSS_FN.register_module("opmd")
+class OPMDPolicyLossFn(PolicyLossFn):
+ def __init__(self, tau: float = 1.0) -> None:
+ self.tau = tau
+
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ experiences: Any,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ pg_losses = -advantages * logprob
+ opmd_loss = masked_mean(pg_losses, action_mask)
+ opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
+ return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {"tau": 1.0}
diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
index 392f80e521..eb02c49b46 100644
--- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
+++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
@@ -36,3 +36,11 @@ def __call__(
`torch.Tensor`: Policy loss
`Dict`: The metrics for logging.
"""
+
+ @classmethod
+ @abstractmethod
+ def default_args(cls) -> Dict:
+ """
+ Returns:
+ `Dict`: The default init arguments for the policy loss function.
+ """
diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
new file mode 100644
index 0000000000..9831f048d6
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
@@ -0,0 +1,64 @@
+"""PPO policy loss function.
+
+Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
+"""
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.utils import masked_mean
+
+
+@POLICY_LOSS_FN.register_module("ppo")
+class PPOPolicyLossFn(PolicyLossFn):
+ def __init__(
+ self,
+ clip_range: Optional[float] = None,
+ clip_range_low: Optional[float] = None,
+ clip_range_high: Optional[float] = None,
+ ) -> None:
+ if clip_range_low is None:
+ self.clip_range_low = clip_range
+ else:
+ self.clip_range_low = clip_range_low
+ if clip_range_high is None:
+ self.clip_range_high = clip_range
+ else:
+ self.clip_range_high = clip_range_high
+ assert self.clip_range_low is not None, "clip_range_low must be specified."
+ assert self.clip_range_high is not None, "clip_range_high must be specified."
+
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ experiences: Any,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ negative_approx_kl = logprob - old_logprob
+ ratio = torch.exp(negative_approx_kl)
+ ppo_kl = masked_mean(-negative_approx_kl, action_mask)
+
+ pg_losses = -advantages * ratio
+ pg_losses2 = -advantages * torch.clamp(
+ ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
+ )
+
+ pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask)
+ pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
+ metrics = {
+ "pg_clipfrac": pg_clipfrac.detach().item(),
+ "ppo_kl": ppo_kl.detach().item(),
+ "pg_loss": pg_loss.detach().item(),
+ }
+ return pg_loss, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "clip_range": 0.2,
+ }
diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py
new file mode 100644
index 0000000000..c04f775fa3
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/sft_loss.py
@@ -0,0 +1,35 @@
+"""SFT loss function."""
+
+from typing import Any, Dict, Tuple
+
+import torch
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.utils import masked_mean
+
+
+@POLICY_LOSS_FN.register_module("sft")
+class SFTLossFn(PolicyLossFn):
+ def __init__(self, use_token_level_loss: bool = True) -> None:
+ self.use_token_level_loss = use_token_level_loss
+
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ experiences: Any,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ if self.use_token_level_loss:
+ sft_loss = masked_mean(-logprob, action_mask)
+ else:
+ sft_loss = masked_mean(-logprob, action_mask, axis=1).mean()
+ return sft_loss, {"sft_loss": sft_loss.detach().item()}
+
+ @classmethod
+ def default_args(cls):
+ return {
+ "use_token_level_loss": True,
+ }
diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py
new file mode 100644
index 0000000000..d5cfb72d8c
--- /dev/null
+++ b/trinity/algorithm/utils.py
@@ -0,0 +1,14 @@
+"""Common utils for algorithm module.
+
+Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
+"""
+
+
+def masked_sum(values, mask, axis=None):
+ """Compute mean of tensor with a masked values."""
+ return (values * mask).sum(axis=axis)
+
+
+def masked_mean(values, mask, axis=None):
+ """Compute mean of tensor with a masked values."""
+ return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
diff --git a/trinity/common/config.py b/trinity/common/config.py
index e0660ab03a..9c3b582618 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -175,7 +175,10 @@ class AlgorithmConfig:
repeat_times: int = 1
gamma: Optional[float] = None
lam: Optional[float] = None
- # TODO: add more algorithm params here
+
+ policy_loss_fn: str = "ppo"
+ # If not set, use PolicyLossFn.default_args()
+ policy_loss_fn_args: Optional[dict] = None
@dataclass
@@ -466,6 +469,15 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.pad_token_id = 0
self.buffer.tokenizer_path = self.model.model_path
+ def _check_algorithm(self) -> None:
+ from trinity.algorithm import POLICY_LOSS_FN
+
+ policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
+ if policy_fn_cls is None:
+ raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
+ if self.algorithm.policy_loss_fn_args is None:
+ self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()
+
def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
@@ -489,6 +501,9 @@ def check_and_update(self) -> None: # noqa: C901
if not self.model.critic_model_path:
self.model.critic_model_path = self.model.model_path
+ # check algorithm
+ self._check_algorithm()
+
# check explorer
if (
self.explorer.rollout_model.engine_type != "vllm_async"
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index e5d0d9d55f..a4d9a6e8d9 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -190,6 +190,10 @@ class Algorithm:
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
+ # ! DO NOT SET THE FLOWING PARAMETERS
+ policy_loss_fn: str = "ppo"
+ policy_loss_fn_args: Optional[dict] = None
+
@dataclass
class Trainer:
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 36d23e7628..876ca2835f 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -13,7 +13,7 @@
import ray
from trinity.buffer import get_buffer_reader
-from trinity.common.config import Config
+from trinity.common.config import AlgorithmConfig, Config
from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.common.experience import Experiences
from trinity.utils.log import get_logger
@@ -73,7 +73,7 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
Returns:
bool: Whether to continue training.
"""
- self.engine.set_mode(algo_type)
+ self.engine.set_algorithm(self.config.algorithm)
if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
@@ -157,8 +157,8 @@ def sync_weight(self) -> None:
"""Sync the model weight."""
@abstractmethod
- def set_mode(self, algo_type: AlgorithmType) -> None:
- """Set training mode."""
+ def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None:
+ """Set training algorithm config."""
@abstractmethod
def shutdown(self) -> None:
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 246cd1f21c..b598bb6dad 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -30,6 +30,8 @@
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
+from trinity.algorithm import POLICY_LOSS_FN
+from trinity.common.config import AlgorithmConfig
from trinity.common.constants import AlgorithmType
from trinity.trainer.verl import core_algos
@@ -54,9 +56,13 @@ def __init__(
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
self.algorithm_type = AlgorithmType.PPO
+ self.policy_loss_fn = None
- def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO):
- self.algorithm_type = algorithm_type
+ def set_algorithm(self, algorithm_config: AlgorithmConfig):
+ self.algorithm_type = algorithm_config.algorithm_type
+ self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
+ **algorithm_config.policy_loss_fn_args
+ )
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -129,27 +135,6 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
use_cache=False,
) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
- if self.algorithm_type.is_sft(): # SFT
- loss_fct = nn.CrossEntropyLoss(reduction="none")
- loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
- if self.use_ulysses_sp:
- loss = gather_outpus_and_unpad(
- loss, gather_dim=0, unpad_dim=0, padding_size=pad_size
- )
- response_mask = attention_mask[:, -response_length:].bool()
- # pad back to (bsz, seqlen)
- full_loss = pad_input(
- hidden_states=loss.unsqueeze(-1),
- indices=indices,
- batch=batch_size,
- seqlen=seqlen,
- ).squeeze(-1)
- full_loss = torch.where(
- response_mask, full_loss[:, -response_length - 1 : -1], 0.0
- )
- full_loss = full_loss.sum(-1) / response_mask.sum(-1)
- full_loss = full_loss.mean()
- return full_loss
logits_rmpad.div_(temperature)
@@ -201,21 +186,6 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
use_cache=False,
) # prevent model thinks we are generating
logits = output.logits
- if self.algorithm_type.is_sft():
- loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
- response_mask = attention_mask[:, -response_length:].bool()
- response_labels = torch.where(
- response_mask, input_ids[:, -response_length:], -100
- )
- response_logits = logits[:, -response_length - 1 : -1, :]
- loss = loss_fct(
- response_logits.reshape(-1, response_logits.shape[-1]),
- response_labels.reshape(-1),
- )
- loss = loss.view(response_labels.shape)
- loss = loss.sum(-1) / response_mask.sum(-1)
- loss = loss.mean()
- return loss
logits.div_(temperature)
logits = logits[
:, -response_length - 1 : -1, :
@@ -308,57 +278,25 @@ def update_policy(self, data: DataProto): # noqa: C901
temperature = data.meta_info[
"temperature"
] # temperature must be in the data.meta_info to avoid slient error
-
- algorithm_type: AlgorithmType = self.config.get("algorithm_type", AlgorithmType.PPO)
- if self.algorithm_type.is_rft():
- select_keys = [
- "responses",
- "input_ids",
- "attention_mask",
- "position_ids",
- "old_log_probs",
- "advantages",
- "response_mask",
- ]
- if self.config.use_kl_loss:
- select_keys.append("ref_log_prob")
-
- if algorithm_type == AlgorithmType.PAIRWISE_OPMD:
- select_keys.append("token_level_scores")
- elif self.algorithm_type.is_dpo():
- select_keys = [
- "attention_mask",
- "input_ids",
- "position_ids",
- "response_mask",
- "responses",
- "ref_log_prob",
- ]
- else: # sft
- select_keys = [
- "attention_mask",
- "input_ids",
- "position_ids",
- "response_mask",
- "responses",
- ]
- use_uid = self.config.get("use_uid", False)
-
+ select_keys = [
+ "responses",
+ "input_ids",
+ "attention_mask",
+ "position_ids",
+ "old_log_probs",
+ "advantages",
+ "response_mask",
+ ]
+ if self.config.use_kl_loss:
+ select_keys.append("ref_log_prob")
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
- if has_multi_modal_inputs or ((algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid):
- # TODO: for now, we treat algorithm_type == AlgorithmType.PAIRWISE_OPMD in the same way that
- # has_multi_modal_inputs was treated originally (to handle non_tensor_select_keys);
- # need to double check if this is the best approach.
+ if has_multi_modal_inputs:
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
- non_tensor_select_keys = []
- if has_multi_modal_inputs:
- non_tensor_select_keys.append("multi_modal_inputs")
- if (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid:
- non_tensor_select_keys.append("uid")
+ non_tensor_select_keys = ["multi_modal_inputs"]
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
@@ -373,9 +311,7 @@ def update_policy(self, data: DataProto): # noqa: C901
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
- if has_multi_modal_inputs or (
- (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid
- ):
+ if has_multi_modal_inputs:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
@@ -412,93 +348,48 @@ def update_policy(self, data: DataProto): # noqa: C901
data = data.to(
torch.cuda.current_device()
) # actor device is cpu when using offload
+ responses = data["responses"]
+ response_length = responses.size(1)
+ attention_mask = data["attention_mask"]
+ # response_mask = attention_mask[:, -response_length:]
+ response_mask = data["response_mask"]
+ assert response_mask.shape == attention_mask[:, -response_length:].shape
+ old_log_prob = data["old_log_probs"]
+ advantages = data["advantages"]
+ entropy_coeff = self.config.entropy_coeff
+
+ # all return: (bsz, response_length)
+ entropy, log_prob = self._forward_micro_batch(
+ micro_batch=data, temperature=temperature
+ )
- # TODO: it is better to unify the returns of several modes (sft, dpo)
- if self.algorithm_type.is_sft():
- policy_loss = self._forward_micro_batch(
- micro_batch=data, temperature=temperature
- )
+ pg_loss, metric = self.policy_loss_fn( # type: ignore
+ logprob=log_prob,
+ old_logprob=old_log_prob,
+ action_mask=response_mask,
+ advantages=advantages,
+ experiences=data,
+ )
- elif self.algorithm_type.is_dpo():
- response_mask = data["response_mask"]
+ # compute entropy loss from entropy
+ entropy_loss = verl_F.masked_mean(entropy, response_mask)
- _, log_prob = self._forward_micro_batch(
- micro_batch=data, temperature=temperature
- )
- if self.config.use_kl_loss:
- ref_log_prob = data["ref_log_prob"]
- else:
- ref_log_prob = None
-
- (
- policy_loss,
- chosen_reward,
- rejected_reward,
- ) = core_algos.compute_policy_loss_dpo(
- log_prob=log_prob,
- ref_log_prob=ref_log_prob,
- eos_mask=response_mask,
- beta=self.config.kl_loss_coef,
- # label_smoothing=self.config.label_smoothing # TODO: add configs for dpo
- )
+ # compute policy loss
+ policy_loss = pg_loss - entropy_loss * entropy_coeff
- else: # rft
- responses = data["responses"]
- response_length = responses.size(1)
- attention_mask = data["attention_mask"]
- # response_mask = attention_mask[:, -response_length:]
- response_mask = data["response_mask"]
- assert response_mask.shape == attention_mask[:, -response_length:].shape
- old_log_prob = data["old_log_probs"]
- advantages = data["advantages"]
-
- clip_ratio = self.config.clip_ratio
- entropy_coeff = self.config.entropy_coeff
-
- tau = self.config.get("tau", 1.0)
- token_level_scores = None
- index = None
- if algorithm_type == AlgorithmType.PAIRWISE_OPMD:
- token_level_scores = data["token_level_scores"]
- if use_uid:
- index = data["uid"]
-
- # all return: (bsz, response_length)
- entropy, log_prob = self._forward_micro_batch(
- micro_batch=data, temperature=temperature
+ if self.config.use_kl_loss:
+ ref_log_prob = data["ref_log_prob"]
+ # compute kl loss
+ kld = core_algos.kl_penalty(
+ logprob=log_prob,
+ ref_logprob=ref_log_prob,
+ kl_penalty=self.config.kl_loss_type,
)
+ kl_loss = masked_mean(kld, response_mask)
- pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(
- old_log_prob=old_log_prob,
- log_prob=log_prob,
- eos_mask=response_mask,
- algorithm_type=algorithm_type,
- advantages=advantages,
- cliprange=clip_ratio,
- # for opmd / pairwise_opmd
- tau=tau,
- token_level_scores=token_level_scores,
- index=index,
- )
- # compute entropy loss from entropy
- entropy_loss = verl_F.masked_mean(entropy, response_mask)
-
- # compute policy loss
- policy_loss = pg_loss - entropy_loss * entropy_coeff
-
- if self.config.use_kl_loss:
- ref_log_prob = data["ref_log_prob"]
- # compute kl loss
- kld = core_algos.kl_penalty(
- logprob=log_prob,
- ref_logprob=ref_log_prob,
- kl_penalty=self.config.kl_loss_type,
- )
- kl_loss = masked_mean(kld, response_mask)
-
- policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
- metrics["actor/kl_loss"] = kl_loss.detach().item()
- metrics["actor/kl_coef"] = self.config.kl_loss_coef
+ policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
+ metrics["actor/kl_loss"] = kl_loss.detach().item()
+ metrics["actor/kl_coef"] = self.config.kl_loss_coef
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
@@ -507,28 +398,9 @@ def update_policy(self, data: DataProto): # noqa: C901
loss = policy_loss / self.gradient_accumulation
loss.backward()
- if self.algorithm_type.is_rft():
- data = {
- "actor/entropy_loss": entropy_loss.detach().item(),
- "actor/pg_loss": pg_loss.detach().item(),
- "actor/pg_clipfrac": pg_clipfrac.detach().item(),
- "actor/ppo_kl": ppo_kl.detach().item(),
- }
- elif self.algorithm_type.is_dpo():
- data = {
- "dpo/loss": policy_loss.detach().item(),
- "dpo/loss_mean": loss.detach().item(),
- "dpo/chosen_reward": chosen_reward.detach().mean().item(),
- "dpo/rejected_reward": rejected_reward.detach().mean().item(),
- "dpo/accuracy_mean": (chosen_reward > rejected_reward)
- .float()
- .mean()
- .item(),
- }
- else:
- data = {
- "sft/loss": loss.detach().item(),
- }
+ data = {f"actor/{key}": value for key, value in metric.items()}
+ # TODO: refactor entropy loss
+ data["actor/entropy_loss"] = entropy_loss.detach().item()
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index 26b640e871..c0af427b4a 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -50,7 +50,8 @@
from verl.utils.model import compute_position_id_with_mask
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
-from trinity.common.constants import AlgorithmType, SyncMethod
+from trinity.common.config import AlgorithmConfig
+from trinity.common.constants import SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address
logger = logging.getLogger(__file__)
@@ -623,8 +624,8 @@ def sync_weight(self):
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def set_mode(self, algo_type: AlgorithmType = AlgorithmType.PPO):
- self.actor.set_mode(algo_type)
+ def set_algorithm(self, algo_config: AlgorithmConfig):
+ self.actor.set_algorithm(algo_config)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 7590d6075b..5324a13f7c 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -13,7 +13,7 @@
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
-from trinity.common.config import Config
+from trinity.common.config import AlgorithmConfig, Config
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
@@ -125,9 +125,7 @@ def __init__(
ray_worker_group_cls,
)
self.init_workers()
- self.algorithm_type = (
- AlgorithmType.PPO
- ) # TODO: initialize algorithm_type according to config
+ self.algorithm_type = AlgorithmType.PPO
self.logger = Monitor(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
@@ -499,11 +497,11 @@ def save_checkpoint(self) -> None:
def sync_weight(self) -> None:
self.actor_rollout_wg.sync_weight()
- def set_mode(self, algorithm_type: AlgorithmType = AlgorithmType.PPO) -> None:
- self.actor_rollout_wg.set_mode(algorithm_type)
- if self.algorithm_type.is_sft() and (not algorithm_type.is_sft()):
+ def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None:
+ self.actor_rollout_wg.set_algorithm(algorithm_config)
+ if self.algorithm_type.is_sft() and (not algorithm_config.algorithm_type.is_sft()):
self.sft_to_rft()
- self.algorithm_type = algorithm_type
+ self.algorithm_type = algorithm_config.algorithm_type
def sft_to_rft(self) -> None:
# load from hdfs
diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py
index 70fb2930c9..b31f6872bd 100644
--- a/trinity/utils/registry.py
+++ b/trinity/utils/registry.py
@@ -22,6 +22,8 @@
logger = get_logger(__name__)
+# TODO: support lazy load
+# e.g. @MODULES.register_module("name", lazy=True)
class Registry(object):
"""This class is used to register some modules to registry by a repo
name."""
From fe217aa154dc89a49b2688afc623482f8c276818 Mon Sep 17 00:00:00 2001
From: Yanxi Chen <153061753+yanxi-chen@users.noreply.github.com>
Date: Tue, 3 Jun 2025 13:32:39 +0800
Subject: [PATCH 03/28] Refactor advantage computation, and delete
RayPPOTrainer.fit (#61)
---
tests/template/config.yaml | 5 +
trinity/algorithm/__init__.py | 2 +-
trinity/algorithm/advantage_fn/__init__.py | 20 +
.../algorithm/advantage_fn/advantage_fn.py | 10 +-
.../algorithm/advantage_fn/grpo_advantage.py | 42 +++
.../algorithm/advantage_fn/opmd_advantage.py | 45 +++
.../algorithm/advantage_fn/ppo_advantage.py | 50 +++
.../reinforce_plus_plus_advantage.py | 42 +++
.../algorithm/advantage_fn/remax_advantage.py | 40 ++
.../algorithm/advantage_fn/rloo_advantage.py | 40 ++
trinity/common/config.py | 14 +-
trinity/common/verl_config.py | 24 +-
trinity/trainer/verl/core_algos.py | 6 +-
trinity/trainer/verl/ray_trainer.py | 344 ------------------
trinity/trainer/verl_trainer.py | 45 +--
15 files changed, 342 insertions(+), 387 deletions(-)
create mode 100644 trinity/algorithm/advantage_fn/grpo_advantage.py
create mode 100644 trinity/algorithm/advantage_fn/opmd_advantage.py
create mode 100644 trinity/algorithm/advantage_fn/ppo_advantage.py
create mode 100644 trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
create mode 100644 trinity/algorithm/advantage_fn/remax_advantage.py
create mode 100644 trinity/algorithm/advantage_fn/rloo_advantage.py
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index a83a82655f..c83d938c66 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -8,6 +8,11 @@ algorithm:
policy_loss_fn: ppo
policy_loss_fn_args:
clip_range: 0.2
+ advantage_fn_type: ppo_adv_fn
+ advantage_fn_args:
+ gamma: 1.0
+ lam: 1.0
+
model:
model_path: ''
max_prompt_tokens: 2048
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
index 51d3da8317..170507663f 100644
--- a/trinity/algorithm/__init__.py
+++ b/trinity/algorithm/__init__.py
@@ -1,4 +1,4 @@
-from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
__all__ = [
diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py
index e69de29bb2..7bcf682e4b 100644
--- a/trinity/algorithm/advantage_fn/__init__.py
+++ b/trinity/algorithm/advantage_fn/__init__.py
@@ -0,0 +1,20 @@
+from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn
+from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn
+from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
+from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import (
+ REINFORCEPLUSPLUSAdvantageFn,
+)
+from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn
+from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn
+
+__all__ = [
+ "ADVANTAGE_FN",
+ "AdvantageFn",
+ "PPOAdvantageFn",
+ "GRPOAdvantageFn",
+ "REINFORCEPLUSPLUSAdvantageFn",
+ "REMAXAdvantageFn",
+ "RLOOAdvantageFn",
+ "OPMDAdvantageFn",
+]
diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py
index 7e965b017c..21e3668a53 100644
--- a/trinity/algorithm/advantage_fn/advantage_fn.py
+++ b/trinity/algorithm/advantage_fn/advantage_fn.py
@@ -16,6 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
kwargs (`Dict`): The step-level parameters for calculating advantages.
Returns:
- `Any`: The experiences with advantages.
+ `DataProto`: The experiences with advantages.
`Dict`: The metrics for logging.
"""
+
+ @classmethod
+ @abstractmethod
+ def default_args(cls) -> Dict:
+ """
+ Returns:
+ `Dict`: The default init arguments for the advantage function.
+ """
diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py
new file mode 100644
index 0000000000..89a8282752
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/grpo_advantage.py
@@ -0,0 +1,42 @@
+"""GRPO advantage computation
+
+Adapted from compute_advantage_ppo in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("grpo_adv_fn")
+class GRPOAdvantageFn(AdvantageFn):
+ """GRPO advantage computation"""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_grpo_outcome_advantage(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ eos_mask=exps.batch["response_mask"],
+ index=exps.non_tensor_batch["uid"],
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {}
diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py
new file mode 100644
index 0000000000..abf74686d3
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/opmd_advantage.py
@@ -0,0 +1,45 @@
+"""OPMD advantage computation
+
+Adapted from compute_advantage_opmd in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("opmd_adv_fn")
+class OPMDAdvantageFn(AdvantageFn):
+ """OPMD advantage computation"""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_opmd_outcome_advantage(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ eos_mask=exps.batch["response_mask"],
+ # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
+ index=exps.non_tensor_batch["uid"],
+ opmd_baseline="mean",
+ tau=1.0,
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {}
diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py
new file mode 100644
index 0000000000..5afd51311c
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/ppo_advantage.py
@@ -0,0 +1,50 @@
+"""PPO's GAE advantage computation
+
+Adapted from compute_advantage_ppo in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("ppo_adv_fn")
+class PPOAdvantageFn(AdvantageFn):
+ def __init__(
+ self,
+ gamma: float = 1.0,
+ lam: float = 1.0,
+ ) -> None:
+ self.gamma = gamma
+ self.lam = lam
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_gae_advantage_return(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ values=exps.batch["values"],
+ eos_mask=exps.batch["response_mask"],
+ gamma=self.gamma,
+ lam=self.lam,
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "gamma": 1.0,
+ "lam": 1.0,
+ }
diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
new file mode 100644
index 0000000000..9c668f7640
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
@@ -0,0 +1,42 @@
+"""REINFORCE++ advantage computation
+
+Adapted from compute_advantage_ppo in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn")
+class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
+ def __init__(self, gamma: float = 1.0) -> None:
+ self.gamma = gamma
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ eos_mask=exps.batch["response_mask"],
+ gamma=self.gamma,
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "gamma": 1.0,
+ }
diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py
new file mode 100644
index 0000000000..05a13d7d60
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/remax_advantage.py
@@ -0,0 +1,40 @@
+"""REMAX advantage computation
+
+Adapted from compute_advantage_ppo in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("remax_adv_fn")
+class REMAXAdvantageFn(AdvantageFn):
+ def __init__(self) -> None:
+ pass
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_remax_outcome_advantage(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ reward_baselines=exps.batch["reward_baselines"],
+ eos_mask=exps.batch["response_mask"],
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {}
diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py
new file mode 100644
index 0000000000..3da61c9da4
--- /dev/null
+++ b/trinity/algorithm/advantage_fn/rloo_advantage.py
@@ -0,0 +1,40 @@
+"""RLOO advantage computation
+
+Adapted from compute_advantage_ppo in original ray_trainer.py
+"""
+
+from typing import Dict, Tuple
+
+from verl import DataProto
+
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.trainer.verl import core_algos
+
+
+@ADVANTAGE_FN.register_module("rloo_adv_fn")
+class RLOOAdvantageFn(AdvantageFn):
+ def __init__(self) -> None:
+ pass
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ advantages, returns = core_algos.compute_rloo_outcome_advantage(
+ token_level_rewards=exps.batch["token_level_rewards"],
+ eos_mask=exps.batch["response_mask"],
+ index=exps.non_tensor_batch["uid"],
+ )
+ exps.batch["advantages"] = advantages
+ exps.batch["returns"] = returns
+
+ metrics = {
+ # TODO: add meaningful metrics
+ }
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {}
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 9c3b582618..794202bab0 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -173,13 +173,15 @@ class AlgorithmConfig:
algorithm_type: AlgorithmType = AlgorithmType.PPO
# for GRPO-like algorithms, repeat each task for `repeat_times` times
repeat_times: int = 1
- gamma: Optional[float] = None
- lam: Optional[float] = None
policy_loss_fn: str = "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None
+ advantage_fn_type: str = "ppo_adv_fn"
+ # If not set, use AdvantageFn.default_args()
+ advantage_fn_args: Optional[dict] = None
+
@dataclass
class ClusterConfig:
@@ -470,7 +472,7 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.tokenizer_path = self.model.model_path
def _check_algorithm(self) -> None:
- from trinity.algorithm import POLICY_LOSS_FN
+ from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
@@ -478,6 +480,12 @@ def _check_algorithm(self) -> None:
if self.algorithm.policy_loss_fn_args is None:
self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()
+ advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type)
+ if advantage_fn_cls is None:
+ raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}")
+ if self.algorithm.advantage_fn_args is None:
+ self.algorithm.advantage_fn_args = advantage_fn_cls.default_args()
+
def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index a4d9a6e8d9..fb9f810dee 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -182,6 +182,9 @@ class KL_Ctrl:
@dataclass
class Algorithm:
+ # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
+ # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
+ # if they are really needed (e.g., for GAE advantage/returns computation)
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
@@ -190,7 +193,7 @@ class Algorithm:
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
- # ! DO NOT SET THE FLOWING PARAMETERS
+ # ! DO NOT SET THE FOLLOWING PARAMETERS
policy_loss_fn: str = "ppo"
policy_loss_fn_args: Optional[dict] = None
@@ -315,17 +318,22 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio
# Algorithm related config
- if config.algorithm.gamma is not None:
- self.algorithm.gamma = config.algorithm.gamma
- if config.algorithm.lam is not None:
- self.algorithm.lam = config.algorithm.lam
+ adv_fn_args = config.algorithm.advantage_fn_args
+ if adv_fn_args is not None and "gamma" in adv_fn_args:
+ self.algorithm.gamma = adv_fn_args["gamma"]
+ if adv_fn_args is not None and "lam" in adv_fn_args:
+ self.algorithm.lam = adv_fn_args["lam"]
self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type
if config.algorithm.algorithm_type == AlgorithmType.PPO:
- logger.info("Using GAE `adv_estimator` for PPO")
+ logger.info("Setting `adv_estimator` to 'gae' for PPO")
self.algorithm.adv_estimator = AdvantageEstimator.GAE.value
- elif config.algorithm.algorithm_type == AlgorithmType.GRPO:
- logger.info("Using GRPO `adv_estimator` for GRPO")
+ elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD):
+ logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD")
self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value
+ # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
+ # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper).
+ # Need to double check whether this is indeed the case,
+ # and see if adv_estimator can be removed completely.
if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO
if not self.actor_rollout_ref.actor.use_kl_loss:
diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py
index 20cffc9962..f104e0f4f4 100644
--- a/trinity/trainer/verl/core_algos.py
+++ b/trinity/trainer/verl/core_algos.py
@@ -139,8 +139,8 @@ def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
- gamma: torch.Tensor,
- lam: torch.Tensor,
+ gamma: float,
+ lam: float,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
@@ -283,7 +283,7 @@ def compute_rloo_outcome_advantage(
def compute_reinforce_plus_plus_outcome_advantage(
- token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
+ token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float
):
"""
Compute advantage for REINFORCE++.
diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py
index 7073319db0..5d883d05bb 100644
--- a/trinity/trainer/verl/ray_trainer.py
+++ b/trinity/trainer/verl/ray_trainer.py
@@ -16,18 +16,14 @@
"""
import os
-import uuid
from contextlib import contextmanager
-from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
-from pprint import pprint
from typing import Dict, Type
import numpy as np
import ray
import torch
-import tqdm
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import RandomSampler, SequentialSampler
@@ -41,12 +37,6 @@
RayWorkerGroup,
)
from verl.single_controller.ray.base import create_colocated_worker_cls
-from verl.trainer.ppo.metric_utils import (
- compute_data_metrics,
- compute_throughout_metrics,
- compute_timing_metrics,
- reduce_metrics,
-)
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.seqlen_balancing import (
@@ -206,116 +196,6 @@ def compute_response_mask(data: DataProto):
return attention_mask[:, -response_length:]
-def compute_advantage(data: DataProto, **kwargs):
- """Extend verl's original compute_advantage with OPMD"""
-
- algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO)
-
- if algorithm_type == AlgorithmType.OPMD:
- tau = kwargs.get("tau", 1.0)
- opmd_baseline = kwargs.get("opmd_baseline", "mean")
-
- return compute_advantage_opmd(
- data=data,
- tau=tau,
- opmd_baseline=opmd_baseline,
- )
-
- elif algorithm_type == AlgorithmType.PAIRWISE_OPMD:
- data.batch["advantages"] = None
- data.batch["returns"] = None
- return data
-
- elif algorithm_type.is_rft():
- adv_estimator = kwargs.get("adv_estimator", None)
- gamma = kwargs.get("gamma", 1.0)
- lam = kwargs.get("lam", 1.0)
- num_repeat = kwargs.get("num_repeat", 1)
-
- return compute_advantage_ppo(
- data=data,
- adv_estimator=adv_estimator,
- gamma=gamma,
- lam=lam,
- num_repeat=num_repeat,
- )
-
- else:
- raise ValueError(f"Get invalid algorithm_type '{algorithm_type}'.")
-
-
-def compute_advantage_opmd(data: DataProto, tau=1.0, opmd_baseline="mean"):
- # Modified from GRPO version
- token_level_rewards = data.batch["token_level_rewards"]
- index = data.non_tensor_batch["uid"]
- responses = data.batch["responses"]
- response_length = responses.size(-1)
- attention_mask = data.batch["attention_mask"]
- response_mask = attention_mask[:, -response_length:]
- advantages, returns = core_algos.compute_opmd_outcome_advantage(
- token_level_rewards=token_level_rewards,
- eos_mask=response_mask,
- index=index,
- opmd_baseline=opmd_baseline,
- tau=tau,
- )
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
-
- return data
-
-
-def compute_advantage_ppo(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
- # prepare response group
- # TODO: add other ways to estimate advantages
- if adv_estimator == AdvantageEstimator.GAE:
- advantages, returns = core_algos.compute_gae_advantage_return(
- token_level_rewards=data.batch["token_level_rewards"],
- values=data.batch["values"],
- eos_mask=data.batch["response_mask"],
- gamma=gamma,
- lam=lam,
- )
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
- elif adv_estimator == AdvantageEstimator.GRPO:
- advantages, returns = core_algos.compute_grpo_outcome_advantage(
- token_level_rewards=data.batch["token_level_rewards"],
- eos_mask=data.batch["response_mask"],
- index=data.non_tensor_batch["uid"],
- )
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
- elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
- advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
- token_level_rewards=data.batch["token_level_rewards"],
- eos_mask=data.batch["response_mask"],
- gamma=gamma,
- )
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
- elif adv_estimator == AdvantageEstimator.REMAX:
- advantages, returns = core_algos.compute_remax_outcome_advantage(
- token_level_rewards=data.batch["token_level_rewards"],
- reward_baselines=data.batch["reward_baselines"],
- eos_mask=data.batch["response_mask"],
- )
-
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
- elif adv_estimator == AdvantageEstimator.RLOO:
- advantages, returns = core_algos.compute_rloo_outcome_advantage(
- token_level_rewards=data.batch["token_level_rewards"],
- eos_mask=data.batch["response_mask"],
- index=data.non_tensor_batch["uid"],
- )
- data.batch["advantages"] = advantages
- data.batch["returns"] = returns
- else:
- raise NotImplementedError
- return data
-
-
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
@@ -934,227 +814,3 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
-
- def fit(self): # noqa: C901
- """
- The training loop of PPO.
- The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
- The light-weight advantage computation is done on the driver process.
- """
- from omegaconf import OmegaConf
- from verl.utils.tracking import Tracking
-
- logger = Tracking(
- project_name=self.config.trainer.project_name,
- experiment_name=self.config.trainer.experiment_name,
- default_backend=self.config.trainer.logger,
- config=OmegaConf.to_container(self.config, resolve=True),
- )
-
- self.global_steps = 0
-
- # load checkpoint before doing anything
- self._load_checkpoint()
-
- # perform validation before training
- # currently, we only support validation using the reward_function.
- if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
- val_metrics = self._validate()
- pprint(f"Initial validation metrics: {val_metrics}")
- logger.log(data=val_metrics, step=self.global_steps)
- if self.config.trainer.get("val_only", False):
- return
-
- # add tqdm
- progress_bar = tqdm(
- total=self.total_training_steps, initial=self.global_steps, desc="Training Progress"
- )
-
- # we start from step 1
- self.global_steps += 1
- last_val_metrics = None
-
- for epoch in range(self.config.trainer.total_epochs):
- for batch_dict in self.train_dataloader:
- metrics = {}
- timing_raw = {}
-
- batch: DataProto = DataProto.from_single_dict(batch_dict)
-
- # pop those keys for generation
- if "multi_modal_inputs" in batch.non_tensor_batch.keys():
- gen_batch = batch.pop(
- batch_keys=["input_ids", "attention_mask", "position_ids"],
- non_tensor_batch_keys=[
- "raw_prompt_ids",
- "multi_modal_data",
- "multi_modal_inputs",
- ],
- )
- else:
- gen_batch = batch.pop(
- batch_keys=["input_ids", "attention_mask", "position_ids"],
- non_tensor_batch_keys=["raw_prompt_ids"],
- )
-
- is_last_step = self.global_steps >= self.total_training_steps
-
- with _timer("step", timing_raw):
- # generate a batch
- with _timer("gen", timing_raw):
- gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
-
- if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
- with _timer("gen_max", timing_raw):
- gen_baseline_batch = deepcopy(gen_batch)
- gen_baseline_batch.meta_info["do_sample"] = False
- gen_baseline_output = self.actor_rollout_wg.generate_sequences(
- gen_baseline_batch
- )
-
- batch = batch.union(gen_baseline_output)
- reward_baseline_tensor = self.reward_fn(batch)
- reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
-
- batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
-
- batch.batch["reward_baselines"] = reward_baseline_tensor
-
- del gen_baseline_batch, gen_baseline_output
-
- batch.non_tensor_batch["uid"] = np.array(
- [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
- )
- # repeat to align with repeated responses in rollout
- batch = batch.repeat(
- repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
- )
- batch = batch.union(gen_batch_output)
-
- batch.batch["response_mask"] = compute_response_mask(batch)
-
- # balance the number of valid tokens on each dp rank.
- # Note that this breaks the order of data inside the batch.
- # Please take care when you implement group based adv computation such as GRPO and rloo
- if self.config.trainer.balance_batch:
- self._balance_batch(batch, metrics=metrics)
-
- # compute global_valid tokens
- batch.meta_info["global_token_num"] = torch.sum(
- batch.batch["attention_mask"], dim=-1
- ).tolist()
-
- # recompute old_log_probs
- with _timer("old_log_prob", timing_raw):
- old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
- batch = batch.union(old_log_prob)
-
- if self.use_reference_policy:
- # compute reference log_prob
- with _timer("ref", timing_raw):
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
-
- # compute values
- if self.use_critic:
- with _timer("values", timing_raw):
- values = self.critic_wg.compute_values(batch)
- batch = batch.union(values)
-
- with _timer("adv", timing_raw):
- # compute scores. Support both model and function-based.
- # We first compute the scores using reward model. Then, we call reward_fn to combine
- # the results from reward model and rule-based results.
- if self.use_rm:
- # we first compute reward model score
- reward_tensor = self.rm_wg.compute_rm_score(batch)
- batch = batch.union(reward_tensor)
-
- # we combine with rule-based rm
- reward_tensor = self.reward_fn(batch)
- batch.batch["token_level_scores"] = reward_tensor
-
- # compute rewards. apply_kl_penalty if available
- if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False):
- batch, kl_metrics = apply_kl_penalty(
- batch,
- kl_ctrl=self.kl_ctrl,
- kl_penalty=self.config.algorithm.kl_penalty,
- )
- metrics.update(kl_metrics)
- else:
- batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
-
- # compute advantages, executed on the driver process
- algorithm_type = self.config.actor_rollout_ref.actor.get(
- "algorithm_type", AlgorithmType.PPO
- )
- tau = self.config.actor_rollout_ref.actor.get("tau", 1.0)
- opmd_baseline = self.config.actor_rollout_ref.actor.get(
- "opmd_baseline", "mean"
- )
- batch = compute_advantage(
- batch,
- algorithm_type=algorithm_type,
- adv_estimator=self.config.algorithm.adv_estimator,
- gamma=self.config.algorithm.gamma,
- lam=self.config.algorithm.lam,
- num_repeat=self.config.actor_rollout_ref.rollout.n,
- # additional config params for OPMD
- tau=tau,
- opmd_baseline=opmd_baseline,
- )
-
- # update critic
- if self.use_critic:
- with _timer("update_critic", timing_raw):
- critic_output = self.critic_wg.update_critic(batch)
- critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
- metrics.update(critic_output_metrics)
-
- # implement critic warmup
- if self.config.trainer.critic_warmup <= self.global_steps:
- # update actor
- with _timer("update_actor", timing_raw):
- actor_output = self.actor_rollout_wg.update_actor(batch)
- actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
- metrics.update(actor_output_metrics)
-
- # validate
- if (
- self.val_reward_fn is not None
- and self.config.trainer.test_freq > 0
- and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
- ):
- with _timer("testing", timing_raw):
- val_metrics: dict = self._validate()
- if is_last_step:
- last_val_metrics = val_metrics
- metrics.update(val_metrics)
-
- if self.config.trainer.save_freq > 0 and (
- is_last_step or self.global_steps % self.config.trainer.save_freq == 0
- ):
- with _timer("save_checkpoint", timing_raw):
- self._save_checkpoint()
-
- # collect metrics
- metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
- metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
-
- # Implement actual tflpo and theoretical tflpo
- n_gpus = self.resource_pool_manager.get_n_gpus()
- metrics.update(
- compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)
- )
-
- # TODO: make a canonical logger that supports various backend
- logger.log(data=metrics, step=self.global_steps)
-
- if is_last_step:
- pprint(f"Final validation metrics: {last_val_metrics}")
- progress_bar.close()
- return
-
- progress_bar.update(1)
- self.global_steps += 1
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 5324a13f7c..b6397adde7 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -4,15 +4,24 @@
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
+from pprint import pprint
from typing import Tuple
+import numpy as np
import pandas as pd
import ray
import torch
from omegaconf import OmegaConf
+from verl.trainer.ppo.metric_utils import (
+ compute_data_metrics,
+ compute_throughout_metrics,
+ compute_timing_metrics,
+ reduce_metrics,
+)
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
+from trinity.algorithm import ADVANTAGE_FN
from trinity.common.config import AlgorithmConfig, Config
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
@@ -25,14 +34,7 @@
Role,
_timer,
apply_kl_penalty,
- compute_advantage,
- compute_data_metrics,
- compute_throughout_metrics,
- compute_timing_metrics,
find_latest_ckpt_path,
- np,
- pprint,
- reduce_metrics,
)
from trinity.utils.monitor import Monitor
@@ -126,6 +128,14 @@ def __init__(
)
self.init_workers()
self.algorithm_type = AlgorithmType.PPO
+
+ # specify advantage function for various rft algorithms
+ algo_config = global_config.algorithm
+ if algo_config.algorithm_type.is_rft():
+ adv_fn_type = algo_config.advantage_fn_type
+ adv_fn_args = algo_config.advantage_fn_args
+ self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args)
+
self.logger = Monitor(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
@@ -377,26 +387,7 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
- kwargs = {}
- algorithm_type = self.config.actor_rollout_ref.actor.get(
- "algorithm_type", AlgorithmType.PPO
- )
- if algorithm_type == AlgorithmType.OPMD:
- tau = self.config.actor_rollout_ref.actor.get("tau", 0.0)
- opmd_baseline = self.config.actor_rollout_ref.actor.get("opmd_baseline", "mean")
- kwargs = {
- "algorithm_type": algorithm_type,
- "tau": tau,
- "opmd_baseline": opmd_baseline,
- }
- batch = compute_advantage(
- batch,
- adv_estimator=self.config.algorithm.adv_estimator,
- gamma=self.config.algorithm.gamma,
- lam=self.config.algorithm.lam,
- num_repeat=self.config.actor_rollout_ref.rollout.n,
- **kwargs,
- )
+ batch, _ = self.advantage_fn(batch)
# update critic
if self.use_critic:
From 9d582e8802dd7911f9be46a83eb17dba20439323 Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Wed, 4 Jun 2025 20:54:55 +0800
Subject: [PATCH 04/28] Add unittest && bug fix (#65)
---
pyproject.toml | 2 +-
tests/template/config.yaml | 4 +-
tests/template/data/sft_for_gsm8k/sft.jsonl | 32 ++++++
tests/tools.py | 47 ++++++++
tests/trainer/trainer_test.py | 108 +++++++++++++++++-
trinity/algorithm/policy_loss_fn/dpo_loss.py | 19 +--
.../policy_loss_fn/opmd_policy_loss.py | 15 ++-
.../policy_loss_fn/policy_loss_fn.py | 15 ++-
.../policy_loss_fn/ppo_policy_loss.py | 13 ++-
trinity/algorithm/policy_loss_fn/sft_loss.py | 11 +-
trinity/common/config.py | 5 +-
trinity/trainer/trainer.py | 12 +-
trinity/trainer/verl/dp_actor.py | 30 +++--
trinity/trainer/verl_trainer.py | 54 ++++-----
14 files changed, 295 insertions(+), 72 deletions(-)
create mode 100644 tests/template/data/sft_for_gsm8k/sft.jsonl
diff --git a/pyproject.toml b/pyproject.toml
index dcf86f8349..022c9a8ffe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]>=2.45.0",
- "vllm>=0.8.5",
+ "vllm==0.8.5.post1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index c83d938c66..3a767df243 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -18,8 +18,8 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
cluster: # 2 for explorer, 2 for trainer
- node_num: 1
- gpu_per_node: 4
+ node_num: 2
+ gpu_per_node: 2
buffer:
total_epochs: 1
batch_size: 4
diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl
new file mode 100644
index 0000000000..a8d6972103
--- /dev/null
+++ b/tests/template/data/sft_for_gsm8k/sft.jsonl
@@ -0,0 +1,32 @@
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
diff --git a/tests/tools.py b/tests/tools.py
index 2e34438d66..3111839a37 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -13,6 +13,7 @@
StorageConfig,
load_config,
)
+from trinity.common.constants import PromptType
def get_template_config() -> Config:
@@ -59,6 +60,47 @@ def get_unittest_dataset_config(
default_workflow_type="math_workflow",
default_reward_fn_type="countdown_reward",
)
+ elif dataset_name == "gsm8k":
+ return StorageConfig(
+ name=dataset_name,
+ path="openai/gsm8k",
+ split=split,
+ subset_name="main",
+ format=FormatConfig(
+ prompt_key="question",
+ response_key="answer",
+ ),
+ rollout_args=GenerationConfig(
+ n=1,
+ temperature=1.0,
+ logprobs=0,
+ ),
+ default_workflow_type="math_workflow",
+ default_reward_fn_type="math_reward",
+ )
+ elif dataset_name == "sft_for_gsm8k":
+ return StorageConfig(
+ name=dataset_name,
+ path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
+ split="train",
+ format=FormatConfig(
+ prompt_type=PromptType.PLAINTEXT,
+ prompt_key="prompt",
+ response_key="response",
+ ),
+ )
+ elif dataset_name == "dpo":
+ return StorageConfig(
+ name=dataset_name,
+ path="HumanLLMs/Human-Like-DPO-Dataset",
+ split="train",
+ format=FormatConfig(
+ prompt_type=PromptType.PLAINTEXT,
+ prompt_key="prompt",
+ chosen_key="chosen",
+ rejected_key="rejected",
+ ),
+ )
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")
@@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]:
raise ValueError(f"Metric '{metric_name}' does not exist.")
return list(self._metrics[metric_name].keys())
+ def metric_values(self, metric_name: str) -> List:
+ if not self.metric_exist(metric_name):
+ raise ValueError(f"Metric '{metric_name}' does not exist.")
+ return list(self._metrics[metric_name].values())
+
def metric_list(self, metric_prefix: str) -> List[str]:
return [name for name in self._metrics if name.startswith(metric_prefix)]
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index ac73e46c8d..55f63ae856 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -14,8 +14,8 @@
get_template_config,
get_unittest_dataset_config,
)
-from trinity.cli.launcher import bench, both
-from trinity.common.constants import MonitorType, SyncMethod
+from trinity.cli.launcher import bench, both, train
+from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
class BaseTrainerCase(RayUnittestBase):
@@ -109,3 +109,107 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerGSM8K(BaseTrainerCase):
+ def test_trainer(self):
+ """Test GSM8K."""
+ # test both mode
+ self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.repeat_times = 4
+ # self.config.algorithm.repeat_times = 8 # TODO: used for real testing
+ self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn_args = {}
+ # self.config.buffer.batch_size = 96 # TODO: used for real testing
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+ # TODO: used for real testing
+ # rewards = parser.metric_values("critic/rewards/mean")
+ # self.assertTrue(0.4 < rewards[0] < 0.55)
+ # self.assertTrue(0.4 < rewards[1] < 0.55)
+ # self.assertTrue(0.6 < rewards[2] < 0.7)
+ # self.assertTrue(0.6 < rewards[3] < 0.7)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerGSM8KWithSFT(BaseTrainerCase):
+ def test_trainer(self):
+ """Test GSM8K With SFT."""
+ # test both mode
+ self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.repeat_times = 4
+ self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn_args = {}
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ self.config.buffer.trainer_input.sft_warmup_steps = 2
+ self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
+ "sft_for_gsm8k"
+ )
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
+ self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerDPO(BaseTrainerCase):
+ def test_trainer(self):
+ """Test DPO."""
+ # test both mode
+ self.config.mode = "train"
+ self.config.algorithm.algorithm_type = AlgorithmType.DPO
+ self.config.algorithm.policy_loss_fn = "dpo"
+ self.config.algorithm.policy_loss_fn_args = {}
+ # self.config.buffer.batch_size = 32
+ self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
+ train(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py
index 3a9ea92f5c..7dfbb7141d 100644
--- a/trinity/algorithm/policy_loss_fn/dpo_loss.py
+++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py
@@ -1,6 +1,6 @@
"""DPO loss function."""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
@@ -19,13 +19,11 @@ def __init__(
self.beta = beta
self.label_smoothing = label_smoothing
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
+ ref_logprob: torch.Tensor,
action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
chosen_logprob = logprob[::2]
@@ -35,8 +33,8 @@ def __call__(
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
- chosen_ref_logprob = old_logprob[::2]
- rejected_ref_logprob = old_logprob[1::2]
+ chosen_ref_logprob = ref_logprob[::2]
+ rejected_ref_logprob = ref_logprob[1::2]
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
@@ -65,3 +63,10 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "ref_logprob",
+ "action_mask",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
index dd521f9ee0..e9457c55d1 100644
--- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
@@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -16,13 +16,12 @@ class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
+ old_logprob: torch.Tensor, # NOT USED!
action_mask: torch.Tensor,
advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
@@ -33,3 +32,11 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "old_logprob",
+ "action_mask",
+ "advantages",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
index eb02c49b46..6c1a29b3e9 100644
--- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
+++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -17,10 +17,6 @@ class PolicyLossFn(ABC):
def __call__(
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
- action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
@@ -29,7 +25,6 @@ def __call__(
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
- experiences (`DataProto`): The input experiences.
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
Returns:
@@ -44,3 +39,11 @@ def default_args(cls) -> Dict:
Returns:
`Dict`: The default init arguments for the policy loss function.
"""
+
+ @property
+ @abstractmethod
+ def select_keys(self) -> List[str]:
+ """
+ Returns:
+ `List[str]`: The keys to select from input data.
+ """
diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
index 9831f048d6..5c735d4d6a 100644
--- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
@@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
-from typing import Any, Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import torch
@@ -30,13 +30,12 @@ def __init__(
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
@@ -62,3 +61,11 @@ def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "old_logprob",
+ "action_mask",
+ "advantages",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py
index c04f775fa3..dd1c75a4a2 100644
--- a/trinity/algorithm/policy_loss_fn/sft_loss.py
+++ b/trinity/algorithm/policy_loss_fn/sft_loss.py
@@ -1,6 +1,6 @@
"""SFT loss function."""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -13,13 +13,10 @@ class SFTLossFn(PolicyLossFn):
def __init__(self, use_token_level_loss: bool = True) -> None:
self.use_token_level_loss = use_token_level_loss
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
if self.use_token_level_loss:
@@ -33,3 +30,7 @@ def default_args(cls):
return {
"use_token_level_loss": True,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return ["action_mask"]
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 794202bab0..5d294abdfd 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -182,6 +182,9 @@ class AlgorithmConfig:
# If not set, use AdvantageFn.default_args()
advantage_fn_args: Optional[dict] = None
+ # used for SFT
+ use_token_level_loss: bool = True
+
@dataclass
class ClusterConfig:
@@ -452,7 +455,7 @@ def _check_buffer(self) -> None: # noqa: C901
and self.buffer.trainer_input.sft_warmup_dataset is None
):
raise ValueError(
- "buffer.trainer_input.sft_warmup_dataset is required when buffer.trainer_input.sft_warmup_steps > 0"
+ "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0"
)
if self.buffer.trainer_input.sft_warmup_dataset is not None:
self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 876ca2835f..7208f83fb4 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -73,7 +73,17 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
Returns:
bool: Whether to continue training.
"""
- self.engine.set_algorithm(self.config.algorithm)
+ if algo_type.is_sft():
+ algorithm_config = AlgorithmConfig(
+ algorithm_type=AlgorithmType.SFT,
+ policy_loss_fn="sft",
+ policy_loss_fn_args={
+ "use_token_level_loss": self.config.algorithm.use_token_level_loss
+ },
+ )
+ self.engine.set_algorithm(algorithm_config)
+ else:
+ self.engine.set_algorithm(self.config.algorithm)
if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index b598bb6dad..97cd186c36 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -279,16 +279,25 @@ def update_policy(self, data: DataProto): # noqa: C901
"temperature"
] # temperature must be in the data.meta_info to avoid slient error
select_keys = [
- "responses",
"input_ids",
- "attention_mask",
"position_ids",
- "old_log_probs",
- "advantages",
+ "attention_mask",
+ "responses",
"response_mask",
]
+ select_keys_verl2trinity = {
+ "old_log_probs": "old_logprob",
+ "ref_log_prob": "ref_logprob",
+ "response_mask": "action_mask",
+ "advantages": "advantages",
+ }
+ select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()}
+ for trinity_key in self.policy_loss_fn.select_keys:
+ verl_key = select_keys_trinity2verl[trinity_key]
+ select_keys.append(verl_key)
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
+ select_keys = list(set(select_keys))
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
@@ -351,11 +360,8 @@ def update_policy(self, data: DataProto): # noqa: C901
responses = data["responses"]
response_length = responses.size(1)
attention_mask = data["attention_mask"]
- # response_mask = attention_mask[:, -response_length:]
response_mask = data["response_mask"]
assert response_mask.shape == attention_mask[:, -response_length:].shape
- old_log_prob = data["old_log_probs"]
- advantages = data["advantages"]
entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
@@ -363,12 +369,14 @@ def update_policy(self, data: DataProto): # noqa: C901
micro_batch=data, temperature=temperature
)
+ kwargs = {
+ select_keys_verl2trinity[verl_key]: value
+ for verl_key, value in data.items()
+ if verl_key in select_keys_verl2trinity
+ }
pg_loss, metric = self.policy_loss_fn( # type: ignore
logprob=log_prob,
- old_logprob=old_log_prob,
- action_mask=response_mask,
- advantages=advantages,
- experiences=data,
+ **kwargs,
)
# compute entropy loss from entropy
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index b6397adde7..ca02b6c288 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -4,6 +4,7 @@
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
+import sys
from pprint import pprint
from typing import Tuple
@@ -169,29 +170,14 @@ def prepare(self):
return
# we start from step 1
- self.global_steps += 1
def _create_dataloader(self):
self.train_dataloader = _InternalDataLoader(self.config)
# TODO: compute total training steps
- # if self.algorithm_type.is_dpo():
- # train_batch_size = self.config.buffer.read_batch_size
- # total_epochs = self.config.trainer.total_epochs
- # from math import ceil
-
- # self.total_training_steps = ceil(
- # self.train_dataloader.size() // train_batch_size * total_epochs
- # )
- # if not self.config.actor_rollout_ref.actor.optim.total_training_steps > 0:
- # self.config.actor_rollout_ref.actor.optim.total_training_steps = (
- # self.total_training_steps
- # )
- # if not self.config.critic.optim.total_training_steps > 0:
- # self.config.critic.optim.total_training_steps = self.total_training_steps
- # else:
- self.total_training_steps = float("inf")
+ self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]:
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -251,12 +237,23 @@ def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- self.global_steps += 1
- return True, self.global_steps - 1
+ if self.global_steps >= self.total_training_steps:
+ if (
+ self.config.trainer.save_freq > 0
+ and self.global_steps % self.config.trainer.save_freq != 0
+ ):
+ with _timer("save_checkpoint", timing_raw):
+ self._save_checkpoint()
+ # stop training
+ return False, self.global_steps
+ else:
+ # continue
+ return True, self.global_steps
def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps:
- return False, self.global_steps - 1
+ return False, self.global_steps
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -308,18 +305,19 @@ def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
# TODO: log as sft metrics
self.logger.log(data=metrics, step=self.global_steps)
self.sft_warmup_step_num += 1
- self.global_steps += 1
+ train_status = True
if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps:
self.logger.log(
data={"sft_warmup_steps": self.sft_warmup_step_num},
- step=self.global_steps - 1,
+ step=self.global_steps,
)
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- return False, self.global_steps - 1
- return True, self.global_steps - 1
+ train_status = False
+ return train_status, self.global_steps
def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -426,20 +424,18 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
- self.global_steps += 1
-
if self.global_steps >= self.total_training_steps:
if (
self.config.trainer.save_freq > 0
- and (self.global_steps - 1) % self.config.trainer.save_freq != 0
+ and self.global_steps % self.config.trainer.save_freq != 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# stop training
- return False, self.global_steps - 1
+ return False, self.global_steps
else:
# continue
- return True, self.global_steps - 1
+ return True, self.global_steps
def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
From 732d801f169c81a3cb0c6f3007ea7bc314ef1f99 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Thu, 5 Jun 2025 13:51:27 +0800
Subject: [PATCH 05/28] Add KL/Entorpy Fn (#64)
---
.../sphinx_doc/source/tutorial/example_dpo.md | 5 +-
examples/dpo_humanlike/dpo.yaml | 5 +-
tests/template/config.yaml | 4 +-
tests/trainer/trainer_test.py | 16 +-
trinity/algorithm/__init__.py | 6 +
.../algorithm/advantage_fn/grpo_advantage.py | 2 +-
.../algorithm/advantage_fn/opmd_advantage.py | 2 +-
.../algorithm/advantage_fn/ppo_advantage.py | 2 +-
.../reinforce_plus_plus_advantage.py | 2 +-
.../algorithm/advantage_fn/remax_advantage.py | 2 +-
.../algorithm/advantage_fn/rloo_advantage.py | 2 +-
trinity/algorithm/entropy_loss/__init__.py | 0
trinity/algorithm/entropy_loss_fn/__init__.py | 9 +
.../entropy_loss_fn/entropy_loss_fn.py | 63 +++++++
trinity/algorithm/kl_fn/__init__.py | 3 +
trinity/algorithm/kl_fn/kl_fn.py | 157 ++++++++++++++++++
trinity/algorithm/kl_loss/__init__.py | 0
trinity/algorithm/utils.py | 8 +
trinity/common/config.py | 49 +++++-
trinity/common/verl_config.py | 11 +-
trinity/trainer/trainer.py | 4 +
trinity/trainer/verl/dp_actor.py | 59 ++++---
trinity/trainer/verl_trainer.py | 25 +--
23 files changed, 361 insertions(+), 75 deletions(-)
delete mode 100644 trinity/algorithm/entropy_loss/__init__.py
create mode 100644 trinity/algorithm/entropy_loss_fn/__init__.py
create mode 100644 trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
create mode 100644 trinity/algorithm/kl_fn/__init__.py
create mode 100644 trinity/algorithm/kl_fn/kl_fn.py
delete mode 100644 trinity/algorithm/kl_loss/__init__.py
diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md
index a6f70f5e62..44543ff2bc 100644
--- a/docs/sphinx_doc/source/tutorial/example_dpo.md
+++ b/docs/sphinx_doc/source/tutorial/example_dpo.md
@@ -48,6 +48,9 @@ name:
mode: train
algorithm:
algorithm_type: dpo
+ kl_loss_fn: k1
+ kl_loss_fn_args:
+ kl_coef: 0.1 # value of beta in DPO
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
@@ -70,8 +73,6 @@ buffer:
trainer:
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
save_interval: 30
- actor_use_kl_loss: True
- actor_kl_loss_coef: 0.1 # value of beta in DPO
```
### Run the Experiment
diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml
index 8cd3dbe0c8..0a0864b8ef 100644
--- a/examples/dpo_humanlike/dpo.yaml
+++ b/examples/dpo_humanlike/dpo.yaml
@@ -3,6 +3,9 @@ name: "trinity_dpo"
mode: train
algorithm:
algorithm_type: dpo
+ kl_loss_fn: k1
+ kl_loss_fn_args:
+ kl_coef: 0.1
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL
@@ -34,5 +37,3 @@ trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
save_interval: 30
- actor_use_kl_loss: True
- actor_kl_loss_coef: 0.1
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index 3a767df243..98180fff48 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -8,10 +8,12 @@ algorithm:
policy_loss_fn: ppo
policy_loss_fn_args:
clip_range: 0.2
- advantage_fn_type: ppo_adv_fn
+ advantage_fn: ppo
advantage_fn_args:
gamma: 1.0
lam: 1.0
+ kl_penalty_fn: k3
+ kl_loss_fn: k2
model:
model_path: ''
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 55f63ae856..e83b443c4b 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -67,6 +67,10 @@ def test_trainer(self):
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
+ actor_kl_metrics = parser.metric_list("actor/kl")
+ self.assertTrue(len(actor_kl_metrics) > 0)
+ critic_kl_metrics = parser.metric_list("critic/kl")
+ self.assertTrue(len(critic_kl_metrics) > 0)
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
@@ -86,7 +90,7 @@ def test_trainer(self):
)
self.assertTrue(os.path.exists(checkpoint_step_4))
self.assertTrue(os.path.exists(checkpoint_step_8))
-
+ # TODO: Reinit will fail when using v1 engine, find a way to fix it
ray.init(ignore_reinit_error=True)
# test bench mode
self.config.mode = "bench"
@@ -118,7 +122,7 @@ def test_trainer(self):
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
- self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {}
# self.config.buffer.batch_size = 96 # TODO: used for real testing
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
@@ -143,8 +147,6 @@ def test_trainer(self):
# self.assertTrue(0.4 < rewards[1] < 0.55)
# self.assertTrue(0.6 < rewards[2] < 0.7)
# self.assertTrue(0.6 < rewards[3] < 0.7)
- ray.shutdown(_exiting_interpreter=True)
- # check checkpoint
def tearDown(self):
# remove dir only when the test passed
@@ -157,7 +159,7 @@ def test_trainer(self):
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
- self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {}
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.trainer_input.sft_warmup_steps = 2
@@ -180,8 +182,6 @@ def test_trainer(self):
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
- ray.shutdown(_exiting_interpreter=True)
- # check checkpoint
def tearDown(self):
# remove dir only when the test passed
@@ -207,8 +207,6 @@ def test_trainer(self):
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
- ray.shutdown(_exiting_interpreter=True)
- # check checkpoint
def tearDown(self):
# remove dir only when the test passed
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
index 170507663f..101364c57c 100644
--- a/trinity/algorithm/__init__.py
+++ b/trinity/algorithm/__init__.py
@@ -1,4 +1,6 @@
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
+from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
__all__ = [
@@ -6,4 +8,8 @@
"ADVANTAGE_FN",
"PolicyLossFn",
"POLICY_LOSS_FN",
+ "KLFn",
+ "KL_FN",
+ "EntropyLossFn",
+ "ENTROPY_LOSS_FN",
]
diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py
index 89a8282752..37f824de4f 100644
--- a/trinity/algorithm/advantage_fn/grpo_advantage.py
+++ b/trinity/algorithm/advantage_fn/grpo_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("grpo_adv_fn")
+@ADVANTAGE_FN.register_module("grpo")
class GRPOAdvantageFn(AdvantageFn):
"""GRPO advantage computation"""
diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py
index abf74686d3..e9e0eb090f 100644
--- a/trinity/algorithm/advantage_fn/opmd_advantage.py
+++ b/trinity/algorithm/advantage_fn/opmd_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("opmd_adv_fn")
+@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""
diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py
index 5afd51311c..896deca116 100644
--- a/trinity/algorithm/advantage_fn/ppo_advantage.py
+++ b/trinity/algorithm/advantage_fn/ppo_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("ppo_adv_fn")
+@ADVANTAGE_FN.register_module("ppo")
class PPOAdvantageFn(AdvantageFn):
def __init__(
self,
diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
index 9c668f7640..d53052c83f 100644
--- a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
+++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn")
+@ADVANTAGE_FN.register_module("reinforceplusplus")
class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
def __init__(self, gamma: float = 1.0) -> None:
self.gamma = gamma
diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py
index 05a13d7d60..516213c0c2 100644
--- a/trinity/algorithm/advantage_fn/remax_advantage.py
+++ b/trinity/algorithm/advantage_fn/remax_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("remax_adv_fn")
+@ADVANTAGE_FN.register_module("remax")
class REMAXAdvantageFn(AdvantageFn):
def __init__(self) -> None:
pass
diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py
index 3da61c9da4..c88276e836 100644
--- a/trinity/algorithm/advantage_fn/rloo_advantage.py
+++ b/trinity/algorithm/advantage_fn/rloo_advantage.py
@@ -11,7 +11,7 @@
from trinity.trainer.verl import core_algos
-@ADVANTAGE_FN.register_module("rloo_adv_fn")
+@ADVANTAGE_FN.register_module("rloo")
class RLOOAdvantageFn(AdvantageFn):
def __init__(self) -> None:
pass
diff --git a/trinity/algorithm/entropy_loss/__init__.py b/trinity/algorithm/entropy_loss/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/trinity/algorithm/entropy_loss_fn/__init__.py b/trinity/algorithm/entropy_loss_fn/__init__.py
new file mode 100644
index 0000000000..d932b94fde
--- /dev/null
+++ b/trinity/algorithm/entropy_loss_fn/__init__.py
@@ -0,0 +1,9 @@
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import (
+ ENTROPY_LOSS_FN,
+ EntropyLossFn,
+)
+
+__all__ = [
+ "EntropyLossFn",
+ "ENTROPY_LOSS_FN",
+]
diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
new file mode 100644
index 0000000000..4df9272ca0
--- /dev/null
+++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
@@ -0,0 +1,63 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Tuple
+
+import torch
+
+from trinity.algorithm.utils import masked_mean
+from trinity.utils.registry import Registry
+
+ENTROPY_LOSS_FN = Registry("entropy_loss_fn")
+
+
+class EntropyLossFn(ABC):
+ """
+ Entropy loss function.
+ """
+
+ @abstractmethod
+ def __call__(
+ self,
+ entropy: torch.Tensor,
+ action_mask: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ """
+ Args:
+ entropy (`torch.Tensor`): The entropy generated by the policy model.
+ action_mask (`torch.Tensor`): The action mask.
+
+ Returns:
+ `torch.Tensor`: The calculated entropy loss.
+ `Dict`: The metrics for logging
+ """
+
+ @classmethod
+ @abstractmethod
+ def default_args(cls) -> Dict:
+ """
+ Returns:
+ `Dict`: The default arguments for the entropy loss function.
+ """
+
+
+@ENTROPY_LOSS_FN.register_module("basic")
+class BasicEntropyLossFn(EntropyLossFn):
+ """
+ Basic entropy loss function.
+ """
+
+ def __init__(self, entropy_coef: float):
+ self.entropy_coef = entropy_coef
+
+ def __call__(
+ self,
+ entropy: torch.Tensor,
+ action_mask: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ entropy_loss = masked_mean(entropy, action_mask)
+ return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {"entropy_coef": 0.0}
diff --git a/trinity/algorithm/kl_fn/__init__.py b/trinity/algorithm/kl_fn/__init__.py
new file mode 100644
index 0000000000..875c620442
--- /dev/null
+++ b/trinity/algorithm/kl_fn/__init__.py
@@ -0,0 +1,3 @@
+from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn
+
+__all__ = ["KLFn", "KL_FN"]
diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py
new file mode 100644
index 0000000000..3901ea7f3c
--- /dev/null
+++ b/trinity/algorithm/kl_fn/kl_fn.py
@@ -0,0 +1,157 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+
+from trinity.algorithm.utils import masked_mean
+from trinity.utils.registry import Registry
+
+KL_FN = Registry("kl_fn")
+
+
+class KLFn(ABC):
+ """
+ KL controller.
+ """
+
+ def __init__(
+ self,
+ adaptive: bool = False,
+ kl_coef: float = 0.001,
+ target_kl: Optional[float] = None,
+ horizon: Optional[float] = None,
+ ) -> None:
+ self.kl_coef = kl_coef
+ self.adaptive = adaptive
+ self.target_kl = target_kl
+ self.horizon = horizon
+ if adaptive and (target_kl is None or horizon is None):
+ raise ValueError("Target KL and horizon must be provided for adaptive KL.")
+
+ def update_kl_coef(self, current_kl: float, batch_size: int) -> None:
+ """Update kl coefficient."""
+ if self.adaptive:
+ target_kl = self.target_kl
+ proportional_error = torch.clip(current_kl / target_kl - 1, -0.2, 0.2).item() # type: ignore
+ multiplier = 1 + proportional_error * batch_size / self.horizon
+ self.kl_coef *= multiplier
+
+ def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]:
+ """Apply KL penalty to reward. Only support DataProto input for now."""
+ responses = experiences.batch["responses"]
+ response_length = responses.size(1)
+ token_level_scores = experiences.batch["token_level_scores"]
+ batch_size = experiences.batch.batch_size[0]
+ attention_mask = experiences.batch["attention_mask"]
+ response_mask = experiences.batch["response_mask"]
+ assert response_mask.shape == attention_mask[:, -response_length:].shape
+ logprob = experiences.batch["old_log_probs"]
+ ref_logprob = experiences.batch["ref_log_prob"]
+
+ if "ref_log_prob" in experiences.batch.keys():
+ kl = self.calculate_kl(logprob, ref_logprob)
+ kl = kl * response_mask
+ kl_coef = self.kl_coef
+ experiences.batch["token_level_rewards"] = token_level_scores - kl_coef * kl
+ else:
+ kl_coef = 0.0
+ kl = torch.zeros_like(response_mask, dtype=torch.float32)
+ experiences.batch["token_level_rewards"] = token_level_scores
+
+ current_kl = masked_mean(kl, mask=response_mask, axis=-1).mean(dim=0).item()
+ self.update_kl_coef(current_kl=current_kl, batch_size=batch_size)
+
+ metrics = {
+ "kl": current_kl,
+ "kl_coef": kl_coef,
+ }
+
+ return experiences, metrics
+
+ def calculate_kl_loss(
+ self,
+ logprob: torch.Tensor,
+ ref_logprob: torch.Tensor,
+ response_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict]:
+ """Compute KL loss."""
+ kl = self.calculate_kl(logprob, ref_logprob)
+ kl_loss = masked_mean(kl, response_mask)
+ metrics = {
+ "kl_loss": kl_loss.detach().item(),
+ "kl_coef": self.kl_coef,
+ }
+ return kl_loss * self.kl_coef, metrics
+
+ @abstractmethod
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ """Compute KL divergence between logprob and ref_logprob."""
+
+ @classmethod
+ def default_args(cls):
+ """Get the default initialization arguments."""
+ return {"adaptive": False, "kl_coef": 0.001}
+
+
+@KL_FN.register_module("none")
+class DummyFn(KLFn):
+ """
+ Dummy KL function.
+ """
+
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ return torch.zeros_like(logprob)
+
+ def apply_kl_penalty_to_reward(self, experiences: Any) -> Tuple[Any, Dict]:
+ experiences.batch["token_level_rewards"] = experiences.batch["token_level_scores"]
+ return experiences, {}
+
+ def calculate_kl_loss(
+ self,
+ logprob: torch.Tensor,
+ ref_logprob: torch.Tensor,
+ response_mask: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict]:
+ # return a zero tensor
+ return torch.tensor(0.0), {}
+
+
+@KL_FN.register_module("k1")
+class K1Fn(KLFn):
+ """
+ KL K1 function.
+ """
+
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ return logprob - ref_logprob
+
+
+@KL_FN.register_module("k2")
+class K2Fn(KLFn):
+ """
+ KL K2 function.
+ """
+
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ return (logprob - ref_logprob).square() * 0.5
+
+
+@KL_FN.register_module("k3")
+class K3Fn(KLFn):
+ """
+ KL K3 function.
+ """
+
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ logr = ref_logprob - logprob
+ return logr.exp() - 1 - logr
+
+
+@KL_FN.register_module("abs")
+class AbsFn(KLFn):
+ """
+ KL Abs function.
+ """
+
+ def calculate_kl(self, logprob: torch.Tensor, ref_logprob: torch.Tensor) -> torch.Tensor:
+ return torch.abs(logprob - ref_logprob)
diff --git a/trinity/algorithm/kl_loss/__init__.py b/trinity/algorithm/kl_loss/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py
index d5cfb72d8c..01356cc066 100644
--- a/trinity/algorithm/utils.py
+++ b/trinity/algorithm/utils.py
@@ -12,3 +12,11 @@ def masked_sum(values, mask, axis=None):
def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
+
+
+def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict:
+ if dst_metrics is None:
+ dst_metrics = {}
+ for k, v in src_metrics.items():
+ dst_metrics[f"{prefix}/{k}"] = v
+ return dst_metrics
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 5d294abdfd..91c7790571 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -178,11 +178,24 @@ class AlgorithmConfig:
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None
- advantage_fn_type: str = "ppo_adv_fn"
+ advantage_fn: str = "ppo"
# If not set, use AdvantageFn.default_args()
advantage_fn_args: Optional[dict] = None
- # used for SFT
+ kl_penalty_fn: str = "none" # set to "none" to disable kl penalty in reward
+ # If not set, use kl_penalty_fn.default_args()
+ kl_penalty_fn_args: Optional[dict] = None
+
+ kl_loss_fn: str = "k2" # set to "none" to disable kl loss
+ # If not set, use kl_loss_fn.default_args()
+ kl_loss_fn_args: Optional[dict] = None
+
+ entropy_loss_fn: str = "basic"
+ # If not set, use entropy_loss_fn.default_args()
+ entropy_loss_fn_args: Optional[dict] = None
+
+ # used for SFT warmup
+ # TODO: move this to SFT warmup
use_token_level_loss: bool = True
@@ -271,9 +284,6 @@ class TrainerConfig:
enable_preview: bool = True # enable rollout preview in wandb
# trainer configs
- actor_use_kl_loss: Optional[bool] = None
- actor_kl_loss_coef: Optional[float] = None
- actor_entropy_coef: Optional[float] = None
actor_grad_clip: Optional[float] = None
actor_clip_ratio: Optional[float] = None
# TODO: extract more train-related params from underlying trainer engine
@@ -475,7 +485,12 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.tokenizer_path = self.model.model_path
def _check_algorithm(self) -> None:
- from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN
+ from trinity.algorithm import (
+ ADVANTAGE_FN,
+ ENTROPY_LOSS_FN,
+ KL_FN,
+ POLICY_LOSS_FN,
+ )
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
@@ -483,12 +498,30 @@ def _check_algorithm(self) -> None:
if self.algorithm.policy_loss_fn_args is None:
self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()
- advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type)
+ advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn)
if advantage_fn_cls is None:
- raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}")
+ raise ValueError(f"Invalid advantage_fn: {self.algorithm.advantage_fn}")
if self.algorithm.advantage_fn_args is None:
self.algorithm.advantage_fn_args = advantage_fn_cls.default_args()
+ kl_loss_fn_cls = KL_FN.get(self.algorithm.kl_loss_fn)
+ if kl_loss_fn_cls is None:
+ raise ValueError(f"Invalid kl_loss_fn: {self.algorithm.kl_loss_fn}")
+ if self.algorithm.kl_loss_fn_args is None:
+ self.algorithm.kl_loss_fn_args = kl_loss_fn_cls.default_args()
+
+ kl_penalty_fn_cls = KL_FN.get(self.algorithm.kl_penalty_fn)
+ if kl_penalty_fn_cls is None:
+ raise ValueError(f"Invalid kl_penalty_fn: {self.algorithm.kl_penalty_fn}")
+ if self.algorithm.kl_penalty_fn_args is None:
+ self.algorithm.kl_penalty_fn_args = kl_penalty_fn_cls.default_args()
+
+ entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.algorithm.entropy_loss_fn)
+ if entropy_loss_fn_cls is None:
+ raise ValueError(f"Invalid entropy_loss_fn: {self.algorithm.entropy_loss_fn}")
+ if self.algorithm.entropy_loss_fn_args is None:
+ self.algorithm.entropy_loss_fn_args = entropy_loss_fn_cls.default_args()
+
def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index fb9f810dee..e8180f4718 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -306,12 +306,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.critic.ppo_mini_batch_size = config.buffer.batch_size
self.critic.rollout_n = self.actor_rollout_ref.rollout.n
- if config.trainer.actor_use_kl_loss is not None:
- self.actor_rollout_ref.actor.use_kl_loss = config.trainer.actor_use_kl_loss
- if config.trainer.actor_kl_loss_coef is not None:
- self.actor_rollout_ref.actor.kl_loss_coef = config.trainer.actor_kl_loss_coef
- if config.trainer.actor_entropy_coef is not None:
- self.actor_rollout_ref.actor.entropy_coeff = config.trainer.actor_entropy_coef
if config.trainer.actor_grad_clip is not None:
self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip
if config.trainer.actor_clip_ratio is not None:
@@ -330,6 +324,11 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD):
logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD")
self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value
+ self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
+ self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore
+ self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore
+ "entropy_coef"
+ ]
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
# True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper).
# Need to double check whether this is indeed the case,
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 7208f83fb4..c2ad5fec96 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -80,6 +80,10 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
policy_loss_fn_args={
"use_token_level_loss": self.config.algorithm.use_token_level_loss
},
+ kl_loss_fn="none",
+ kl_loss_fn_args={},
+ entropy_loss_fn="basic",
+ entropy_loss_fn_args=self.config.algorithm.entropy_loss_fn_args,
)
self.engine.set_algorithm(algorithm_config)
else:
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 97cd186c36..a7705fc6a0 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -26,14 +26,14 @@
from verl import DataProto
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
-from verl.utils.torch_functional import logprobs_from_logits, masked_mean
+from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor
-from trinity.algorithm import POLICY_LOSS_FN
+from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
+from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig
from trinity.common.constants import AlgorithmType
-from trinity.trainer.verl import core_algos
__all__ = ["DataParallelPPOActor"]
@@ -63,6 +63,10 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
**algorithm_config.policy_loss_fn_args
)
+ self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args)
+ self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
+ **algorithm_config.entropy_loss_fn_args
+ )
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -347,6 +351,8 @@ def update_policy(self, data: DataProto): # noqa: C901
self.actor_optimizer.zero_grad()
for data in micro_batches:
+ micro_batch_metrics = {}
+
# Support all hardwares
if isinstance(data, DataProto):
data = {
@@ -362,7 +368,6 @@ def update_policy(self, data: DataProto): # noqa: C901
attention_mask = data["attention_mask"]
response_mask = data["response_mask"]
assert response_mask.shape == attention_mask[:, -response_length:].shape
- entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(
@@ -374,30 +379,37 @@ def update_policy(self, data: DataProto): # noqa: C901
for verl_key, value in data.items()
if verl_key in select_keys_verl2trinity
}
- pg_loss, metric = self.policy_loss_fn( # type: ignore
+ pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore
logprob=log_prob,
**kwargs,
)
+ prefix_metrics(
+ src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
+ )
# compute entropy loss from entropy
- entropy_loss = verl_F.masked_mean(entropy, response_mask)
+ entropy_loss, entropy_loss_metrics = self.entropy_loss_fn(
+ entropy=entropy,
+ action_mask=response_mask,
+ )
+ prefix_metrics(
+ src_metrics=entropy_loss_metrics,
+ prefix="actor",
+ dst_metrics=micro_batch_metrics,
+ )
# compute policy loss
- policy_loss = pg_loss - entropy_loss * entropy_coeff
-
- if self.config.use_kl_loss:
- ref_log_prob = data["ref_log_prob"]
- # compute kl loss
- kld = core_algos.kl_penalty(
- logprob=log_prob,
- ref_logprob=ref_log_prob,
- kl_penalty=self.config.kl_loss_type,
- )
- kl_loss = masked_mean(kld, response_mask)
+ policy_loss = pg_loss - entropy_loss
- policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
- metrics["actor/kl_loss"] = kl_loss.detach().item()
- metrics["actor/kl_coef"] = self.config.kl_loss_coef
+ kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss(
+ logprob=log_prob,
+ ref_logprob=data["ref_log_prob"],
+ response_mask=response_mask,
+ )
+ prefix_metrics(
+ src_metrics=kl_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
+ )
+ policy_loss = policy_loss + kl_loss
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
@@ -406,13 +418,10 @@ def update_policy(self, data: DataProto): # noqa: C901
loss = policy_loss / self.gradient_accumulation
loss.backward()
- data = {f"actor/{key}": value for key, value in metric.items()}
- # TODO: refactor entropy loss
- data["actor/entropy_loss"] = entropy_loss.detach().item()
- append_to_dict(metrics, data)
+ append_to_dict(metrics, micro_batch_metrics)
grad_norm = self._optimizer_step()
data = {"actor/grad_norm": grad_norm.detach().item()}
- append_to_dict(metrics, data)
+ append_to_dict(metrics, data)
self.actor_optimizer.zero_grad()
return metrics
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index ca02b6c288..83e3480dc3 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -22,7 +22,8 @@
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
-from trinity.algorithm import ADVANTAGE_FN
+from trinity.algorithm import ADVANTAGE_FN, KL_FN
+from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig, Config
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
@@ -34,7 +35,6 @@
ResourcePoolManager,
Role,
_timer,
- apply_kl_penalty,
find_latest_ckpt_path,
)
from trinity.utils.monitor import Monitor
@@ -133,9 +133,10 @@ def __init__(
# specify advantage function for various rft algorithms
algo_config = global_config.algorithm
if algo_config.algorithm_type.is_rft():
- adv_fn_type = algo_config.advantage_fn_type
- adv_fn_args = algo_config.advantage_fn_args
- self.advantage_fn = ADVANTAGE_FN.get(adv_fn_type)(**adv_fn_args)
+ self.advantage_fn = ADVANTAGE_FN.get(algo_config.advantage_fn)(
+ **algo_config.advantage_fn_args
+ )
+ self.kl_fn = KL_FN.get(algo_config.kl_penalty_fn)(**algo_config.kl_penalty_fn_args)
self.logger = Monitor(
project=config.trainer.project_name,
@@ -373,17 +374,9 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
batch = batch.union(values)
with _timer("adv", timing_raw):
- # compute rewards. apply_kl_penalty if available
- if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False):
- batch, kl_metrics = apply_kl_penalty(
- batch,
- kl_ctrl=self.kl_ctrl,
- kl_penalty=self.config.algorithm.kl_penalty,
- )
- metrics.update(kl_metrics)
- else:
- batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
-
+ # compute kl penalty
+ batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch)
+ metrics.update(prefix_metrics(kl_metrics, prefix="critic"))
# compute advantages, executed on the driver process
batch, _ = self.advantage_fn(batch)
From 2d8f0c1464e8640785fa8483332fd41bfd56610a Mon Sep 17 00:00:00 2001
From: Yanxi Chen <153061753+yanxi-chen@users.noreply.github.com>
Date: Thu, 5 Jun 2025 19:38:42 +0800
Subject: [PATCH 06/28] Refactor advantage computation (cont.) (#68)
---
.../algorithm/advantage_fn/grpo_advantage.py | 65 +++++++++++---
.../algorithm/advantage_fn/opmd_advantage.py | 85 +++++++++++++++----
.../algorithm/advantage_fn/ppo_advantage.py | 54 ++++++++++--
.../reinforce_plus_plus_advantage.py | 38 +++++++--
.../algorithm/advantage_fn/remax_advantage.py | 40 +++++++--
.../algorithm/advantage_fn/rloo_advantage.py | 53 ++++++++++--
trinity/algorithm/kl_fn/kl_fn.py | 10 ++-
.../policy_loss_fn/opmd_policy_loss.py | 5 +-
trinity/algorithm/utils.py | 40 +++++++++
9 files changed, 323 insertions(+), 67 deletions(-)
diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py
index 37f824de4f..553af6d065 100644
--- a/trinity/algorithm/advantage_fn/grpo_advantage.py
+++ b/trinity/algorithm/advantage_fn/grpo_advantage.py
@@ -1,35 +1,74 @@
"""GRPO advantage computation
-Adapted from compute_advantage_ppo in original ray_trainer.py
+Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
+from collections import defaultdict
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
@ADVANTAGE_FN.register_module("grpo")
class GRPOAdvantageFn(AdvantageFn):
"""GRPO advantage computation"""
- def __init__(self) -> None:
- pass
+ def __init__(
+ self,
+ epsilon: float = 1e-6,
+ ) -> None:
+ self.epsilon = epsilon
def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_grpo_outcome_advantage(
- token_level_rewards=exps.batch["token_level_rewards"],
- eos_mask=exps.batch["response_mask"],
- index=exps.non_tensor_batch["uid"],
- )
- exps.batch["advantages"] = advantages
- exps.batch["returns"] = returns
+ """
+ Compute advantage for GRPO, operating only on Outcome reward
+ (with only one scalar reward for each response).
+
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length)
+ scores: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ eos_mask = exps.batch["response_mask"]
+ index = exps.non_tensor_batch["uid"]
+ epsilon = self.epsilon
+
+ response_length = token_level_rewards.shape[-1]
+ scores = token_level_rewards.sum(dim=-1)
+
+ id2score = defaultdict(list)
+ id2mean = {}
+ id2std = {}
+
+ with torch.no_grad():
+ bsz = scores.shape[0]
+ for i in range(bsz):
+ id2score[index[i]].append(scores[i])
+ for idx in id2score:
+ if len(id2score[idx]) == 1:
+ id2mean[idx] = torch.tensor(0.0)
+ id2std[idx] = torch.tensor(1.0)
+ elif len(id2score[idx]) > 1:
+ id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
+ id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
+ else:
+ raise ValueError(f"no score in prompt index: {idx}")
+ for i in range(bsz):
+ scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
+ scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
+
+ exps.batch["advantages"] = scores
+ exps.batch["returns"] = scores
metrics = {
# TODO: add meaningful metrics
@@ -39,4 +78,6 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
- return {}
+ return {
+ "epsilon": 1e-6,
+ }
diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py
index e9e0eb090f..b27e2c9ab0 100644
--- a/trinity/algorithm/advantage_fn/opmd_advantage.py
+++ b/trinity/algorithm/advantage_fn/opmd_advantage.py
@@ -1,38 +1,84 @@
-"""OPMD advantage computation
-
-Adapted from compute_advantage_opmd in original ray_trainer.py
-"""
+"""OPMD advantage computation"""
+from collections import defaultdict
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""
- def __init__(self) -> None:
- pass
+ def __init__(
+ self,
+ opmd_baseline: str = "mean",
+ tau: float = 1.0,
+ ) -> None:
+ self.opmd_baseline = opmd_baseline
+ self.tau = tau
def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_opmd_outcome_advantage(
- token_level_rewards=exps.batch["token_level_rewards"],
- eos_mask=exps.batch["response_mask"],
- # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
- index=exps.non_tensor_batch["uid"],
- opmd_baseline="mean",
- tau=1.0,
- )
- exps.batch["advantages"] = advantages
- exps.batch["returns"] = returns
+ """Modified from compute_grpo_outcome_advantage
+
+ Compute advantage for OPMD, operating only on Outcome reward
+ (with only one scalar reward for each response).
+
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length)
+ scores: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ eos_mask = exps.batch["response_mask"]
+ # TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
+ index = exps.non_tensor_batch["uid"]
+ opmd_baseline = self.opmd_baseline
+ tau = self.tau
+
+ response_length = token_level_rewards.shape[-1]
+ scores = token_level_rewards.sum(dim=-1)
+
+ id2score = defaultdict(list)
+ id2baseline = {}
+
+ with torch.no_grad():
+ bsz = scores.shape[0]
+ for i in range(bsz):
+ id2score[index[i]].append(scores[i])
+ for idx in id2score:
+ if len(id2score[idx]) == 1:
+ id2baseline[idx] = torch.tensor(0.0)
+ # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
+ elif len(id2score[idx]) > 1:
+ if opmd_baseline == "mean":
+ id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
+ elif opmd_baseline == "logavgexp":
+ rewards_tensor = torch.tensor(id2score[idx])
+ # here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
+ id2baseline[idx] = tau * (
+ torch.logsumexp(rewards_tensor / tau, dim=-1)
+ - torch.log(torch.tensor(len(id2score[idx])))
+ )
+ else:
+ raise NotImplementedError
+ else:
+ raise ValueError(f"no score in prompt index: {idx}")
+ for i in range(bsz):
+ scores[i] = scores[i] - id2baseline[index[i]]
+ scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
+
+ exps.batch["advantages"] = scores
+ exps.batch["returns"] = scores
metrics = {
# TODO: add meaningful metrics
@@ -42,4 +88,7 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
- return {}
+ return {
+ "opmd_baseline": "mean",
+ "tau": 1.0,
+ }
diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py
index 896deca116..31fda4454c 100644
--- a/trinity/algorithm/advantage_fn/ppo_advantage.py
+++ b/trinity/algorithm/advantage_fn/ppo_advantage.py
@@ -1,14 +1,15 @@
"""PPO's GAE advantage computation
-Adapted from compute_advantage_ppo in original ray_trainer.py
+Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
+from trinity.algorithm.utils import masked_whiten
@ADVANTAGE_FN.register_module("ppo")
@@ -26,13 +27,48 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_gae_advantage_return(
- token_level_rewards=exps.batch["token_level_rewards"],
- values=exps.batch["values"],
- eos_mask=exps.batch["response_mask"],
- gamma=self.gamma,
- lam=self.lam,
- )
+ """
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ values: `(torch.Tensor)`
+ shape: (bs, response_length)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
+ gamma: `(float)`
+ discounted factor used in RL
+ lam: `(float)`
+ lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
+ advantages: `(torch.Tensor)`
+ shape: (bs, response_length)
+ returns: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ values = exps.batch["values"]
+ eos_mask = exps.batch["response_mask"]
+ gamma = self.gamma
+ lam = self.lam
+
+ with torch.no_grad():
+ lastgaelam = 0
+ advantages_reversed = []
+ gen_len = token_level_rewards.shape[-1]
+
+ # values = values * eos_mask TODO: may use in multi-turn
+ for t in reversed(range(gen_len)):
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
+ delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
+
+ lastgaelam = delta + gamma * lam * lastgaelam
+ # lastgaelam = torch.where( # TODO: may use in multi-turn
+ # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
+ # )
+ advantages_reversed.append(lastgaelam)
+ advantages = torch.stack(advantages_reversed[::-1], dim=1)
+
+ returns = advantages + values
+ advantages = masked_whiten(advantages, eos_mask)
+
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns
diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
index d53052c83f..eb63c3605b 100644
--- a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
+++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
@@ -1,14 +1,15 @@
"""REINFORCE++ advantage computation
-Adapted from compute_advantage_ppo in original ray_trainer.py
+Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
+from trinity.algorithm.utils import masked_whiten
@ADVANTAGE_FN.register_module("reinforceplusplus")
@@ -21,11 +22,34 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
- token_level_rewards=exps.batch["token_level_rewards"],
- eos_mask=exps.batch["response_mask"],
- gamma=self.gamma,
- )
+ """
+ Compute advantage for REINFORCE++.
+ This implementation is based on the paper: https://arxiv.org/abs/2501.03262
+
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length)
+ advantages: `(torch.Tensor)`
+ shape: (bs, response_length)
+ returns: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ eos_mask = exps.batch["response_mask"]
+ gamma = self.gamma
+
+ with torch.no_grad():
+ returns = torch.zeros_like(token_level_rewards)
+ running_return = 0
+
+ for t in reversed(range(token_level_rewards.shape[1])):
+ running_return = token_level_rewards[:, t] + gamma * running_return
+ returns[:, t] = running_return
+
+ advantages = masked_whiten(returns, eos_mask)
+ advantages = advantages * eos_mask
+
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns
diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py
index 516213c0c2..07f92d91a0 100644
--- a/trinity/algorithm/advantage_fn/remax_advantage.py
+++ b/trinity/algorithm/advantage_fn/remax_advantage.py
@@ -1,14 +1,14 @@
"""REMAX advantage computation
-Adapted from compute_advantage_ppo in original ray_trainer.py
+Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
@ADVANTAGE_FN.register_module("remax")
@@ -21,11 +21,37 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_remax_outcome_advantage(
- token_level_rewards=exps.batch["token_level_rewards"],
- reward_baselines=exps.batch["reward_baselines"],
- eos_mask=exps.batch["response_mask"],
- )
+ """
+ Compute advantage for ReMax, operating only on Outcome reward
+ (with only one scalar reward for each response).
+ This implementation is based on the paper: https://arxiv.org/abs/2310.10505
+
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ reward_baselines: `(torch.Tensor)`
+ shape: (bs,)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length)
+ advantages: `(torch.Tensor)`
+ shape: (bs, response_length)
+ returns: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ reward_baselines = exps.batch["reward_baselines"]
+ eos_mask = exps.batch["response_mask"]
+
+ response_length = token_level_rewards.shape[-1]
+ token_level_rewards.sum(dim=-1)
+
+ with torch.no_grad():
+ returns = (
+ (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
+ )
+ advantages = (
+ returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
+ )
+
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns
diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py
index c88276e836..fb2680a68b 100644
--- a/trinity/algorithm/advantage_fn/rloo_advantage.py
+++ b/trinity/algorithm/advantage_fn/rloo_advantage.py
@@ -1,14 +1,15 @@
"""RLOO advantage computation
-Adapted from compute_advantage_ppo in original ray_trainer.py
+Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
+from collections import defaultdict
from typing import Dict, Tuple
+import torch
from verl import DataProto
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
-from trinity.trainer.verl import core_algos
@ADVANTAGE_FN.register_module("rloo")
@@ -21,13 +22,47 @@ def __call__(
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
- advantages, returns = core_algos.compute_rloo_outcome_advantage(
- token_level_rewards=exps.batch["token_level_rewards"],
- eos_mask=exps.batch["response_mask"],
- index=exps.non_tensor_batch["uid"],
- )
- exps.batch["advantages"] = advantages
- exps.batch["returns"] = returns
+ """
+ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
+
+ token_level_rewards: `(torch.Tensor)`
+ shape: (bs, response_length)
+ eos_mask: `(torch.Tensor)`
+ shape: (bs, response_length)
+ scores: `(torch.Tensor)`
+ shape: (bs, response_length)
+ """
+ token_level_rewards = exps.batch["token_level_rewards"]
+ eos_mask = exps.batch["response_mask"]
+ index = exps.non_tensor_batch["uid"]
+
+ response_length = token_level_rewards.shape[-1]
+ scores = token_level_rewards.sum(dim=-1)
+
+ id2score = defaultdict(list)
+ id2mean = {}
+
+ with torch.no_grad():
+ bsz = scores.shape[0]
+ for i in range(bsz):
+ id2score[index[i]].append(scores[i])
+ for idx in id2score:
+ if len(id2score[idx]) == 1:
+ id2mean[idx] = torch.tensor(0.0)
+ elif len(id2score[idx]) > 1:
+ id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
+ else:
+ raise ValueError(f"no score in prompt index: {idx}")
+ for i in range(bsz):
+ response_num = len(id2score[index[i]])
+ if response_num > 1:
+ scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[
+ index[i]
+ ] * response_num / (response_num - 1)
+ scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
+
+ exps.batch["advantages"] = scores
+ exps.batch["returns"] = scores
metrics = {
# TODO: add meaningful metrics
diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py
index 3901ea7f3c..95d2915a84 100644
--- a/trinity/algorithm/kl_fn/kl_fn.py
+++ b/trinity/algorithm/kl_fn/kl_fn.py
@@ -1,3 +1,11 @@
+"""KL penalty and loss.
+
+Ref:
+https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
+https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py
+https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/utils.py
+"""
+
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
@@ -11,7 +19,7 @@
class KLFn(ABC):
"""
- KL controller.
+ KL penalty and loss.
"""
def __init__(
diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
index e9457c55d1..042d26b341 100644
--- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
@@ -1,7 +1,4 @@
-"""PPO policy loss function.
-
-Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
-"""
+"""OPMD policy loss function."""
from typing import Dict, List, Tuple
diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py
index 01356cc066..8660a6376c 100644
--- a/trinity/algorithm/utils.py
+++ b/trinity/algorithm/utils.py
@@ -3,6 +3,8 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
"""
+import torch
+
def masked_sum(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
@@ -14,6 +16,44 @@ def masked_mean(values, mask, axis=None):
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
+def masked_var(values, mask, unbiased=True):
+ """Compute variance of tensor with masked values."""
+ mean = masked_mean(values, mask)
+ centered_values = values - mean
+ variance = masked_mean(centered_values**2, mask)
+ if unbiased:
+ mask_sum = mask.sum()
+ if mask_sum == 0:
+ raise ValueError("At least one element in the mask has to be 1.")
+ # note that if mask_sum == 1, then there is a division by zero issue
+ # to avoid it you just need to use a larger minibatch_size
+ if mask_sum == 1:
+ raise ValueError("The sum of the mask is one, which can cause a division by zero.")
+ bessel_correction = mask_sum / (mask_sum - 1)
+ variance = variance * bessel_correction
+ return variance
+
+
+def masked_whiten(values, mask, shift_mean=True):
+ """
+ Whiten `values` by normalizing with mean and variance computed over `mask`.
+
+ Args:
+ values (torch.Tensor): Input tensor.
+ mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats.
+ shift_mean (bool): If True (default), output is zero-mean;
+ if False, the original mean is re-added after scaling.
+
+ Returns:
+ torch.Tensor: Whitened tensor of same shape as `values`.
+ """
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
+ if not shift_mean:
+ whitened += mean
+ return whitened
+
+
def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict:
if dst_metrics is None:
dst_metrics = {}
From fec7f3cca75c625cae8404a6e5962b25e929c80d Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Tue, 10 Jun 2025 20:55:31 +0800
Subject: [PATCH 07/28] Refactor train step (#69)
---
tests/buffer/queue_test.py | 4 +-
tests/buffer/sql_test.py | 4 +-
tests/explorer/runner_pool_test.py | 4 +-
tests/trainer/trainer_test.py | 10 +-
trinity/algorithm/algorithm.py | 186 ++++
trinity/algorithm/algorithm_manager.py | 34 +
.../entropy_loss_fn/entropy_loss_fn.py | 22 +
trinity/algorithm/kl_fn/kl_fn.py | 2 +-
trinity/buffer/buffer.py | 6 +-
trinity/buffer/reader/file_reader.py | 7 +-
trinity/buffer/schema/sql_schema.py | 25 +-
trinity/buffer/writer/sql_writer.py | 4 +-
trinity/cli/launcher.py | 30 +-
trinity/common/config.py | 110 ++-
trinity/common/constants.py | 28 -
trinity/common/verl_config.py | 20 +-
trinity/explorer/explorer.py | 11 +
trinity/trainer/trainer.py | 97 +--
trinity/trainer/verl/core_algos.py | 717 ---------------
trinity/trainer/verl/dp_actor.py | 16 +-
trinity/trainer/verl/ray_trainer.py | 816 ------------------
trinity/trainer/verl_trainer.py | 399 ++++-----
22 files changed, 568 insertions(+), 1984 deletions(-)
create mode 100644 trinity/algorithm/algorithm.py
create mode 100644 trinity/algorithm/algorithm_manager.py
delete mode 100644 trinity/trainer/verl/core_algos.py
delete mode 100644 trinity/trainer/verl/ray_trainer.py
diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py
index e06b133256..ce141909bc 100644
--- a/tests/buffer/queue_test.py
+++ b/tests/buffer/queue_test.py
@@ -4,7 +4,7 @@
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, StorageConfig
-from trinity.common.constants import AlgorithmType, StorageType
+from trinity.common.constants import StorageType
from trinity.common.experience import Experience
@@ -15,7 +15,7 @@ def test_queue_buffer(self):
read_batch_size = 4
meta = StorageConfig(
name="test_buffer",
- algorithm_type=AlgorithmType.PPO,
+ algorithm_type="ppo",
storage_type=StorageType.QUEUE,
)
config = BufferConfig(
diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py
index 61ebc46315..5620c38f8e 100644
--- a/tests/buffer/sql_test.py
+++ b/tests/buffer/sql_test.py
@@ -6,7 +6,7 @@
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
-from trinity.common.constants import AlgorithmType, StorageType
+from trinity.common.constants import StorageType
from trinity.common.experience import Experience
db_path = os.path.join(os.path.dirname(__file__), "test.db")
@@ -19,7 +19,7 @@ def test_create_sql_buffer(self) -> None:
read_batch_size = 4
meta = StorageConfig(
name="test_buffer",
- algorithm_type=AlgorithmType.PPO,
+ algorithm_type="ppo",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
)
diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py
index 036339e747..8a6e262a90 100644
--- a/tests/explorer/runner_pool_test.py
+++ b/tests/explorer/runner_pool_test.py
@@ -10,7 +10,7 @@
from tests.tools import get_unittest_dataset_config
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import InferenceModelConfig, StorageConfig, load_config
-from trinity.common.constants import AlgorithmType, StorageType
+from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
from trinity.common.workflows import Task
@@ -105,7 +105,7 @@ def setUp(self):
) = StorageConfig(
name="test",
storage_type=StorageType.QUEUE,
- algorithm_type=AlgorithmType.PPO,
+ algorithm_type="ppo",
)
self.queue = QueueReader(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index e83b443c4b..5b2795d952 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -15,7 +15,7 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both, train
-from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
+from trinity.common.constants import MonitorType, SyncMethod
class BaseTrainerCase(RayUnittestBase):
@@ -119,7 +119,7 @@ class TestTrainerGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K."""
# test both mode
- self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.repeat_times = 4
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
self.config.algorithm.advantage_fn = "grpo"
@@ -157,7 +157,7 @@ class TestTrainerGSM8KWithSFT(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
- self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.repeat_times = 4
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {}
@@ -174,7 +174,7 @@ def test_trainer(self):
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
- self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
@@ -193,7 +193,7 @@ def test_trainer(self):
"""Test DPO."""
# test both mode
self.config.mode = "train"
- self.config.algorithm.algorithm_type = AlgorithmType.DPO
+ self.config.algorithm.algorithm_type = "dpo"
self.config.algorithm.policy_loss_fn = "dpo"
self.config.algorithm.policy_loss_fn_args = {}
# self.config.buffer.batch_size = 32
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
new file mode 100644
index 0000000000..f94798fe85
--- /dev/null
+++ b/trinity/algorithm/algorithm.py
@@ -0,0 +1,186 @@
+# -*- coding: utf-8 -*-
+"""Algorithm classes."""
+
+from abc import ABC, ABCMeta
+from typing import Dict
+
+from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
+from trinity.common.config import Config
+from trinity.common.constants import SyncMethod
+from trinity.common.experience import Experience, Experiences
+from trinity.utils.log import get_logger
+from trinity.utils.registry import Registry
+
+logger = get_logger(__name__)
+
+ALGORITHM_TYPE = Registry("algorithm")
+
+
+class ConstantMeta(ABCMeta):
+ def __setattr__(cls, name, value):
+ if name in cls.__dict__:
+ raise AttributeError(f"{name} is already defined in {cls.__name__}")
+ return super().__setattr__(name, value)
+
+
+class AlgorithmType(ABC, metaclass=ConstantMeta):
+ use_critic: bool
+ use_reference: bool
+ use_advantage: bool
+ use_rollout: bool
+ can_balance_batch: bool
+ schema: type
+
+ @classmethod
+ def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
+ return Experiences.gather_experiences(exps, pad_token_id)
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ raise NotImplementedError
+
+ @classmethod
+ def name(cls) -> str:
+ return cls._name
+
+ @classmethod
+ def check_config(cls, config: Config) -> None:
+ pass
+
+
+@ALGORITHM_TYPE.register_module("sft")
+class SFTAlgorithm(AlgorithmType):
+ """SFT Algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = False
+ use_advantage: bool = False
+ use_rollout: bool = False
+ can_balance_batch: bool = True
+ schema: type = SFTDataModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "policy_loss_fn": "sft",
+ "kl_loss_fn": "none",
+ "entropy_loss_fn": "none",
+ }
+
+
+@ALGORITHM_TYPE.register_module("ppo")
+class PPOAlgorithm(AlgorithmType):
+ """PPO Algorithm."""
+
+ use_critic: bool = True
+ use_reference: bool = True
+ use_advantage: bool = True
+ use_rollout: bool = True
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 1,
+ "policy_loss_fn": "ppo",
+ "advantage_fn": "ppo",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "basic",
+ }
+
+
+@ALGORITHM_TYPE.register_module("grpo")
+class GRPOAlgorithm(AlgorithmType):
+ """GRPO algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = True
+ use_advantage: bool = True
+ use_rollout: bool = True
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 2,
+ "policy_loss_fn": "ppo",
+ "advantage_fn": "grpo",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "basic",
+ }
+
+
+@ALGORITHM_TYPE.register_module("opmd")
+class OPMDAlgorithm(AlgorithmType):
+ """OPMD algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = True
+ use_advantage: bool = True
+ use_rollout: bool = True
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 2,
+ "policy_loss_fn": "opmd",
+ "advantage_fn": "opmd",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "basic",
+ }
+
+
+@ALGORITHM_TYPE.register_module("dpo")
+class DPOAlgorithm(AlgorithmType):
+ """DPO algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = True
+ use_advantage: bool = False
+ use_rollout: bool = False
+ can_balance_batch: bool = False
+ schema: type = DPODataModel
+
+ @classmethod
+ def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
+ return Experiences.gather_dpo_experiences(exps, pad_token_id)
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 2, # fake repeat times
+ "policy_loss_fn": "dpo",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "basic",
+ }
+
+ @classmethod
+ def check_config(cls, config: Config) -> None:
+ if config.model == "train":
+ if (
+ config.buffer.trainer_input.experience_buffer is None
+ or not config.buffer.trainer_input.experience_buffer.path
+ ):
+ raise ValueError(
+ "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == dpo`"
+ )
+ elif config.mode in ["both", "explore"]:
+ raise ValueError(f"DPO does not support `{config.mode}` mode")
+
+ if config.synchronizer.sync_method != SyncMethod.CHECKPOINT:
+ config.synchronizer.sync_method = SyncMethod.CHECKPOINT
+ logger.warning(
+ "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
+ )
+ if config.algorithm.repeat_times != 2:
+ config.algorithm.repeat_times = 2
+ logger.warning(
+ "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2."
+ ) # no need to warn
diff --git a/trinity/algorithm/algorithm_manager.py b/trinity/algorithm/algorithm_manager.py
new file mode 100644
index 0000000000..3c2983c80b
--- /dev/null
+++ b/trinity/algorithm/algorithm_manager.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+"""AlgorithmManager for switching between SFT and RFT."""
+
+from trinity.algorithm.algorithm import ALGORITHM_TYPE
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN
+from trinity.algorithm.kl_fn.kl_fn import KL_FN
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
+from trinity.common.config import AlgorithmConfig, Config
+
+
+class AlgorithmManager:
+ def __init__(self, config: Config):
+ self.config = config
+ sft_type = ALGORITHM_TYPE.get("sft")
+ sft_default_config = sft_type.get_default_config()
+ self.sft_algorithm_config = AlgorithmConfig(
+ algorithm_type="sft",
+ **sft_default_config,
+ )
+ policy_fn_cls = POLICY_LOSS_FN.get(self.sft_algorithm_config.policy_loss_fn)
+ self.sft_algorithm_config.policy_loss_fn_args = policy_fn_cls.default_args()
+ kl_loss_fn_cls = KL_FN.get(self.sft_algorithm_config.kl_loss_fn)
+ self.sft_algorithm_config.kl_loss_fn_args = kl_loss_fn_cls.default_args()
+ entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.sft_algorithm_config.entropy_loss_fn)
+ self.sft_algorithm_config.entropy_loss_fn_args = entropy_loss_fn_cls.default_args()
+
+ def get_current_algorithm_config(self, global_steps: int):
+ if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps:
+ return self.sft_algorithm_config
+ else:
+ return self.config.algorithm
+
+ def need_save(self, global_steps: int):
+ return global_steps == self.config.buffer.trainer_input.sft_warmup_steps
diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
index 4df9272ca0..cf102dd6b7 100644
--- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
+++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
@@ -61,3 +61,25 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
return {"entropy_coef": 0.0}
+
+
+@ENTROPY_LOSS_FN.register_module("none")
+class DummyEntropyLossFn(EntropyLossFn):
+ """
+ Dummy entropy loss function.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(
+ self,
+ entropy: torch.Tensor,
+ action_mask: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ return torch.tensor(0.0), {}
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {}
diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py
index 95d2915a84..62ed48cd49 100644
--- a/trinity/algorithm/kl_fn/kl_fn.py
+++ b/trinity/algorithm/kl_fn/kl_fn.py
@@ -102,7 +102,7 @@ def default_args(cls):
@KL_FN.register_module("none")
-class DummyFn(KLFn):
+class DummyKLFn(KLFn):
"""
Dummy KL function.
"""
diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py
index 09ff663c47..9d77dbb379 100644
--- a/trinity/buffer/buffer.py
+++ b/trinity/buffer/buffer.py
@@ -41,9 +41,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
elif storage_config.storage_type == StorageType.FILE:
from trinity.buffer.reader.file_reader import FILE_READERS
- file_read_type = storage_config.algorithm_type
- if file_read_type is not None:
- file_read_type = file_read_type.value
+ algorithm_type = storage_config.algorithm_type
+ if algorithm_type is not None:
+ file_read_type = algorithm_type
else:
file_read_type = "rollout"
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 316d3ae297..58b762d3f2 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -6,9 +6,10 @@
import transformers
from datasets import load_dataset
+from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
-from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
+from trinity.common.constants import PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
@@ -17,7 +18,7 @@
FILE_READERS = Registry("file_readers")
-@FILE_READERS.register_module(AlgorithmType.SFT.value)
+@FILE_READERS.register_module(SFTAlgorithm.name())
class SFTDataReader(BufferReader):
"""Reader for SFT file data."""
@@ -96,7 +97,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
return exp_list
-@FILE_READERS.register_module(AlgorithmType.DPO.value)
+@FILE_READERS.register_module(DPOAlgorithm.name())
class DPODataReader(BufferReader):
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.split = meta.split
diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py
index db2e4ca137..21289c7768 100644
--- a/trinity/buffer/schema/sql_schema.py
+++ b/trinity/buffer/schema/sql_schema.py
@@ -5,9 +5,7 @@
from sqlalchemy import Column, Float, Integer, LargeBinary, String
from sqlalchemy.ext.declarative import declarative_base
-from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experience
-from trinity.common.models.utils import tokenize_and_mask_messages_hf
Base = declarative_base()
@@ -85,6 +83,8 @@ def from_messages(
chat_template: Optional[str] = None,
) -> "SFTDataModel":
"""Convert a list of messages into a single instance of SFT data."""
+ from trinity.common.models.utils import tokenize_and_mask_messages_hf
+
token_ids, action_mask = tokenize_and_mask_messages_hf(
tokenizer=tokenizer,
messages=messages,
@@ -125,22 +125,15 @@ def to_experience(self) -> Experience:
return exp
-SCHEMA_MAPPING = {
- None: TaskModel,
- AlgorithmType.SFT: SFTDataModel,
- AlgorithmType.PPO: ExperienceModel,
- AlgorithmType.GRPO: ExperienceModel,
- AlgorithmType.OPMD: ExperienceModel,
- AlgorithmType.DPO: DPODataModel,
-}
-
-
-def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any:
+def create_dynamic_table(algorithm_type: Union[str | None], table_name: str) -> Any:
"""Create a dynamic table based on the provided algorithm type and table name."""
- if algorithm_type not in SCHEMA_MAPPING:
- raise ValueError(f"Unknown schema: {algorithm_type}")
+ if algorithm_type is None:
+ base_class = TaskModel
+ else:
+ from trinity.algorithm.algorithm import ALGORITHM_TYPE
- base_class = SCHEMA_MAPPING[algorithm_type]
+ algorithm = ALGORITHM_TYPE.get(algorithm_type)
+ base_class = algorithm.schema
table_attrs = {
"__tablename__": table_name,
diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py
index 7464064037..e0b0bdf640 100644
--- a/trinity/buffer/writer/sql_writer.py
+++ b/trinity/buffer/writer/sql_writer.py
@@ -5,6 +5,7 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
+from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.schema import Base, create_dynamic_table
from trinity.buffer.utils import retry_session
@@ -22,7 +23,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
# we only support write RFT algorithm buffer for now
# TODO: support other algorithms
- assert meta.algorithm_type.is_rft, "Only RFT buffer is supported for writing."
+ algorithm = ALGORITHM_TYPE.get(meta.algorithm_type)
+ assert algorithm.use_rollout, "Only RFT buffer is supported for writing."
self.engine = create_engine(meta.path, poolclass=NullPool)
self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index 9dfe4df8ee..6a01bfb688 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -8,7 +8,6 @@
import ray
from trinity.common.config import Config, load_config
-from trinity.common.constants import AlgorithmType
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
@@ -49,20 +48,8 @@ def train(config: Config) -> None:
trainer = Trainer.remote(config)
ray.get(trainer.prepare.remote())
- if config.buffer.trainer_input.sft_warmup_steps > 0:
- while True:
- train_continue, train_step_num = ray.get(
- trainer.train_one_period.remote(AlgorithmType.SFT)
- )
- if train_step_num <= config.buffer.trainer_input.sft_warmup_steps:
- logger.info(f"SFT warmup step {train_step_num} finished.")
- if not train_continue:
- logger.info("SFT warmup finished.")
- break
-
- algo_type = config.algorithm.algorithm_type
try:
- ray.get(trainer.train.remote(algo_type))
+ ray.get(trainer.train.remote())
logger.info("Train finished.")
ray.get(trainer.shutdown.remote())
except Exception as e:
@@ -93,23 +80,10 @@ def both(config: Config) -> None:
# sync weight before training start
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
- if config.buffer.trainer_input.sft_warmup_steps > 0:
- while True:
- train_continue, train_step_num = ray.get(
- trainer.train_one_period.remote(AlgorithmType.SFT)
- )
- if train_step_num <= config.buffer.trainer_input.sft_warmup_steps:
- logger.info(f"SFT warmup step {train_step_num} finished.")
- if not train_continue:
- logger.info("SFT warmup finished.")
- break
- ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
-
- algo_type = config.algorithm.algorithm_type
while True:
try:
ref_explore = explorer.explore_one_period.remote()
- ref_train = trainer.train_one_period.remote(algo_type)
+ ref_train = trainer.train_one_period.remote()
explore_continue, explore_step_num = ray.get(ref_explore)
train_continue, train_step_num = ray.get(ref_train)
if not explore_continue:
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 91c7790571..dd863edbd3 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -7,7 +7,6 @@
from omegaconf import OmegaConf
from trinity.common.constants import (
- AlgorithmType,
MonitorType,
PromptType,
ReadStrategy,
@@ -84,7 +83,7 @@ class StorageConfig:
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
# ! DO NOT SET, automatically set from algorithm.algorithm_type
- algorithm_type: Optional[AlgorithmType] = None
+ algorithm_type: Optional[str] = None
# ! DO NOT SET, automatically set from buffer.total_epochs
total_epochs: int = 1 # automatically set
@@ -170,27 +169,27 @@ class InferenceModelConfig:
class AlgorithmConfig:
"""Config for algorithm."""
- algorithm_type: AlgorithmType = AlgorithmType.PPO
+ algorithm_type: str = "ppo"
# for GRPO-like algorithms, repeat each task for `repeat_times` times
repeat_times: int = 1
- policy_loss_fn: str = "ppo"
+ policy_loss_fn: Optional[str] = None # "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None
- advantage_fn: str = "ppo"
+ advantage_fn: Optional[str] = None # "ppo"
# If not set, use AdvantageFn.default_args()
advantage_fn_args: Optional[dict] = None
- kl_penalty_fn: str = "none" # set to "none" to disable kl penalty in reward
+ kl_penalty_fn: Optional[str] = None # "none" # set to "none" to disable kl penalty in reward
# If not set, use kl_penalty_fn.default_args()
kl_penalty_fn_args: Optional[dict] = None
- kl_loss_fn: str = "k2" # set to "none" to disable kl loss
+ kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss
# If not set, use kl_loss_fn.default_args()
kl_loss_fn_args: Optional[dict] = None
- entropy_loss_fn: str = "basic"
+ entropy_loss_fn: Optional[str] = None # "basic"
# If not set, use entropy_loss_fn.default_args()
entropy_loss_fn_args: Optional[dict] = None
@@ -198,6 +197,15 @@ class AlgorithmConfig:
# TODO: move this to SFT warmup
use_token_level_loss: bool = True
+ # do not set
+ algorithm_manager: Optional[Any] = None
+
+ def get_current_algorithm_config(self, global_steps: int):
+ return self.algorithm_manager.get_current_algorithm_config(global_steps)
+
+ def need_save(self, global_steps: int):
+ return self.algorithm_manager.need_save(global_steps)
+
@dataclass
class ClusterConfig:
@@ -351,32 +359,25 @@ def _check_deprecated(self) -> None:
def _check_interval(self) -> None:
assert self.synchronizer.sync_interval > 0
- # check eval_interval
- if (
- self.mode != "bench"
- and self.algorithm.algorithm_type != AlgorithmType.DPO
- and self.explorer.eval_interval % self.synchronizer.sync_interval != 0
- ):
- self.explorer.eval_interval = (
- max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1)
- ) * self.synchronizer.sync_interval
- logger.warning(
- f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}."
- )
-
- # check save_interval
- if (
- self.mode != "bench"
- and self.algorithm.algorithm_type != AlgorithmType.DPO
- and self.synchronizer.sync_method == SyncMethod.CHECKPOINT
- ):
- if self.trainer.save_interval != self.synchronizer.sync_interval:
+ if self.mode != "bench" and self.algorithm.algorithm_type != "dpo": # TODO
+ # check eval_interval
+ if self.explorer.eval_interval % self.synchronizer.sync_interval != 0:
+ self.explorer.eval_interval = (
+ max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1)
+ ) * self.synchronizer.sync_interval
logger.warning(
- f"When `algorithm.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
- f"`trainer.save_interval` will be set to "
- f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
+ f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}."
)
- self.trainer.save_interval = self.synchronizer.sync_interval
+
+ # check save_interval
+ if self.synchronizer.sync_method == SyncMethod.CHECKPOINT:
+ if self.trainer.save_interval != self.synchronizer.sync_interval:
+ logger.warning(
+ f"When `algorithm.algorithm_type` != `dpo` and `synchronizer.sync_method` == `checkpoint`, "
+ f"`trainer.save_interval` will be set to "
+ f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
+ )
+ self.trainer.save_interval = self.synchronizer.sync_interval
def _check_buffer(self) -> None: # noqa: C901
# check explorer_input
@@ -440,14 +441,7 @@ def _check_buffer(self) -> None: # noqa: C901
f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}"
)
elif self.mode == "train": # TODO: to be check
- if self.algorithm.algorithm_type.is_dpo():
- if (
- self.buffer.trainer_input.experience_buffer is None
- or not self.buffer.trainer_input.experience_buffer.path
- ):
- raise ValueError(
- "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == AlgorithmType.DPO`"
- )
+ pass
if self.buffer.trainer_input.experience_buffer is not None:
self.buffer.trainer_input.experience_buffer.algorithm_type = (
self.algorithm.algorithm_type
@@ -468,7 +462,7 @@ def _check_buffer(self) -> None: # noqa: C901
"`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0"
)
if self.buffer.trainer_input.sft_warmup_dataset is not None:
- self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
+ self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO
# set read_batch_size / pad_token_id / tokenizer_path
self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times
@@ -491,6 +485,21 @@ def _check_algorithm(self) -> None:
KL_FN,
POLICY_LOSS_FN,
)
+ from trinity.algorithm.algorithm import ALGORITHM_TYPE
+
+ algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type)
+ algorithm.check_config(self)
+ default_config = {
+ "policy_loss_fn": "ppo",
+ "advantage_fn": "ppo",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "basic",
+ }
+ default_config.update(algorithm.get_default_config())
+ for key, value in default_config.items():
+ if getattr(self.algorithm, key, None) is None:
+ setattr(self.algorithm, key, value)
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
@@ -526,11 +535,12 @@ def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
+ # check algorithm
+ self._check_algorithm()
+
# check mode
if self.mode not in ["explore", "train", "both", "bench"]:
raise ValueError(f"Invalid mode: {self.mode}")
- if self.algorithm.algorithm_type == AlgorithmType.DPO and self.mode == "both":
- raise ValueError("DPO does not support `both` mode")
# prepare for the checkpoint directory
if not os.path.isabs(self.checkpoint_root_dir):
@@ -545,9 +555,6 @@ def check_and_update(self) -> None: # noqa: C901
if not self.model.critic_model_path:
self.model.critic_model_path = self.model.model_path
- # check algorithm
- self._check_algorithm()
-
# check explorer
if (
self.explorer.rollout_model.engine_type != "vllm_async"
@@ -572,17 +579,6 @@ def check_and_update(self) -> None: # noqa: C901
logger.warning(
f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
- if (
- self.algorithm.algorithm_type == AlgorithmType.DPO
- and self.synchronizer.sync_method != SyncMethod.CHECKPOINT
- ):
- self.synchronizer.sync_method = SyncMethod.CHECKPOINT
- logger.warning(
- "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
- )
- if self.algorithm.algorithm_type == AlgorithmType.DPO and self.algorithm.repeat_times != 2:
- self.algorithm.repeat_times = 2
- logger.warning("DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2.")
self._check_interval()
diff --git a/trinity/common/constants.py b/trinity/common/constants.py
index 860cd39027..47b04f853b 100644
--- a/trinity/common/constants.py
+++ b/trinity/common/constants.py
@@ -62,34 +62,6 @@ class StorageType(CaseInsensitiveEnum):
FILE = "file"
-class AlgorithmType(CaseInsensitiveEnum):
- """Algorithm Type."""
-
- SFT = "sft"
- PPO = "ppo"
- GRPO = "grpo"
- OPMD = "opmd"
- PAIRWISE_OPMD = "pairwise_opmd"
- DPO = "dpo"
-
- def is_rft(self) -> bool:
- """Check if the algorithm is RFT."""
- return self in [
- AlgorithmType.PPO,
- AlgorithmType.GRPO,
- AlgorithmType.OPMD,
- AlgorithmType.PAIRWISE_OPMD,
- ]
-
- def is_sft(self) -> bool:
- """Check if the algorithm is SFT."""
- return self == AlgorithmType.SFT
-
- def is_dpo(self) -> bool:
- """Check if the algorithm is DPO."""
- return self == AlgorithmType.DPO
-
-
class MonitorType(CaseInsensitiveEnum):
"""Monitor Type."""
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index e8180f4718..644fe9a8f5 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -4,9 +4,8 @@
from omegaconf import OmegaConf
+from trinity.algorithm.algorithm import DPOAlgorithm
from trinity.common.config import BufferConfig, Config, SynchronizerConfig
-from trinity.common.constants import AlgorithmType
-from trinity.trainer.verl.ray_trainer import AdvantageEstimator
from trinity.utils.log import get_logger
logger = get_logger(__name__)
@@ -79,7 +78,7 @@ class Actor:
checkpoint: Checkpoint = field(default_factory=Checkpoint)
optim: Optim = field(default_factory=Optim)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
- algorithm_type: AlgorithmType = AlgorithmType.PPO
+ algorithm_type: str = "ppo" # TODO
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
@@ -95,8 +94,15 @@ class Ref:
ulysses_sequence_parallel_size: int = 1
+@dataclass
+class _ValKwargs:
+ do_sample: bool = False
+
+
@dataclass
class Rollout:
+ # do not set
+ val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
temperature: float = 1.0
n: int = 1 # > 1 for grpo
@@ -318,12 +324,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
if adv_fn_args is not None and "lam" in adv_fn_args:
self.algorithm.lam = adv_fn_args["lam"]
self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type
- if config.algorithm.algorithm_type == AlgorithmType.PPO:
- logger.info("Setting `adv_estimator` to 'gae' for PPO")
- self.algorithm.adv_estimator = AdvantageEstimator.GAE.value
- elif config.algorithm.algorithm_type in (AlgorithmType.GRPO, AlgorithmType.OPMD):
- logger.info("Setting `adv_estimator` to 'grpo' for GRPO/OPMD")
- self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore
self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore
@@ -334,7 +334,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
# Need to double check whether this is indeed the case,
# and see if adv_estimator can be removed completely.
- if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO
+ if isinstance(self.actor_rollout_ref.actor.algorithm_type, DPOAlgorithm): # for DPO
if not self.actor_rollout_ref.actor.use_kl_loss:
self.actor_rollout_ref.actor.use_kl_loss = True
logger.warning("DPO must use KL loss.")
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 9c3cc414c7..37257f71ce 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -8,6 +8,7 @@
import ray
import torch
+from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.buffer import get_buffer_writer
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config
@@ -33,6 +34,7 @@ def __init__(self, config: Config):
explorer_meta = self.cache.load_explorer()
self.step_num = explorer_meta.get("latest_iteration", 0)
self.config = config
+ self.algorithm_manager = AlgorithmManager(config)
self.models, self.auxiliary_models = create_inference_models(config)
if self.config.mode != "bench":
self.experience_buffer = get_buffer_writer(
@@ -177,6 +179,15 @@ def explore_one_period(self) -> Tuple[bool, int]:
explore_status: whether there are more tasks to explore.
explore_step_num: the number of explore steps
"""
+ # skip for sft
+ algo_config = self.algorithm_manager.get_current_algorithm_config(self.step_num + 1)
+ if algo_config.algorithm_type == "sft":
+ for _ in range(self.config.synchronizer.sync_interval):
+ self.step_num += 1
+ if self.algorithm_manager.need_save(self.step_num):
+ break
+ return True, self.step_num
+
task_num_per_period = self.config.synchronizer.sync_interval * self.config.buffer.batch_size
st = time.time()
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index c2ad5fec96..95859685ee 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -12,10 +12,11 @@
import ray
+from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
+from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.buffer import get_buffer_reader
-from trinity.common.config import AlgorithmConfig, Config
-from trinity.common.constants import AlgorithmType, SyncMethod
-from trinity.common.experience import Experiences
+from trinity.common.config import Config
+from trinity.common.constants import SyncMethod
from trinity.utils.log import get_logger
@@ -26,6 +27,7 @@ class Trainer:
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
+ self.algorithm_manager = AlgorithmManager(config)
self.train_buffer = get_buffer_reader(
self.config.buffer.trainer_input.experience_buffer, # type: ignore
self.config.buffer,
@@ -44,86 +46,54 @@ def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
- def train(self, algo_type: AlgorithmType = AlgorithmType.PPO):
+ def train(self):
"""Train the model."""
while True:
- train_status, _ = self.train_step(algo_type)
+ train_status, _ = self.train_step()
if not train_status:
break
- def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
+ def train_one_period(self) -> Tuple[bool, int]:
"""Train for one period. Each period contains `sync_interval` steps.
Returns:
train_status: Whether to continue training.
train_step_num: The number of training steps"""
for _ in range(self.config.synchronizer.sync_interval):
- train_status, train_step_num = self.train_step(algo_type)
+ train_status, train_step_num = self.train_step()
if not train_status:
return False, train_step_num
self.logger.info(f"Train step {train_step_num} finished.")
return True, train_step_num
- def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
+ def train_step(self) -> Tuple[bool, int]:
"""Train one step.
- Args:
- algo_type (AlgorithmType): The type of data to be used for training.
- Defaults to AlgorithmType.PPO.
-
Returns:
bool: Whether to continue training.
"""
- if algo_type.is_sft():
- algorithm_config = AlgorithmConfig(
- algorithm_type=AlgorithmType.SFT,
- policy_loss_fn="sft",
- policy_loss_fn_args={
- "use_token_level_loss": self.config.algorithm.use_token_level_loss
- },
- kl_loss_fn="none",
- kl_loss_fn_args={},
- entropy_loss_fn="basic",
- entropy_loss_fn_args=self.config.algorithm.entropy_loss_fn_args,
- )
- self.engine.set_algorithm(algorithm_config)
- else:
- self.engine.set_algorithm(self.config.algorithm)
- if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy:
+ algo_config = self.algorithm_manager.get_current_algorithm_config(
+ self.engine.train_step_num + 1
+ )
+ algo_type = algo_config.algorithm_type
+ algorithm = ALGORITHM_TYPE.get(algo_type)
+ if algorithm.use_rollout:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
strategy = None
try:
- if algo_type.is_sft():
+ if algorithm == SFTAlgorithm:
exps = self.sft_warmup_buffer.read()
else:
exps = self.train_buffer.read(strategy=strategy)
except StopIteration:
self.logger.warning("No more data to train. Stop training.")
- return False, 0 # TODO: get the actual step number
-
- if algo_type.is_sft():
- return self.engine.train_sft_step(
- Experiences.gather_experiences(
- exps,
- pad_token_id=self.config.buffer.pad_token_id, # type: ignore
- )
- )
- elif algo_type.is_rft():
- return self.engine.train_rft_step(
- Experiences.gather_experiences(
- exps,
- pad_token_id=self.config.buffer.pad_token_id, # type: ignore
- )
- )
- elif algo_type.is_dpo():
- return self.engine.train_dpo_step(
- Experiences.gather_dpo_experiences(
- exps,
- pad_token_id=self.config.buffer.pad_token_id, # type: ignore
- )
- )
- else:
- raise ValueError(f"Unsupported algorithm type: {algo_type}")
+ return False, self.engine.train_step_num
+
+ experiences = algorithm.gather_experience(
+ exps,
+ pad_token_id=self.config.buffer.pad_token_id, # type: ignore
+ )
+ return self.engine.train_step(experiences)
def sync_weight(self) -> None:
"""Sync the model weight."""
@@ -136,7 +106,7 @@ def flush_log(self, step: int) -> None:
def shutdown(self) -> None:
# if checkpoint not saved, save the last checkpoint
- step_num = self.engine.global_steps - 1
+ step_num = self.engine.train_step_num
path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}")
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
self.engine.save_checkpoint()
@@ -150,17 +120,14 @@ class TrainEngineWrapper(ABC):
def prepare(self) -> None:
"""Do some preparation before training started."""
+ @property
@abstractmethod
- def train_rft_step(self, experiences) -> Tuple[bool, int]:
- """Train on the RFT data."""
+ def train_step_num(self) -> int:
+ """Get the current training step number."""
@abstractmethod
- def train_sft_step(self, experiences) -> Tuple[bool, int]:
- """Train on the SFT data."""
-
- @abstractmethod
- def train_dpo_step(self, experiences) -> Tuple[bool, int]:
- """Train on the DPO data."""
+ def train_step(self, experiences) -> Tuple[bool, int]:
+ """Training."""
@abstractmethod
def save_checkpoint(self) -> None:
@@ -170,10 +137,6 @@ def save_checkpoint(self) -> None:
def sync_weight(self) -> None:
"""Sync the model weight."""
- @abstractmethod
- def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None:
- """Set training algorithm config."""
-
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the engine."""
diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py
deleted file mode 100644
index f104e0f4f4..0000000000
--- a/trinity/trainer/verl/core_algos.py
+++ /dev/null
@@ -1,717 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Modified from core_algos.py
-"""
-
-from abc import ABC, abstractmethod
-from collections import defaultdict
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import verl.utils.torch_functional as verl_F
-
-from trinity.common.constants import AlgorithmType
-
-
-class KLController(ABC):
- @abstractmethod
- def update(self, current_kl, n_steps):
- """update value"""
-
-
-class AdaptiveKLController(KLController):
- """
- Adaptive KL controller described in the paper:
- https://arxiv.org/pdf/1909.08593.pdf
- """
-
- def __init__(self, init_kl_coef, target_kl, horizon):
- self.value = init_kl_coef
- self.target = target_kl
- self.horizon = horizon
-
- def update(self, current_kl, n_steps):
- target = self.target
- proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
- mult = 1 + proportional_error * n_steps / self.horizon
- self.value *= mult
-
-
-class FixedKLController(KLController):
- """Fixed KL controller."""
-
- def __init__(self, kl_coef):
- self.value = kl_coef
-
- def update(self, current_kl, n_steps):
- pass
-
-
-def get_kl_controller(kl_config):
- if kl_config.type == "fixed":
- return FixedKLController(kl_coef=kl_config.kl_coef)
- elif kl_config.type == "adaptive":
- assert kl_config.horizon > 0, f"horizon must be larger than 0. Got {kl_config.horizon}"
- return AdaptiveKLController(
- init_kl_coef=kl_config.kl_coef,
- target_kl=kl_config.target_kl,
- horizon=kl_config.horizon,
- )
- else:
- raise ValueError("Unknown kl_ctrl type")
-
-
-def compute_opmd_outcome_advantage(
- token_level_rewards: torch.Tensor,
- eos_mask: torch.Tensor,
- index: torch.Tensor,
- opmd_baseline: str = "mean",
- tau: float = 1.0,
-):
- """Modified from compute_grpo_outcome_advantage
-
- Compute advantage for OPMD, operating only on Outcome reward
- (with only one scalar reward for each response).
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
- """
- response_length = token_level_rewards.shape[-1]
- scores = token_level_rewards.sum(dim=-1)
-
- id2score = defaultdict(list)
- id2baseline = {}
-
- with torch.no_grad():
- bsz = scores.shape[0]
- for i in range(bsz):
- id2score[index[i]].append(scores[i])
- for idx in id2score:
- if len(id2score[idx]) == 1:
- id2baseline[idx] = torch.tensor(0.0)
- # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
- elif len(id2score[idx]) > 1:
- if opmd_baseline == "mean":
- id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
- elif opmd_baseline == "logavgexp":
- rewards_tensor = torch.tensor(id2score[idx])
- # NOTE: we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)).
- # Hopefully the logsumexp calculation is numerically stable (as claimed by PyTorch's doc)
- # in cases where tau is small...
- id2baseline[idx] = tau * (
- torch.logsumexp(rewards_tensor / tau, dim=-1)
- - torch.log(torch.tensor(len(id2score[idx])))
- )
- else:
- raise NotImplementedError
- else:
- raise ValueError(f"no score in prompt index: {idx}")
- for i in range(bsz):
- scores[i] = scores[i] - id2baseline[index[i]]
- scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
-
- return scores, scores
-
-
-def compute_gae_advantage_return(
- token_level_rewards: torch.Tensor,
- values: torch.Tensor,
- eos_mask: torch.Tensor,
- gamma: float,
- lam: float,
-):
- """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
-
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- values: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
- gamma: `(float)`
- discounted factor used in RL
- lam: `(float)`
- lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
-
- """
- with torch.no_grad():
- lastgaelam = 0
- advantages_reversed = []
- gen_len = token_level_rewards.shape[-1]
-
- # values = values * eos_mask TODO: may use in multi-turn
- for t in reversed(range(gen_len)):
- nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
- delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
-
- lastgaelam = delta + gamma * lam * lastgaelam
- # lastgaelam = torch.where( # TODO: may use in multi-turn
- # eos_mask[:, t] == 1, delta + gamma * lam * lastgaelam, lastgaelam
- # )
- advantages_reversed.append(lastgaelam)
- advantages = torch.stack(advantages_reversed[::-1], dim=1)
-
- returns = advantages + values
- advantages = verl_F.masked_whiten(advantages, eos_mask)
- return advantages, returns
-
-
-# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
-def compute_grpo_outcome_advantage(
- token_level_rewards: torch.Tensor,
- eos_mask: torch.Tensor,
- index: torch.Tensor,
- epsilon: float = 1e-6,
-):
- """
- Compute advantage for GRPO, operating only on Outcome reward
- (with only one scalar reward for each response).
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
- """
- response_length = token_level_rewards.shape[-1]
- scores = token_level_rewards.sum(dim=-1)
-
- id2score = defaultdict(list)
- id2mean = {}
- id2std = {}
-
- with torch.no_grad():
- bsz = scores.shape[0]
- for i in range(bsz):
- id2score[index[i]].append(scores[i])
- for idx in id2score:
- if len(id2score[idx]) == 1:
- id2mean[idx] = torch.tensor(0.0)
- id2std[idx] = torch.tensor(1.0)
- elif len(id2score[idx]) > 1:
- id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
- id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
- else:
- raise ValueError(f"no score in prompt index: {idx}")
- for i in range(bsz):
- scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
- scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
-
- return scores, scores
-
-
-def compute_rloo_outcome_advantage(
- token_level_rewards: torch.Tensor,
- eos_mask: torch.Tensor,
- index: torch.Tensor,
- epsilon: float = 1e-6,
-):
- """
- Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
- """
- response_length = token_level_rewards.shape[-1]
- scores = token_level_rewards.sum(dim=-1)
-
- id2score = defaultdict(list)
- id2mean = {}
-
- with torch.no_grad():
- bsz = scores.shape[0]
- for i in range(bsz):
- id2score[index[i]].append(scores[i])
- for idx in id2score:
- if len(id2score[idx]) == 1:
- id2mean[idx] = torch.tensor(0.0)
- elif len(id2score[idx]) > 1:
- id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
- else:
- raise ValueError(f"no score in prompt index: {idx}")
- for i in range(bsz):
- response_num = len(id2score[index[i]])
- if response_num > 1:
- scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[
- index[i]
- ] * response_num / (response_num - 1)
- scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
-
- return scores, scores
-
-
-def compute_reinforce_plus_plus_outcome_advantage(
- token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: float
-):
- """
- Compute advantage for REINFORCE++.
- This implementation is based on the paper: https://arxiv.org/abs/2501.03262
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
- """
-
- with torch.no_grad():
- returns = torch.zeros_like(token_level_rewards)
- running_return = 0
-
- for t in reversed(range(token_level_rewards.shape[1])):
- running_return = token_level_rewards[:, t] + gamma * running_return
- returns[:, t] = running_return
-
- advantages = verl_F.masked_whiten(returns, eos_mask)
- advantages = advantages * eos_mask
-
- return advantages, returns
-
-
-def compute_remax_outcome_advantage(
- token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
-):
- """
- Compute advantage for ReMax, operating only on Outcome reward
- This implementation is based on the paper: https://arxiv.org/abs/2310.10505
-
- (with only one scalar reward for each response).
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- reward_baselines: `(torch.Tensor)`
- shape: (bs,)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
- """
- response_length = token_level_rewards.shape[-1]
- token_level_rewards.sum(dim=-1)
-
- with torch.no_grad():
- returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
- advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
-
- return advantages, returns
-
-
-def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
- kl = old_log_prob - ref_log_prob
- return token_level_scores - kl * kl_ratio
-
-
-def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs):
- """Compute policy loss for PPO / OPMD / pairwise OPMD"""
-
- algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO)
-
- if algorithm_type == AlgorithmType.OPMD:
- advantages = kwargs.get("advantages")
- tau = kwargs.get("tau")
- return compute_policy_loss_opmd(
- old_log_prob=old_log_prob,
- log_prob=log_prob,
- advantages=advantages,
- eos_mask=eos_mask,
- tau=tau,
- )
-
- elif algorithm_type == AlgorithmType.PAIRWISE_OPMD:
- token_level_scores = kwargs.get("token_level_scores")
- index = kwargs.get("index")
- tau = kwargs.get("tau")
- return compute_policy_loss_pairwise_opmd(
- old_log_prob=old_log_prob,
- log_prob=log_prob,
- token_level_scores=token_level_scores,
- eos_mask=eos_mask,
- index=index,
- tau=tau,
- )
-
- elif algorithm_type.is_rft():
- advantages = kwargs.get("advantages")
- cliprange = kwargs.get("cliprange")
- return compute_policy_loss_ppo(
- old_log_prob=old_log_prob,
- log_prob=log_prob,
- advantages=advantages,
- eos_mask=eos_mask,
- cliprange=cliprange,
- )
-
- else:
- raise NotImplementedError(f"Get invalid algorithm_type '{algorithm_type}'.")
-
-
-def compute_policy_loss_dpo(
- log_prob, ref_log_prob, eos_mask, loss_type="sigmoid", beta=0.1, label_smoothing=0.0
-):
- """Compute policy loss for DPO (Direct Preference Optimization)
-
- Ref: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L918
-
- Args:
- log_prob: `(torch.Tensor)`
- The log probabilities of the chosen responses from the policy model.
- ref_log_prob: `(torch.Tensor)`
- The log probabilities of the chosen responses from the reference model.
- loss_type: `(str)`
- Default: "sigmoid"
- The type of loss function to use.
- beta: `(float)`
- Default: 0.1
- A temperature parameter that controls the sharpness of the preference signal.
- Higher values make the loss more sensitive to small differences in log probabilities.
- label_smoothing: `(float)`
- Default: 0.0
- A parameter to encode uncertainty about the labels. Adds a small amount of smoothing to the loss
- to avoid overconfident predictions.
-
- Returns:
- dpo_loss: `a scalar torch.Tensor`
- chosen_diff: `(torch.Tensor)`
- rejected_diff: `(torch.Tensor)`
- """
- # log_prob: chosen, rejected, chosen, rejected, ...
- chosen_log_prob, rejected_log_prob = log_prob[::2], log_prob[1::2]
- chosen_mask, rejected_mask = eos_mask[::2], eos_mask[1::2]
- chosen_log_prob_sum = (chosen_log_prob * chosen_mask).sum(-1)
- rejected_log_prob_sum = (rejected_log_prob * rejected_mask).sum(-1)
-
- if ref_log_prob is None:
- raise NotImplementedError("DPO requires valid ref_log_prob")
- chosen_ref_log_prob, rejected_ref_log_prob = ref_log_prob[::2], ref_log_prob[1::2]
- chosen_ref_log_prob_sum = (chosen_ref_log_prob * chosen_mask).sum(-1)
- rejected_ref_log_prob_sum = (rejected_ref_log_prob * rejected_mask).sum(-1)
-
- # compute logits
- chosen_ratios = chosen_log_prob_sum - chosen_ref_log_prob_sum
- rejected_ratios = rejected_log_prob_sum - rejected_ref_log_prob_sum
- logits = chosen_ratios - rejected_ratios
-
- if loss_type == "sigmoid":
- losses = (
- -F.logsigmoid(beta * logits) * (1 - label_smoothing)
- - F.logsigmoid(-beta * logits) * label_smoothing
- )
- loss = losses.mean()
-
- else:
- raise NotImplementedError(f"loss_type {loss_type} is not supported in DPO")
-
- chosen_reward = beta * chosen_ratios.detach()
- rejected_reward = beta * rejected_ratios.detach()
- return loss, chosen_reward, rejected_reward
-
-
-def compute_policy_loss_pairwise_opmd(
- old_log_prob, log_prob, token_level_scores, eos_mask, index, tau
-):
- """Compute policy loss for pairwise_opmd
-
- NOTE: NOT TESTED YET
-
- TODO: allow using old_log_prob; for now we just discard it.
-
- NOTE: use token_level_scores rather than token_level_rewards, because we're not sure yet
- whether this algorithm is compatible with kl penalty as negative reward
-
- Args:
- old_log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- token_level_scores: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
- index: `(torch.Tensor)` or None (when use_uid is False)
- tau: `float`
-
- Returns:
- opmd_loss: `a scalar torch.Tensor`
- pairwise_opmd loss
- pg_clipfrac: (float)
- a float number indicating the fraction of policy gradient loss being clipped
- ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats)
-
- """
-
- # dummy computation
- log_prob_diff = log_prob - log_prob
- pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask)
- ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask)
-
- # loss for pairwise_opmd
- scores = token_level_scores.sum(dim=-1)
- action_level_log_prob = (log_prob * eos_mask).sum(dim=-1)
- diffs = scores - tau * (action_level_log_prob - action_level_log_prob.detach())
-
- if index is None:
- normalizer = eos_mask.sum() * max(1.0, tau)
- opmd_loss = (diffs - diffs.mean()).square().sum() / normalizer
- else:
- opmd_loss = None
- unique_index = list(set(index.tolist()))
- for idx in unique_index:
- subdiff = diffs[index == idx]
- if subdiff.shape[0] == 1:
- continue
- # subloss = len(subdiff) * subdiff.square().sum() - subdiff.sum().square()
- subloss = (subdiff - subdiff.mean()).square().sum()
- if opmd_loss is None:
- opmd_loss = subloss
- else:
- opmd_loss = opmd_loss + subloss
- normalizer = eos_mask.sum() * max(1.0, tau)
- opmd_loss = opmd_loss / normalizer
-
- # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss
- return opmd_loss, pg_clipfrac, ppo_kl
-
-
-def compute_policy_loss_opmd(old_log_prob, log_prob, advantages, eos_mask, tau):
- """The OPMD counterpart of verl's original compute_policy_loss (now renamed as compute_policy_loss_ppo)
-
- Args:
- old_log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
- tau: `float`
-
- Returns:
- opmd_loss: `a scalar torch.Tensor`
- opmd loss
- pg_clipfrac: (float)
- a float number indicating the fraction of policy gradient loss being clipped
- ppo_kl: (float) ... (TODO, confirm that this is only used for logging stats)
-
- """
- log_prob_diff = log_prob - old_log_prob
- pg_clipfrac = verl_F.masked_mean(
- torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask
- ) # meaningless
- ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask)
-
- # --- version 0: kimi-opmd ---
-
- # # the original quadratic loss in OPMD can be reformulated as follows
- # pg_losses = -advantages * log_prob
- # pg_loss = verl_F.masked_sum(pg_losses, eos_mask)
-
- # reg_losses = (log_prob_diff * eos_mask).sum(dim=-1).square()
- # reg_loss = reg_losses.sum()
-
- # opmd_loss = (pg_loss + 0.5 * tau * reg_loss) / eos_mask.sum()
- # # NOTE: this implementation uses batch-wise normalization;
- # # would it be beneficial to use trajectory-wise or group-wise normalization?
-
- # opmd_loss = opmd_loss / max(1.0, tau) # for stability when tau is large
-
- # --- version 1: min-opmd (minimalistic, but theoretically grounded) ---
-
- pg_losses = -advantages * log_prob
- opmd_loss = verl_F.masked_mean(pg_losses, eos_mask)
- opmd_loss = opmd_loss / (1.0 + tau) # for regularization (w.r.t. current pi_theta)
-
- # NOTE: return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss
- return opmd_loss, pg_clipfrac, ppo_kl
-
-
-def compute_policy_loss_ppo(old_log_prob, log_prob, advantages, eos_mask, cliprange):
- """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
-
- Args:
- old_log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
- cliprange: (float)
- The clip range used in PPO. See https://arxiv.org/abs/1707.06347
-
- Returns:
- pg_loss: `a scalar torch.Tensor`
- policy gradient loss computed via PPO
- pg_clipfrac: (float)
- a float number indicating the fraction of policy gradient loss being clipped
-
- """
- negative_approx_kl = log_prob - old_log_prob
- ratio = torch.exp(negative_approx_kl)
- ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
-
- pg_losses = -advantages * ratio
- pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
-
- pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
- pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
- return pg_loss, pg_clipfrac, ppo_kl
-
-
-def compute_policy_loss_sft(log_prob, eos_mask):
- """Simple way to compute SFT loss, unified with PG loss
-
- Args:
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- sft_loss: `a scalar torch.Tensor`
- pg_clipfrac: dummy value, merely for compatibility
- ppo_kl: dummy value, merely for compatibility
-
- """
- log_prob_diff = log_prob - log_prob.detach()
- pg_clipfrac = verl_F.masked_mean(torch.gt(log_prob_diff, log_prob_diff).float(), eos_mask)
- ppo_kl = verl_F.masked_mean(-log_prob_diff, eos_mask)
-
- sft_loss = verl_F.masked_mean(-log_prob, eos_mask)
-
- # Return pg_clipfrac and ppo_kl merely for compatibility with original compute_policy_loss
- return sft_loss, pg_clipfrac, ppo_kl
-
-
-def compute_entropy_loss(logits, eos_mask):
- """Compute Categorical entropy loss
-
- Args:
- logits: `(torch.Tensor)`
- shape: (bs, response_length, vocab_size)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
-
- Returns:
- entropy: a scalar torch.Tensor
-
- """
- # compute entropy
- entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
- entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
- return entropy_loss
-
-
-def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
- """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
-
- Args:
- vpreds (`torch.FloatTensor`):
- Predicted values of the value head, shape (`batch_size`, `response_length`)
- values (`torch.FloatTensor`):
- Old values of value head, shape (`batch_size`, `response_length`)
- returns: (`torch.FloatTensor`):
- Ground truth returns, shape (`batch_size`, `response_length`)
-
- Returns:
- vf_loss: a scalar (`torch.FloatTensor`):
- value function loss
- vf_clipfrac: a float
- The ratio of vf being clipped
-
- """
- vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
- vf_losses1 = (vpreds - returns) ** 2
- vf_losses2 = (vpredclipped - returns) ** 2
- vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
- vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
- return vf_loss, vf_clipfrac
-
-
-def kl_penalty(
- logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty
-) -> torch.FloatTensor:
- """Compute KL divergence given logprob and ref_logprob.
- Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
-
- Args:
- logprob:
- ref_logprob:
-
- Returns:
-
- """
- if kl_penalty == "kl":
- return logprob - ref_logprob
-
- if kl_penalty == "abs":
- return (logprob - ref_logprob).abs()
-
- if kl_penalty == "mse":
- return 0.5 * (logprob - ref_logprob).square()
-
- # J. Schulman. Approximating kl divergence, 2020.
- # # URL http://joschu.net/blog/kl-approx.html.
- if kl_penalty == "low_var_kl":
- kl = ref_logprob - logprob
- ratio = torch.exp(kl)
- kld = (ratio - kl - 1).contiguous()
- return torch.clamp(kld, min=-10, max=10)
-
- if kl_penalty == "full":
- # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
- raise NotImplementedError
-
- raise NotImplementedError
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index a7705fc6a0..595084ac02 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -31,9 +31,9 @@
from verl.workers.actor import BasePPOActor
from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
+from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig
-from trinity.common.constants import AlgorithmType
__all__ = ["DataParallelPPOActor"]
@@ -55,11 +55,11 @@ def __init__(
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
- self.algorithm_type = AlgorithmType.PPO
self.policy_loss_fn = None
+ self.kl_loss_fn = None
+ self.entropy_loss_fn = None
def set_algorithm(self, algorithm_config: AlgorithmConfig):
- self.algorithm_type = algorithm_config.algorithm_type
self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
**algorithm_config.policy_loss_fn_args
)
@@ -299,7 +299,7 @@ def update_policy(self, data: DataProto): # noqa: C901
for trinity_key in self.policy_loss_fn.select_keys:
verl_key = select_keys_trinity2verl[trinity_key]
select_keys.append(verl_key)
- if self.config.use_kl_loss:
+ if not isinstance(self.kl_loss_fn, DummyKLFn):
select_keys.append("ref_log_prob")
select_keys = list(set(select_keys))
batch = data.select(batch_keys=select_keys).batch
@@ -388,7 +388,7 @@ def update_policy(self, data: DataProto): # noqa: C901
)
# compute entropy loss from entropy
- entropy_loss, entropy_loss_metrics = self.entropy_loss_fn(
+ entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore
entropy=entropy,
action_mask=response_mask,
)
@@ -403,11 +403,13 @@ def update_policy(self, data: DataProto): # noqa: C901
kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss(
logprob=log_prob,
- ref_logprob=data["ref_log_prob"],
+ ref_logprob=data.get("ref_log_prob", None),
response_mask=response_mask,
)
prefix_metrics(
- src_metrics=kl_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
+ src_metrics=kl_loss_metrics,
+ prefix="actor",
+ dst_metrics=micro_batch_metrics,
)
policy_loss = policy_loss + kl_loss
diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py
deleted file mode 100644
index 5d883d05bb..0000000000
--- a/trinity/trainer/verl/ray_trainer.py
+++ /dev/null
@@ -1,816 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Modified from ray_trainer.py
-"""
-
-import os
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Dict, Type
-
-import numpy as np
-import ray
-import torch
-from codetiming import Timer
-from omegaconf import OmegaConf, open_dict
-from torch.utils.data import RandomSampler, SequentialSampler
-from torchdata.stateful_dataloader import StatefulDataLoader
-from verl import DataProto
-from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
-from verl.single_controller.base import Worker
-from verl.single_controller.ray import (
- RayClassWithInitArgs,
- RayResourcePool,
- RayWorkerGroup,
-)
-from verl.single_controller.ray.base import create_colocated_worker_cls
-from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
-from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
-from verl.utils.seqlen_balancing import (
- get_seqlen_balanced_partitions,
- log_seqlen_unbalance,
-)
-from verl.utils.torch_functional import masked_mean
-from verl.utils.tracking import ValidationGenerationsLogger
-
-from trinity.common.constants import AlgorithmType
-from trinity.trainer.verl import core_algos
-
-WorkerType = Type[Worker]
-
-
-class Role(Enum):
- """
- To create more roles dynamically, you can subclass Role and add new members
- """
-
- Actor = 0
- Rollout = 1
- ActorRollout = 2
- Critic = 3
- RefPolicy = 4
- RewardModel = 5
- ActorRolloutRef = 6
-
-
-class AdvantageEstimator(str, Enum):
- """
- Using an enumeration class to avoid spelling errors in adv_estimator
- """
-
- GAE = "gae"
- GRPO = "grpo"
- REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
- REMAX = "remax"
- RLOO = "rloo"
-
-
-@dataclass
-class ResourcePoolManager:
- """
- Define a resource pool specification. Resource pool will be initialized first.
- Mapping
- """
-
- resource_pool_spec: dict[str, list[int]]
- mapping: dict[Role, str]
- resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
-
- def create_resource_pool(self):
- for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
- # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
- # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
- # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
- resource_pool = RayResourcePool(
- process_on_nodes=process_on_nodes,
- use_gpu=True,
- max_colocate_count=1,
- name_prefix=resource_pool_name,
- )
- self.resource_pool_dict[resource_pool_name] = resource_pool
-
- self._check_resource_available()
-
- def get_resource_pool(self, role: Role) -> RayResourcePool:
- """Get the resource pool of the worker_cls"""
- return self.resource_pool_dict[self.mapping[role]]
-
- def get_n_gpus(self) -> int:
- """Get the number of gpus in this cluster."""
- return sum(
- [
- n_gpus
- for process_on_nodes in self.resource_pool_spec.values()
- for n_gpus in process_on_nodes
- ]
- )
-
- def _check_resource_available(self):
- """Check if the resource pool can be satisfied in this ray cluster."""
- node_available_resources = ray.state.available_resources_per_node()
- node_available_gpus = {
- node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()
- }
-
- # check total required gpus can be satisfied
- total_available_gpus = sum(node_available_gpus.values())
- total_required_gpus = sum(
- [
- n_gpus
- for process_on_nodes in self.resource_pool_spec.values()
- for n_gpus in process_on_nodes
- ]
- )
- if total_available_gpus < total_required_gpus:
- raise ValueError(
- f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}"
- )
-
- # check each resource pool can be satisfied, O(#resource_pools * #nodes)
- for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
- num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
- for node, available_gpus in node_available_gpus.items():
- if available_gpus >= num_gpus:
- node_available_gpus[node] -= num_gpus
- num_nodes -= 1
- if num_nodes == 0:
- break
- if num_nodes > 0:
- raise ValueError(
- f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster"
- )
-
-
-def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
- responses = data.batch["responses"]
- response_length = responses.size(1)
- token_level_scores = data.batch["token_level_scores"]
- batch_size = data.batch.batch_size[0]
- attention_mask = data.batch["attention_mask"]
- # response_mask = attention_mask[:, -response_length:]
- response_mask = data.batch["response_mask"]
- assert response_mask.shape == attention_mask[:, -response_length:].shape
-
- # compute kl between ref_policy and current policy
- if "ref_log_prob" in data.batch.keys():
- kld = core_algos.kl_penalty(
- data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
- ) # (batch_size, response_length)
- kld = kld * response_mask
- beta = kl_ctrl.value
- else:
- beta = 0
- kld = torch.zeros_like(response_mask, dtype=torch.float32)
-
- token_level_rewards = token_level_scores - beta * kld
-
- current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
- current_kl = torch.mean(current_kl, dim=0).item()
-
- # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
- kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
- data.batch["token_level_rewards"] = token_level_rewards
-
- metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta}
-
- return data, metrics
-
-
-def compute_response_mask(data: DataProto):
- responses = data.batch["responses"]
- response_length = responses.size(1)
- attention_mask = data.batch["attention_mask"]
- return attention_mask[:, -response_length:]
-
-
-@contextmanager
-def _timer(name: str, timing_raw: Dict[str, float]):
- with Timer(name=name, logger=None) as timer:
- yield
- timing_raw[name] = timer.last
-
-
-class RayPPOTrainer(object):
- """
- Note that this trainer runs on the driver process on a single CPU/GPU node.
- """
-
- # TODO: support each role have individual ray_worker_group_cls,
- # i.e., support different backend of different role
- def __init__(
- self,
- config,
- tokenizer,
- role_worker_mapping: dict[Role, WorkerType],
- resource_pool_manager: ResourcePoolManager,
- ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
- processor=None,
- reward_fn=None,
- val_reward_fn=None,
- ):
- # assert torch.cuda.is_available(), 'cuda must be available on driver'
-
- self.tokenizer = tokenizer
- self.processor = processor
- self.config = config
- self.reward_fn = reward_fn
- self.val_reward_fn = val_reward_fn
-
- self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
- assert self.hybrid_engine, "Currently, only support hybrid engine"
-
- if self.hybrid_engine:
- assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
-
- self.role_worker_mapping = role_worker_mapping
- self.resource_pool_manager = resource_pool_manager
- self.use_reference_policy = Role.RefPolicy in role_worker_mapping
- self.use_rm = Role.RewardModel in role_worker_mapping
- self.ray_worker_group_cls = ray_worker_group_cls
- self.validation_generations_logger = ValidationGenerationsLogger()
-
- # define KL control
- if self.use_reference_policy:
- self.kl_ctrl = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
- else:
- self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0)
-
- if (
- self.config.actor_rollout_ref.actor.get("algorithm_type", AlgorithmType.PPO)
- != AlgorithmType.PPO
- ):
- self.use_critic = False
- elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
- self.use_critic = True
- elif self.config.algorithm.adv_estimator in [
- AdvantageEstimator.GRPO,
- AdvantageEstimator.REINFORCE_PLUS_PLUS,
- AdvantageEstimator.REMAX,
- AdvantageEstimator.RLOO,
- ]:
- self.use_critic = False
- else:
- raise NotImplementedError
-
- self._validate_config()
- self._create_dataloader()
-
- def _validate_config(self): # noqa: C901
- config = self.config
- # number of GPUs total
- n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
-
- # 1. Check total batch size for data correctness
- real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
- assert (
- real_train_batch_size % n_gpus == 0
- ), f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
-
- # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
- # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
- def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
- if mbs is None and mbs_per_gpu is None:
- raise ValueError(
- f"[{name}] Please set at least one of '{name}.micro_batch_size' or "
- f"'{name}.micro_batch_size_per_gpu'."
- )
-
- if mbs is not None and mbs_per_gpu is not None:
- raise ValueError(
- f"[{name}] You have set both '{name}.micro_batch_size' AND "
- f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' "
- f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated)."
- )
-
- if not config.actor_rollout_ref.actor.use_dynamic_bsz:
- # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
- check_mutually_exclusive(
- config.actor_rollout_ref.actor.ppo_micro_batch_size,
- config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
- "actor_rollout_ref.actor",
- )
-
- # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
- check_mutually_exclusive(
- config.actor_rollout_ref.ref.log_prob_micro_batch_size,
- config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
- "actor_rollout_ref.ref",
- )
-
- # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
- check_mutually_exclusive(
- config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
- config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
- "actor_rollout_ref.rollout",
- )
-
- if self.use_critic and not config.critic.use_dynamic_bsz:
- # Check for critic micro-batch size conflicts
- check_mutually_exclusive(
- config.critic.ppo_micro_batch_size,
- config.critic.ppo_micro_batch_size_per_gpu,
- "critic",
- )
-
- # Check for reward model micro-batch size conflicts
- if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
- check_mutually_exclusive(
- config.reward_model.micro_batch_size,
- config.reward_model.micro_batch_size_per_gpu,
- "reward_model",
- )
-
- # Actor
- # if NOT dynamic_bsz, we must ensure:
- # ppo_mini_batch_size is divisible by ppo_micro_batch_size
- # ppo_micro_batch_size * sequence_parallel_size >= n_gpus
- if not config.actor_rollout_ref.actor.use_dynamic_bsz:
- assert (
- config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
- )
- sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
- if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
- assert (
- config.actor_rollout_ref.actor.ppo_mini_batch_size
- % config.actor_rollout_ref.actor.ppo_micro_batch_size
- == 0
- )
- assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
-
- # critic
- if self.use_critic and not config.critic.use_dynamic_bsz:
- assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
- sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
- if config.critic.ppo_micro_batch_size is not None:
- assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
- assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
-
- # Check if use_remove_padding is enabled when using sequence parallelism for fsdp
- if config.actor_rollout_ref.actor.strategy == "fsdp":
- if (
- config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
- or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
- ):
- assert (
- config.actor_rollout_ref.model.use_remove_padding
- ), "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
-
- if self.use_critic and config.critic.strategy == "fsdp":
- if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
- assert (
- config.critic.model.use_remove_padding
- ), "When using sequence parallelism for critic, you must enable `use_remove_padding`."
-
- if config.data.get("val_batch_size", None) is not None:
- print(
- "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves."
- )
-
- print("[validate_config] All configuration checks passed successfully!")
-
- def _create_dataloader(self):
- # TODO: we have to make sure the batch size is divisible by the dp size
- self.train_dataset = RLHFDataset(
- parquet_files=self.config.data.train_files,
- tokenizer=self.tokenizer,
- processor=self.processor,
- prompt_key=self.config.data.prompt_key,
- image_key=self.config.data.get("image_key", "images"),
- max_prompt_length=self.config.data.max_prompt_length,
- filter_prompts=True,
- return_raw_chat=self.config.data.get("return_raw_chat", False),
- truncation="error",
- )
- # use sampler for better ckpt resume
- if self.config.data.shuffle:
- train_dataloader_generator = torch.Generator()
- train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
- sampler = RandomSampler(
- data_source=self.train_dataset, generator=train_dataloader_generator
- )
- else:
- sampler = SequentialSampler(data_source=self.train_dataset)
-
- self.train_dataloader = StatefulDataLoader(
- dataset=self.train_dataset,
- batch_size=self.config.data.train_batch_size,
- num_workers=8,
- drop_last=True,
- collate_fn=collate_fn,
- sampler=sampler,
- )
-
- self.val_dataset = RLHFDataset(
- parquet_files=self.config.data.val_files,
- tokenizer=self.tokenizer,
- processor=self.processor,
- prompt_key=self.config.data.prompt_key,
- image_key=self.config.data.get("image_key", "images"),
- max_prompt_length=self.config.data.max_prompt_length,
- filter_prompts=True,
- return_raw_chat=self.config.data.get("return_raw_chat", False),
- truncation="error",
- )
- self.val_dataloader = StatefulDataLoader(
- dataset=self.val_dataset,
- # Validation datasets are sent to inference engines as a whole batch,
- # which will schedule the memory themselves.
- batch_size=len(self.val_dataset),
- num_workers=8,
- shuffle=False,
- drop_last=False,
- collate_fn=collate_fn,
- )
-
- assert len(self.train_dataloader) >= 1
- assert (
- len(self.val_dataloader) == 1
- ), "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves."
-
- print(f"Size of train dataloader: {len(self.train_dataloader)}")
-
- # inject total_training_steps to actor/critic optim_config. This is hacky.
- total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
-
- if self.config.trainer.total_training_steps is not None:
- total_training_steps = self.config.trainer.total_training_steps
-
- self.total_training_steps = total_training_steps
- print(f"Total training steps: {self.total_training_steps}")
-
- OmegaConf.set_struct(self.config, True)
- with open_dict(self.config):
- self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
- self.config.critic.optim.total_training_steps = total_training_steps
-
- def _maybe_log_val_generations(self, inputs, outputs, scores):
- """Log a table of validation samples to the configured logger (wandb or swanlab)"""
-
- generations_to_log = self.config.trainer.val_generations_to_log_to_wandb
-
- if generations_to_log == 0:
- return
-
- import numpy as np
-
- # Create tuples of (input, output, score) and sort by input text
- samples = list(zip(inputs, outputs, scores))
- samples.sort(key=lambda x: x[0]) # Sort by input text
-
- # Use fixed random seed for deterministic shuffling
- rng = np.random.RandomState(42)
- rng.shuffle(samples)
-
- # Take first N samples after shuffling
- samples = samples[:generations_to_log]
-
- # Log to each configured logger
- self.validation_generations_logger.log(
- self.config.trainer.logger, samples, self.global_steps
- )
-
- def _validate(self):
- reward_tensor_lst = []
- data_source_lst = []
-
- # Lists to collect samples for the table
- sample_inputs = []
- sample_outputs = []
- sample_scores = []
-
- for test_data in self.val_dataloader:
- test_batch = DataProto.from_single_dict(test_data)
-
- # we only do validation on rule-based rm
- if (
- self.config.reward_model.enable
- and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model"
- ):
- return {}
-
- # Store original inputs
- input_ids = test_batch.batch["input_ids"]
- input_texts = [
- self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
- ]
- sample_inputs.extend(input_texts)
-
- if "multi_modal_inputs" in test_batch.non_tensor_batch.keys():
- test_gen_batch = test_batch.pop(
- batch_keys=["input_ids", "attention_mask", "position_ids"],
- non_tensor_batch_keys=[
- "raw_prompt_ids",
- "multi_modal_data",
- "multi_modal_inputs",
- ],
- )
- else:
- test_gen_batch = test_batch.pop(
- batch_keys=["input_ids", "attention_mask", "position_ids"],
- non_tensor_batch_keys=["raw_prompt_ids"],
- )
-
- test_gen_batch.meta_info = {
- "eos_token_id": self.tokenizer.eos_token_id,
- "pad_token_id": self.tokenizer.pad_token_id,
- "recompute_log_prob": False,
- "do_sample": False,
- "validate": True,
- }
-
- # pad to be divisible by dp_size
- test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(
- test_gen_batch, self.actor_rollout_wg.world_size
- )
- test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(
- test_gen_batch_padded
- )
- # unpad
- test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
- print("validation generation end")
-
- # Store generated outputs
- output_ids = test_output_gen_batch.batch["responses"]
- output_texts = [
- self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids
- ]
- sample_outputs.extend(output_texts)
-
- test_batch = test_batch.union(test_output_gen_batch)
-
- # evaluate using reward_function
- reward_tensor = self.val_reward_fn(test_batch)
-
- # Store scores
- scores = reward_tensor.sum(-1).cpu().tolist()
- sample_scores.extend(scores)
-
- reward_tensor_lst.append(reward_tensor)
- data_source_lst.append(
- test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])
- )
-
- self._maybe_log_val_generations(
- inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores
- )
-
- reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
- data_sources = np.concatenate(data_source_lst, axis=0)
-
- # evaluate test_score based on data source
- data_source_reward = {}
- for i in range(reward_tensor.shape[0]):
- data_source = data_sources[i]
- if data_source not in data_source_reward:
- data_source_reward[data_source] = []
- data_source_reward[data_source].append(reward_tensor[i].item())
-
- metric_dict = {}
- for data_source, rewards in data_source_reward.items():
- metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
-
- return metric_dict
-
- def init_workers(self):
- """Init resource pool and worker group"""
- self.resource_pool_manager.create_resource_pool()
-
- self.resource_pool_to_cls = {
- pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()
- }
-
- # create actor and rollout
- if self.hybrid_engine:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
- actor_rollout_cls = RayClassWithInitArgs(
- cls=self.role_worker_mapping[Role.ActorRollout],
- config=self.config.actor_rollout_ref,
- role="actor",
- )
- self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls
- else:
- raise NotImplementedError
-
- # create critic
- if self.use_critic:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
- critic_cls = RayClassWithInitArgs(
- cls=self.role_worker_mapping[Role.Critic], config=self.config.critic
- )
- self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
-
- # create reference policy if needed
- if self.use_reference_policy:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(
- self.role_worker_mapping[Role.RefPolicy],
- config=self.config.actor_rollout_ref,
- role="ref",
- )
- self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
-
- # create a reward model if reward_fn is None
- if self.use_rm:
- # we create a RM here
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- rm_cls = RayClassWithInitArgs(
- self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model
- )
- self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
-
- # initialize WorkerGroup
- # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
- # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
- # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
- all_wg = {}
- self.wg_dicts = []
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(
- resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls
- )
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
- # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
- self.wg_dicts.append(wg_dict)
-
- if self.use_critic:
- self.critic_wg = all_wg["critic"]
- self.critic_wg.init_model()
-
- if self.use_reference_policy:
- self.ref_policy_wg = all_wg["ref"]
- self.ref_policy_wg.init_model()
-
- if self.use_rm:
- self.rm_wg = all_wg["rm"]
- self.rm_wg.init_model()
-
- # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
- self.actor_rollout_wg = all_wg["actor"]
- self.actor_rollout_wg.init_model()
-
- def _save_checkpoint(self):
- # path: given_path + `/global_step_{global_steps}` + `/actor`
- local_global_step_folder = os.path.join(
- self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
- )
-
- print(f"local_global_step_folder: {local_global_step_folder}")
- actor_local_path = os.path.join(local_global_step_folder, "actor")
-
- actor_remote_path = (
- None
- if self.config.trainer.default_hdfs_dir is None
- else os.path.join(
- self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor"
- )
- )
-
- remove_previous_ckpt_in_save = self.config.trainer.get(
- "remove_previous_ckpt_in_save", False
- )
- if remove_previous_ckpt_in_save:
- print(
- "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead"
- )
- max_actor_ckpt_to_keep = (
- self.config.trainer.get("max_actor_ckpt_to_keep", None)
- if not remove_previous_ckpt_in_save
- else 1
- )
- max_critic_ckpt_to_keep = (
- self.config.trainer.get("max_critic_ckpt_to_keep", None)
- if not remove_previous_ckpt_in_save
- else 1
- )
-
- self.actor_rollout_wg.save_checkpoint(
- actor_local_path,
- actor_remote_path,
- self.global_steps,
- max_ckpt_to_keep=max_actor_ckpt_to_keep,
- )
-
- if self.use_critic:
- critic_local_path = os.path.join(local_global_step_folder, "critic")
- critic_remote_path = (
- None
- if self.config.trainer.default_hdfs_dir is None
- else os.path.join(
- self.config.trainer.default_hdfs_dir,
- f"global_step_{self.global_steps}",
- "critic",
- )
- )
- self.critic_wg.save_checkpoint(
- critic_local_path,
- critic_remote_path,
- self.global_steps,
- max_ckpt_to_keep=max_critic_ckpt_to_keep,
- )
-
- # save dataloader
- dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
- dataloader_state_dict = self.train_dataloader.state_dict()
- torch.save(dataloader_state_dict, dataloader_local_path)
-
- # latest checkpointed iteration tracker (for atomic usage)
- local_latest_checkpointed_iteration = os.path.join(
- self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
- )
- with open(local_latest_checkpointed_iteration, "w") as f:
- f.write(str(self.global_steps))
-
- def _load_checkpoint(self):
- if self.config.trainer.resume_mode == "disable":
- return 0
-
- # load from hdfs
- if self.config.trainer.default_hdfs_dir is not None:
- raise NotImplementedError("load from hdfs is not implemented yet")
- else:
- checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
- if not os.path.isabs(checkpoint_folder):
- working_dir = os.getcwd()
- checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
- global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
-
- # find global_step_folder
- if self.config.trainer.resume_mode == "auto":
- if global_step_folder is None:
- print("Training from scratch")
- return 0
- else:
- if self.config.trainer.resume_mode == "resume_path":
- assert isinstance(
- self.config.trainer.resume_from_path, str
- ), "resume ckpt must be str type"
- assert (
- "global_step_" in self.config.trainer.resume_from_path
- ), "resume ckpt must specify the global_steps"
- global_step_folder = self.config.trainer.resume_from_path
- if not os.path.isabs(global_step_folder):
- working_dir = os.getcwd()
- global_step_folder = os.path.join(working_dir, global_step_folder)
- print(f"Load from checkpoint folder: {global_step_folder}")
- # set global step
- self.global_steps = int(global_step_folder.split("global_step_")[-1])
-
- print(f"Setting global step to {self.global_steps}")
- print(f"Resuming from {global_step_folder}")
-
- actor_path = os.path.join(global_step_folder, "actor")
- critic_path = os.path.join(global_step_folder, "critic")
- # load actor
- self.actor_rollout_wg.load_checkpoint(
- actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
- )
- # load critic
- if self.use_critic:
- self.critic_wg.load_checkpoint(
- critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
- )
-
- # load dataloader,
- # TODO: from remote not implemented yet
- dataloader_local_path = os.path.join(global_step_folder, "data.pt")
- if os.path.exists(dataloader_local_path):
- dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
- self.train_dataloader.load_state_dict(dataloader_state_dict)
- else:
- print(
- f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch"
- )
-
- def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
- """Reorder the data on single controller such that each dp rank gets similar total tokens"""
- attention_mask = batch.batch["attention_mask"]
- batch_size = attention_mask.shape[0]
- global_seqlen_lst = (
- batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()
- ) # (train_batch_size,)
- world_size = self.actor_rollout_wg.world_size
- global_partition_lst = get_seqlen_balanced_partitions(
- global_seqlen_lst, k_partitions=world_size, equal_size=True
- )
- # reorder based on index. The data will be automatically equally partitioned by dispatch function
- global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
- batch.reorder(global_idx)
- global_balance_stats = log_seqlen_unbalance(
- seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
- )
- metrics.update(global_balance_stats)
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 83e3480dc3..d040c329dd 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -19,24 +19,27 @@
compute_timing_metrics,
reduce_metrics,
)
-from verl.utils import hf_tokenizer
-from verl.utils.fs import copy_local_path_from_hdfs
-
-from trinity.algorithm import ADVANTAGE_FN, KL_FN
-from trinity.algorithm.utils import prefix_metrics
-from trinity.common.config import AlgorithmConfig, Config
-from trinity.common.constants import AlgorithmType
-from trinity.common.experience import Experiences
-from trinity.trainer.trainer import TrainEngineWrapper
-from trinity.trainer.verl.ray_trainer import (
+from verl.trainer.ppo.ray_trainer import (
DataProto,
+ RayClassWithInitArgs,
RayPPOTrainer,
RayWorkerGroup,
ResourcePoolManager,
Role,
_timer,
+ create_colocated_worker_cls,
find_latest_ckpt_path,
)
+from verl.utils import hf_tokenizer
+from verl.utils.fs import copy_local_path_from_hdfs
+
+from trinity.algorithm import ADVANTAGE_FN, KL_FN
+from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
+from trinity.algorithm.algorithm_manager import AlgorithmManager
+from trinity.algorithm.utils import prefix_metrics
+from trinity.common.config import Config
+from trinity.common.experience import Experiences
+from trinity.trainer.trainer import TrainEngineWrapper
from trinity.utils.monitor import Monitor
@@ -119,6 +122,19 @@ def __init__(
resource_pool_manager = ResourcePoolManager(
resource_pool_spec=resource_pool_spec, mapping=mapping
)
+ self.algorithm_config = global_config.algorithm
+ self.algorithm = None
+ self.algorithm_manager = AlgorithmManager(global_config)
+
+ # specify advantage function for various rft algorithms
+ algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
+ if algorithm.use_advantage:
+ self.advantage_fn = ADVANTAGE_FN.get(self.algorithm_config.advantage_fn)(
+ **self.algorithm_config.advantage_fn_args
+ )
+ self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)(
+ **self.algorithm_config.kl_penalty_fn_args
+ )
super().__init__(
config,
@@ -128,15 +144,6 @@ def __init__(
ray_worker_group_cls,
)
self.init_workers()
- self.algorithm_type = AlgorithmType.PPO
-
- # specify advantage function for various rft algorithms
- algo_config = global_config.algorithm
- if algo_config.algorithm_type.is_rft():
- self.advantage_fn = ADVANTAGE_FN.get(algo_config.advantage_fn)(
- **algo_config.advantage_fn_args
- )
- self.kl_fn = KL_FN.get(algo_config.kl_penalty_fn)(**algo_config.kl_penalty_fn_args)
self.logger = Monitor(
project=config.trainer.project_name,
@@ -146,20 +153,109 @@ def __init__(
)
self.reset_experiences_example_table()
+ def _validate_config(self): # TODO
+ algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
+ self.use_critic = algorithm.use_critic
+ super()._validate_config()
+
+ def init_workers(self):
+ """Init resource pool and worker group"""
+ self.resource_pool_manager.create_resource_pool()
+
+ self.resource_pool_to_cls = {
+ pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()
+ }
+
+ # create actor and rollout
+ if self.hybrid_engine:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
+ actor_rollout_cls = RayClassWithInitArgs(
+ cls=self.role_worker_mapping[Role.ActorRollout],
+ config=self.config.actor_rollout_ref,
+ role="actor",
+ )
+ self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls
+ else:
+ raise NotImplementedError
+
+ # create critic
+ if self.use_critic:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
+ critic_cls = RayClassWithInitArgs(
+ cls=self.role_worker_mapping[Role.Critic], config=self.config.critic
+ )
+ self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
+
+ # create reference policy if needed
+ if self.use_reference_policy:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
+ ref_policy_cls = RayClassWithInitArgs(
+ self.role_worker_mapping[Role.RefPolicy],
+ config=self.config.actor_rollout_ref,
+ role="ref",
+ )
+ self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
+
+ # create a reward model if reward_fn is None
+ if self.use_rm:
+ # we create a RM here
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
+ rm_cls = RayClassWithInitArgs(
+ self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model
+ )
+ self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
+
+ # initialize WorkerGroup
+ # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
+ # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
+ all_wg = {}
+ self.wg_dicts = []
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
+ wg_dict = self.ray_worker_group_cls(
+ resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls
+ )
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
+ all_wg.update(spawn_wg)
+ # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
+ self.wg_dicts.append(wg_dict)
+
+ if self.use_critic:
+ self.critic_wg = all_wg["critic"]
+ self.critic_wg.init_model()
+
+ if self.use_reference_policy:
+ self.ref_policy_wg = all_wg["ref"]
+ self.ref_policy_wg.init_model()
+
+ if self.use_rm:
+ self.rm_wg = all_wg["rm"]
+ self.rm_wg.init_model()
+
+ # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
+ self.actor_rollout_wg = all_wg["actor"]
+ self.actor_rollout_wg.init_model()
+
def reset_experiences_example_table(self):
self.experiences_example_table = pd.DataFrame(
columns=["step", "reward", "prompt", "response"]
)
+ @property
+ def train_step_num(self) -> int:
+ return self.global_steps
+
def prepare(self):
self.actor_rollout_wg.setup_weight_sync_group()
+ # The global step counter, initialized to 0
+ # It represents the total number of training steps completed so far
+ # We increment this counter at the beginning of each training step
self.global_steps = 0
- self.sft_warmup_step_num = 0
# load checkpoint before doing anything
self._load_checkpoint()
- self.sft_warmup_step_num = min(self.global_steps, self.config.trainer.sft_warmup_steps)
# perform validation before training
# currently, we only support validation using the reward_function.
@@ -170,190 +266,60 @@ def prepare(self):
if self.config.trainer.get("val_only", False):
return
- # we start from step 1
-
def _create_dataloader(self):
self.train_dataloader = _InternalDataLoader(self.config)
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
- def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]:
- self.global_steps += 1
- metrics = {}
- timing_raw = {}
-
- with _timer("step", timing_raw):
- # generate a batch
- attention_mask = experiences.attention_masks
- cumsum = torch.cumsum(attention_mask, dim=-1)
- position_ids = torch.clip(cumsum - 1, 0, None).long()
-
- batch = DataProto.from_single_dict(
- {
- "uid": np.array(experiences.run_ids), # useless
- "position_ids": position_ids,
- "input_ids": experiences.tokens.long(),
- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
- "attention_mask": attention_mask.long(),
- "response_mask": (
- experiences.action_masks[:, experiences.prompt_length :].long()
- if hasattr(experiences, "action_masks")
- and experiences.action_masks is not None
- else attention_mask[:, experiences.prompt_length :].long()
- ),
- }
- )
- batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
-
- # self._balance_batch(batch, metrics=metrics) # _balance_batch will shuffle the batch, which will break DPO
- # TODO: implement a new _balance_batch for DPO
-
- # compute global_valid tokens
- batch.meta_info["global_token_num"] = torch.sum(
- batch.batch["attention_mask"], dim=-1
- ).tolist()
-
- if self.use_reference_policy:
- # compute reference log_prob
- with _timer("ref", timing_raw):
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
-
- # update actor
- with _timer("update_actor", timing_raw):
- actor_output = self.actor_rollout_wg.update_actor(batch)
- actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
- metrics.update(actor_output_metrics)
-
- # collect metrics
- metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
-
- self.logger.log(data=metrics, step=self.global_steps)
-
- # save checkpoint
- if (
- self.config.trainer.save_freq > 0
- and self.global_steps % self.config.trainer.save_freq == 0
- ):
- with _timer("save_checkpoint", timing_raw):
- self._save_checkpoint()
-
- if self.global_steps >= self.total_training_steps:
- if (
- self.config.trainer.save_freq > 0
- and self.global_steps % self.config.trainer.save_freq != 0
- ):
- with _timer("save_checkpoint", timing_raw):
- self._save_checkpoint()
- # stop training
- return False, self.global_steps
- else:
- # continue
- return True, self.global_steps
-
- def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
- if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps:
- return False, self.global_steps
- self.global_steps += 1
- metrics = {}
- timing_raw = {}
-
- with _timer("step", timing_raw):
- # generate a batch
- attention_mask = experiences.attention_masks
- cumsum = torch.cumsum(attention_mask, dim=-1)
- position_ids = torch.clip(cumsum - 1, 0, None).long()
-
- batch = DataProto.from_single_dict(
- {
- "uid": np.array(experiences.run_ids),
- "position_ids": position_ids,
- "input_ids": experiences.tokens.long(),
- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
- "attention_mask": attention_mask.long(),
- "response_mask": (
- experiences.action_masks[:, experiences.prompt_length :].long()
- if hasattr(experiences, "action_masks")
- and experiences.action_masks is not None
- else attention_mask[:, experiences.prompt_length :].long()
- ),
- }
- )
- batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
-
- self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn
-
- # compute global_valid tokens
- batch.meta_info["global_token_num"] = torch.sum(
- batch.batch["attention_mask"], dim=-1
- ).tolist()
-
- if self.use_reference_policy:
- # compute reference log_prob
- with _timer("ref", timing_raw):
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
-
- # update actor
- with _timer("update_actor", timing_raw):
- actor_output = self.actor_rollout_wg.update_actor(batch)
- actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
- metrics.update(actor_output_metrics)
-
- # collect metrics
- metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
-
- # TODO: log as sft metrics
- self.logger.log(data=metrics, step=self.global_steps)
- self.sft_warmup_step_num += 1
- train_status = True
- if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps:
- self.logger.log(
- data={"sft_warmup_steps": self.sft_warmup_step_num},
- step=self.global_steps,
- )
- with _timer("save_checkpoint", timing_raw):
- self._save_checkpoint()
- train_status = False
- return train_status, self.global_steps
-
- def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
+ def train_step(self, experiences: Experiences) -> Tuple[bool, int]:
self.global_steps += 1
metrics = {}
timing_raw = {}
+ algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
+ algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
+ if self.algorithm != algorithm:
+ self.actor_rollout_wg.set_algorithm(algorithm_config)
+ if self.algorithm == SFTAlgorithm:
+ self.sft_to_rft()
+ self.algorithm = algorithm
with _timer("step", timing_raw):
# Convert rewards to token_level_rewards
attention_mask = experiences.attention_masks
- token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
cumsum = torch.cumsum(attention_mask, dim=-1)
- eos_mask_idx = cumsum.argmax(dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
- token_level_rewards[
- torch.arange(experiences.batch_size), eos_mask_idx
- ] = experiences.rewards
- token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
-
- batch = DataProto.from_single_dict(
- {
- "uid": np.array(experiences.run_ids),
- "position_ids": position_ids,
- "input_ids": experiences.tokens.long(),
- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
- "attention_mask": attention_mask.long(),
- "response_mask": (
- experiences.action_masks[:, experiences.prompt_length :].long()
- if hasattr(experiences, "action_masks")
- and experiences.action_masks is not None
- else attention_mask[:, experiences.prompt_length :].long()
- ),
- "token_level_scores": token_level_rewards,
- "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
- }
- )
+ batch_dict = {
+ "uid": np.array(experiences.run_ids),
+ "position_ids": position_ids,
+ "input_ids": experiences.tokens.long(),
+ "responses": experiences.tokens[:, experiences.prompt_length :].long(),
+ "attention_mask": attention_mask.long(),
+ "response_mask": (
+ experiences.action_masks[:, experiences.prompt_length :].long()
+ if hasattr(experiences, "action_masks") and experiences.action_masks is not None
+ else attention_mask[:, experiences.prompt_length :].long()
+ ),
+ }
+ if self.algorithm.use_advantage:
+ token_level_rewards = torch.zeros(
+ attention_mask.shape, dtype=experiences.rewards.dtype
+ )
+ eos_mask_idx = cumsum.argmax(dim=-1)
+ token_level_rewards[
+ torch.arange(experiences.batch_size), eos_mask_idx
+ ] = experiences.rewards
+ token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
+ batch_dict.update(
+ {
+ "token_level_scores": token_level_rewards,
+ "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
+ }
+ )
+
+ batch = DataProto.from_single_dict(batch_dict)
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
- if self.config.trainer.balance_batch:
+ if self.algorithm.can_balance_batch and self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn
# compute global_valid tokens
@@ -361,34 +327,37 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
batch.batch["attention_mask"], dim=-1
).tolist()
- if self.use_reference_policy:
+ if self.algorithm.use_reference: # ref_logprob may not be used
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
- # compute values
- if self.use_critic:
+ if self.algorithm.use_critic:
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
- with _timer("adv", timing_raw):
- # compute kl penalty
- batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch)
- metrics.update(prefix_metrics(kl_metrics, prefix="critic"))
- # compute advantages, executed on the driver process
- batch, _ = self.advantage_fn(batch)
+ if self.algorithm.use_advantage:
+ with _timer("adv", timing_raw):
+ # compute kl penalty
+ batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch)
+ metrics.update(prefix_metrics(kl_metrics, prefix="critic"))
+ # compute advantages, executed on the driver process
+ batch, _ = self.advantage_fn(batch)
- # update critic
- if self.use_critic:
+ # update critic
+ if self.algorithm.use_critic:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
- if self.config.trainer.critic_warmup <= self.global_steps:
+ if (
+ not self.algorithm.use_critic
+ or self.config.trainer.critic_warmup <= self.global_steps
+ ):
# update actor
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
@@ -404,31 +373,29 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
self._save_checkpoint()
# collect metrics
- metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
+ if self.algorithm.use_advantage: # TODO
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(
compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)
)
- if self.config.enable_preview:
+ if self.algorithm.use_advantage and self.config.enable_preview: # TODO
self._log_experiences(experiences)
# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
- if self.global_steps >= self.total_training_steps:
+ train_status = self.global_steps < self.total_training_steps
+ if not train_status or self.algorithm_manager.need_save(self.global_steps):
if (
- self.config.trainer.save_freq > 0
- and self.global_steps % self.config.trainer.save_freq != 0
+ self.config.trainer.save_freq == 0
+ or self.global_steps % self.config.trainer.save_freq != 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- # stop training
- return False, self.global_steps
- else:
- # continue
- return True, self.global_steps
+ return train_status, self.global_steps
def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
@@ -477,12 +444,6 @@ def save_checkpoint(self) -> None:
def sync_weight(self) -> None:
self.actor_rollout_wg.sync_weight()
- def set_algorithm(self, algorithm_config: AlgorithmConfig) -> None:
- self.actor_rollout_wg.set_algorithm(algorithm_config)
- if self.algorithm_type.is_sft() and (not algorithm_config.algorithm_type.is_sft()):
- self.sft_to_rft()
- self.algorithm_type = algorithm_config.algorithm_type
-
def sft_to_rft(self) -> None:
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
@@ -513,9 +474,9 @@ def sft_to_rft(self) -> None:
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f"Load from checkpoint folder: {global_step_folder}")
# set global step
- self.global_steps = int(global_step_folder.split("global_step_")[-1])
+ global_steps = int(global_step_folder.split("global_step_")[-1])
+ assert self.global_steps == global_steps + 1
- print(f"Setting global step to {self.global_steps}")
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, "actor")
From 48f596aedd479a9d1c92b4357285f7a59d955b28 Mon Sep 17 00:00:00 2001
From: weijie <34210233+shiweijiezero@users.noreply.github.com>
Date: Wed, 11 Jun 2025 17:04:11 +0800
Subject: [PATCH 08/28] Fix EntropyLossFn (#77)
Co-authored-by: weijie
---
.../source/tutorial/trinity_programming_guide.md | 2 +-
trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py | 10 +---------
2 files changed, 2 insertions(+), 10 deletions(-)
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index 2e4daeab0b..7119b1af35 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -299,7 +299,7 @@ pip install -e .[dev]
# pip install -e .\[dev\]
# Run code style checks
-pre-commit --all-files
+pre-commit run --all-files
# Commit the code after all checks pass
git commit -am "create example workflow"
diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
index cf102dd6b7..e575caa449 100644
--- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
+++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
@@ -32,12 +32,12 @@ def __call__(
"""
@classmethod
- @abstractmethod
def default_args(cls) -> Dict:
"""
Returns:
`Dict`: The default arguments for the entropy loss function.
"""
+ return {"entropy_coef": 0.0}
@ENTROPY_LOSS_FN.register_module("basic")
@@ -58,10 +58,6 @@ def __call__(
entropy_loss = masked_mean(entropy, action_mask)
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
- @classmethod
- def default_args(cls) -> Dict:
- return {"entropy_coef": 0.0}
-
@ENTROPY_LOSS_FN.register_module("none")
class DummyEntropyLossFn(EntropyLossFn):
@@ -79,7 +75,3 @@ def __call__(
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
return torch.tensor(0.0), {}
-
- @classmethod
- def default_args(cls) -> Dict:
- return {}
From 9e729965701f8c8bceaa90af35cd365eecbcef01 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Wed, 11 Jun 2025 18:15:02 +0800
Subject: [PATCH 09/28] Fix Conflicts with main (#75)
---
.../source/tutorial/trinity_configs.md | 2 +
.../tutorial/trinity_programming_guide.md | 150 +-
tests/buffer/sql_test.py | 4 +
tests/explorer/explorer_test.py | 3 +-
tests/explorer/runner_pool_test.py | 6 +-
tests/explorer/workflow_test.py | 44 +-
tests/trainer/trainer_test.py | 4 +-
tests/utils/__init__.py | 0
tests/utils/plugin_test.py | 34 +
tests/utils/plugins/__init__.py | 0
tests/utils/plugins/my_workflow.py | 12 +
trinity/buffer/buffer.py | 2 +-
trinity/buffer/db_wrapper.py | 105 ++
trinity/buffer/queue.py | 1 +
trinity/buffer/reader/file_reader.py | 7 +-
trinity/buffer/reader/queue_reader.py | 8 +-
trinity/buffer/reader/sql_reader.py | 68 +-
trinity/buffer/writer/sql_writer.py | 32 +-
trinity/cli/launcher.py | 19 +-
trinity/common/config.py | 11 +-
trinity/common/constants.py | 13 +-
trinity/common/models/__init__.py | 7 +-
trinity/common/models/model.py | 10 +-
trinity/common/models/vllm_async_model.py | 18 +-
trinity/common/workflows/__init__.py | 3 +-
trinity/common/workflows/workflow.py | 7 +-
trinity/explorer/explorer.py | 4 +-
trinity/manager/config_manager.py | 1468 +++--------------
trinity/manager/config_registry/__init__.py | 13 +
.../config_registry/buffer_config_manager.py | 433 +++++
.../config_registry/config_registry.py | 209 +++
.../explorer_config_manager.py | 298 ++++
.../config_registry/model_config_manager.py | 206 +++
.../config_registry/trainer_config_manager.py | 450 +++++
trinity/plugins/__init__.py | 1 +
trinity/trainer/verl/dp_actor.py | 5 -
trinity/trainer/verl_trainer.py | 4 +-
trinity/utils/dlc_utils.py | 56 +-
trinity/utils/eval_utils.py | 13 +-
trinity/utils/monitor.py | 51 +-
trinity/utils/plugin_loader.py | 65 +
trinity/utils/registry.py | 99 +-
42 files changed, 2451 insertions(+), 1494 deletions(-)
create mode 100644 tests/utils/__init__.py
create mode 100644 tests/utils/plugin_test.py
create mode 100644 tests/utils/plugins/__init__.py
create mode 100644 tests/utils/plugins/my_workflow.py
create mode 100644 trinity/buffer/db_wrapper.py
create mode 100644 trinity/manager/config_registry/__init__.py
create mode 100644 trinity/manager/config_registry/buffer_config_manager.py
create mode 100644 trinity/manager/config_registry/config_registry.py
create mode 100644 trinity/manager/config_registry/explorer_config_manager.py
create mode 100644 trinity/manager/config_registry/model_config_manager.py
create mode 100644 trinity/manager/config_registry/trainer_config_manager.py
create mode 100644 trinity/plugins/__init__.py
create mode 100644 trinity/utils/plugin_loader.py
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 09377e1f66..8cb8856fbc 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows:
- `temperature`: The temperature for sampling.
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
+- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
+
### Trainer Input
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index 7119b1af35..4d158f86b9 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
+ - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
+
+```{tip}
+`workflow`, `workflow_args` and `raw_task` provide different levels of customization.
+
+- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level)
+- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)
+- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
+```
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:
@@ -111,7 +120,7 @@ During initialization, `Workflow` receives the following parameters:
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```
-Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
+Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
```python
class ExampleWorkflow(Workflow):
@@ -188,6 +197,25 @@ class ExampleWorkflow(Workflow):
pass
```
+For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:
+
+```python
+# existing import lines
+from .example_workflow import ExampleWorkflow
+
+__all__ = [
+ # existing __all__ lines
+ "ExampleWorkflow",
+]
+```
+
+For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.
+
+```{tip}
+You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory.
+```
+
+
#### Avoid Re-initialization
For heavy workflows, re-initializing every time can incurs extra computational costs.
@@ -286,6 +314,126 @@ trinity run --config
---
+## Adding New Config Entries for the Config Generator (Advanced)
+
+### Step 0: Understanding Streamlit
+
+Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of [Streamlit](https://docs.streamlit.io/develop/api-reference). This project primarily utilizes various input components from Streamlit and employs `st.session_state` to store user-input parameters.
+
+### Step 1: Implement New Config Entries
+
+To illustrate the process of creating a new parameter setting for the Config Generator page, we will use `train_batch_size` as an example.
+
+1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:
+ - `trinity/manager/config_registry/buffer_config_manager.py`
+ - `trinity/manager/config_registry/explorer_config_manager.py`
+ - `trinity/manager/config_registry/model_config_manager.py`
+ - `trinity/manager/config_registry/trainer_config_manager.py`
+
+ In this case, `train_batch_size` should be placed in the `buffer_config_manager.py` file.
+
+2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with 'set_', and the remainder of the name becomes the config name.
+
+3. Decorate the parameter setting function with the `CONFIG_GENERATORS.register_config` decorator. This decorator requires the following information:
+ - Default value of the parameter
+ - Visibility condition (if applicable)
+ - Additional config parameters (if needed)
+
+```{note}
+The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=config_name` as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.
+```
+
+For `train_batch_size`, we will use the following settings:
+- Default value: 96
+- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0`
+- Additional config: `{"_train_batch_size_per_gpu": 16}`
+
+
+Here's the complete code for the `train_batch_size` parameter:
+
+```python
+@CONFIG_GENERATORS.register_config(
+ default_value=96,
+ visible=lambda: st.session_state["trainer_gpu_num"] > 0,
+ other_configs={"_train_batch_size_per_gpu": 16},
+)
+def set_train_batch_size(**kwargs):
+ key = kwargs.get("key")
+ trainer_gpu_num = st.session_state["trainer_gpu_num"]
+ st.session_state[key] = (
+ st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
+ )
+
+ def on_change():
+ st.session_state["_train_batch_size_per_gpu"] = max(
+ st.session_state[key] // st.session_state["trainer_gpu_num"], 1
+ )
+
+ st.number_input(
+ "Train Batch Size",
+ min_value=trainer_gpu_num,
+ step=trainer_gpu_num,
+ help=_str_for_train_batch_size(),
+ on_change=on_change,
+ **kwargs,
+ )
+```
+
+If the parameter requires validation, create a check function. For `train_batch_size`, we need to ensure it is divisible by `trainer_gpu_num`. If not, a warning should be displayed, and the parameter should be added to `unfinished_fields`.
+
+Decorate the check function with the `CONFIG_GENERATORS.register_check` decorator:
+
+```python
+@CONFIG_GENERATORS.register_check()
+def check_train_batch_size(unfinished_fields: set, key: str):
+ if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
+ unfinished_fields.add(key)
+ st.warning(_str_for_train_batch_size())
+```
+
+```{note}
+The `CONFIG_GENERATORS.register_check` decorator automatically receives `key=config_name` and `unfinished_fields=self.unfinished_fields` as arguments. Ensure your function accepts these keyword arguments.
+```
+
+### Step 2: Integrating New Parameters into `config_manager.py`
+
+To successfully integrate new parameters into the `config_manager.py` file, please adhere to the following procedure:
+
+1. Parameter Categorization:
+ Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:
+ - Beginner Mode: Comprises "Essential Configs" and "Important Configs" sections.
+ - Expert Mode: Includes "Model", "Buffer", "Explorer and Synchronizer", and "Trainer" sections.
+
+2. Parameter Addition:
+ Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class.
+
+ Example:
+ ```python
+ class ConfigManager:
+ def _expert_buffer_part(self):
+ self.get_configs("total_epochs", "train_batch_size")
+ ```
+
+3. YAML File Integration:
+ Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the `generate_config` function and its associated sub-functions.
+
+4. Parameter Value Assignment:
+ Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.
+
+ Example:
+ ```python
+ class ConfigManager:
+ def _gen_buffer_config(self):
+ buffer_config = {
+ "batch_size": st.session_state["train_batch_size"],
+ # Additional configuration parameters
+ }
+ ```
+
+By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.
+
+---
+
## Check Code Style
Before submitting the code, make sure it passes the code style check. Follow these steps:
diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py
index 5620c38f8e..751f1a0c30 100644
--- a/tests/buffer/sql_test.py
+++ b/tests/buffer/sql_test.py
@@ -1,6 +1,7 @@
import os
import unittest
+import ray
import torch
from trinity.buffer.reader.sql_reader import SQLReader
@@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None:
algorithm_type="ppo",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
+ wrap_in_ray=True,
)
config = BufferConfig(
max_retry_times=3,
@@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None:
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
+ db_wrapper = ray.get_actor("sql-test_buffer")
+ self.assertIsNotNone(db_wrapper)
diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py
index 74b5d400e5..b0f354c4fa 100644
--- a/tests/explorer/explorer_test.py
+++ b/tests/explorer/explorer_test.py
@@ -12,7 +12,6 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import explore
-from trinity.common.constants import MonitorType
class BaseExplorerCase(RayUnittestBase):
@@ -23,7 +22,7 @@ def setUp(self):
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.algorithm.repeat_times = 2
- self.config.monitor.monitor_type = MonitorType.TENSORBOARD
+ self.config.monitor.monitor_type = "tensorboard"
self.config.project = "Trinity-unittest"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py
index 8a6e262a90..4c0e0349f5 100644
--- a/tests/explorer/runner_pool_test.py
+++ b/tests/explorer/runner_pool_test.py
@@ -2,7 +2,7 @@
import os
import time
import unittest
-from typing import List
+from typing import List, Tuple
import ray
import torch
@@ -87,8 +87,8 @@ def init_process_group(
def has_api_server(self) -> bool:
return True
- def api_server_ready(self) -> str:
- return "http://localhosts:12345"
+ def api_server_ready(self) -> Tuple[str, str]:
+ return "http://localhosts:12345", "placeholder"
class RunnerPoolTest(unittest.TestCase):
diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py
index 8cce2f9e85..0812fb5e6e 100644
--- a/tests/explorer/workflow_test.py
+++ b/tests/explorer/workflow_test.py
@@ -5,7 +5,7 @@
from unittest.mock import MagicMock
from tests.tools import get_unittest_dataset_config
-from trinity.common.workflows import MathWorkflow
+from trinity.common.workflows import MathWorkflow, Workflow
from trinity.common.workflows.workflow import Task
@@ -15,6 +15,33 @@ class MockResponse:
reward: float = 0.0
+class DummyWorkflow(Workflow):
+ def __init__(self, model, task: Task, auxiliary_models=None):
+ super().__init__(model, task, auxiliary_models)
+ self.obj = task.raw_task
+ self.output_format = task.workflow_args["output_format"]
+
+ @property
+ def resettable(self):
+ return True
+
+ def reset(self, task: Task):
+ self.obj = task.raw_task
+ self.output_format = task.workflow_args["output_format"]
+
+ def run(self):
+ if self.output_format == "json":
+ import json
+
+ return [json.dumps(self.obj)]
+ elif self.output_format == "yaml":
+ import yaml
+
+ return [yaml.safe_dump(self.obj)]
+ else:
+ raise ValueError("Invalid output format")
+
+
class WorkflowTest(unittest.TestCase):
def test_math_workflow(self) -> None:
model = MagicMock()
@@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)
+
+ def test_workflow_resettable(self) -> None:
+ model = MagicMock()
+ json_task = Task(
+ workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"}
+ )
+ yaml_task = Task(
+ workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"}
+ )
+ workflow = json_task.to_workflow(model)
+ answer = workflow.run()
+ self.assertEqual(answer[0], '{"a": 1}')
+ workflow.reset(yaml_task)
+ answer = workflow.run()
+ self.assertEqual(answer[0], "a: 1\n")
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 5b2795d952..bf064785cd 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -15,7 +15,7 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both, train
-from trinity.common.constants import MonitorType, SyncMethod
+from trinity.common.constants import SyncMethod
class BaseTrainerCase(RayUnittestBase):
@@ -30,7 +30,7 @@ def setUp(self):
self.config.explorer.rollout_model.use_v1 = False
self.config.project = "Trainer-unittest"
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
- self.config.monitor.monitor_type = MonitorType.TENSORBOARD
+ self.config.monitor.monitor_type = "tensorboard"
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_method = SyncMethod.NCCL
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py
new file mode 100644
index 0000000000..01aa2f3967
--- /dev/null
+++ b/tests/utils/plugin_test.py
@@ -0,0 +1,34 @@
+import unittest
+from pathlib import Path
+
+import ray
+
+from trinity.common.workflows import WORKFLOWS
+from trinity.utils.plugin_loader import load_plugins
+
+
+@ray.remote
+class PluginActor:
+ def run(self):
+ my_plugin_cls = WORKFLOWS.get("my_workflow")
+ return my_plugin_cls(None, None).run()
+
+
+class TestPluginLoader(unittest.TestCase):
+ def test_load_plugins(self):
+ ray.init(ignore_reinit_error=True)
+ my_plugin_cls = WORKFLOWS.get("my_workflow")
+ self.assertIsNone(my_plugin_cls)
+ load_plugins(Path(__file__).resolve().parent / "plugins")
+ my_plugin_cls = WORKFLOWS.get("my_workflow")
+ self.assertIsNotNone(my_plugin_cls)
+ my_plugin = my_plugin_cls(None, None, None)
+ self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
+ res = my_plugin.run()
+ self.assertEqual(res[0], "Hello world")
+ self.assertEqual(res[1], "Hi")
+ remote_plugin = PluginActor.remote()
+ remote_res = ray.get(remote_plugin.run.remote())
+ self.assertEqual(remote_res[0], "Hello world")
+ self.assertEqual(remote_res[1], "Hi")
+ ray.shutdown(_exiting_interpreter=True)
diff --git a/tests/utils/plugins/__init__.py b/tests/utils/plugins/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py
new file mode 100644
index 0000000000..b999590a01
--- /dev/null
+++ b/tests/utils/plugins/my_workflow.py
@@ -0,0 +1,12 @@
+from typing import List
+
+from trinity.common.workflows import WORKFLOWS, Workflow
+
+
+@WORKFLOWS.register_module("my_workflow")
+class MyWorkflow(Workflow):
+ def __init__(self, model, task, auxiliary_models=None):
+ super().__init__(model, task, auxiliary_models)
+
+ def run(self) -> List:
+ return ["Hello world", "Hi"]
diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py
index 9d77dbb379..90f658f07c 100644
--- a/trinity/buffer/buffer.py
+++ b/trinity/buffer/buffer.py
@@ -46,7 +46,7 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
file_read_type = algorithm_type
else:
file_read_type = "rollout"
- return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
+ return FILE_READERS.get(file_read_type)(storage_config, buffer_config) # type: ignore
else:
raise ValueError(f"{storage_config.storage_type} not supported.")
diff --git a/trinity/buffer/db_wrapper.py b/trinity/buffer/db_wrapper.py
new file mode 100644
index 0000000000..977aaae493
--- /dev/null
+++ b/trinity/buffer/db_wrapper.py
@@ -0,0 +1,105 @@
+import time
+from typing import List, Optional
+
+import ray
+from sqlalchemy import asc, create_engine, desc
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import NullPool
+
+from trinity.buffer.schema import Base, create_dynamic_table
+from trinity.buffer.utils import retry_session
+from trinity.common.config import BufferConfig, StorageConfig
+from trinity.common.constants import ReadStrategy
+from trinity.utils.log import get_logger
+
+
+class DBWrapper:
+ """
+ A wrapper of a SQL database.
+
+ If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor,
+ and provide a remote interface to the local database.
+
+ For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we
+ recommend setting `wrap_in_ray` to `True`
+ """
+
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
+ self.logger = get_logger(__name__)
+ self.engine = create_engine(storage_config.path, poolclass=NullPool)
+ self.table_model_cls = create_dynamic_table(
+ storage_config.algorithm_type, storage_config.name
+ )
+
+ try:
+ Base.metadata.create_all(self.engine, checkfirst=True)
+ except OperationalError:
+ self.logger.warning("Failed to create database, assuming it already exists.")
+
+ self.session = sessionmaker(bind=self.engine)
+ self.batch_size = config.read_batch_size
+ self.max_retry_times = config.max_retry_times
+ self.max_retry_interval = config.max_retry_interval
+
+ @classmethod
+ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
+ if storage_config.wrap_in_ray:
+ return (
+ ray.remote(cls)
+ .options(
+ name=f"sql-{storage_config.name}",
+ get_if_exists=True,
+ )
+ .remote(storage_config, config)
+ )
+ else:
+ return cls(storage_config, config)
+
+ def write(self, data: list) -> None:
+ with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
+ experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
+ session.add_all(experience_models)
+
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
+ if strategy is None:
+ strategy = ReadStrategy.LFU
+
+ if strategy == ReadStrategy.LFU:
+ sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
+
+ elif strategy == ReadStrategy.LRU:
+ sortOrder = (desc(self.table_model_cls.id),)
+
+ elif strategy == ReadStrategy.PRIORITY:
+ sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
+
+ else:
+ raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
+
+ exp_list = []
+ while len(exp_list) < self.batch_size:
+ if len(exp_list):
+ self.logger.info("waiting for experiences...")
+ time.sleep(1)
+ with retry_session(
+ self.session, self.max_retry_times, self.max_retry_interval
+ ) as session:
+ # get a batch of experiences from the database
+ experiences = (
+ session.query(self.table_model_cls)
+ .filter(self.table_model_cls.reward.isnot(None))
+ .order_by(*sortOrder) # TODO: very slow
+ .limit(self.batch_size - len(exp_list))
+ .with_for_update()
+ .all()
+ )
+ # update the consumed field
+ for exp in experiences:
+ exp.consumed += 1
+ exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
+ self.logger.info(f"get {len(exp_list)} experiences:")
+ self.logger.info(f"reward = {[exp.reward for exp in exp_list]}")
+ self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
+ self.logger.info(f"first response_text = {exp_list[0].response_text}")
+ return exp_list
diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py
index a360182f07..8490c44506 100644
--- a/trinity/buffer/queue.py
+++ b/trinity/buffer/queue.py
@@ -23,6 +23,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
if storage_config.path is not None and len(storage_config.path) > 0:
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
+ sql_config.wrap_in_ray = False
self.sql_writer = SQLWriter(sql_config, self.config)
else:
self.sql_writer = None
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 58b762d3f2..8ba0ccea31 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -196,8 +196,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.reward_fn_key = meta.format.reward_fn_key
self.task_type = meta.task_type
- self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type)
- self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type)
+ self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore
+ self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore
self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1
def __len__(self):
@@ -217,11 +217,12 @@ def read(self, strategy: Optional[ReadStrategy] = None):
if self.reward_fn_key in sample
else self.default_reward_fn_cls
)
- assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
+ assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required"
task = Task(
workflow=workflow_class,
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
+ workflow_args=self.meta.workflow_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py
index ffd013d4ef..3b26014fc4 100644
--- a/trinity/buffer/reader/queue_reader.py
+++ b/trinity/buffer/reader/queue_reader.py
@@ -16,13 +16,13 @@
class QueueReader(BufferReader):
"""Reader of the Queue buffer."""
- def __init__(self, meta: StorageConfig, config: BufferConfig):
- assert meta.storage_type == StorageType.QUEUE
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig):
+ assert storage_config.storage_type == StorageType.QUEUE
self.config = config
self.queue = QueueActor.options(
- name=f"queue-{meta.name}",
+ name=f"queue-{storage_config.name}",
get_if_exists=True,
- ).remote(meta, config)
+ ).remote(storage_config, config)
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
if strategy is not None and strategy != ReadStrategy.FIFO:
diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py
index 4da2920816..dcd9d942bb 100644
--- a/trinity/buffer/reader/sql_reader.py
+++ b/trinity/buffer/reader/sql_reader.py
@@ -1,21 +1,13 @@
"""Reader of the SQL buffer."""
-import time
from typing import List, Optional
-from sqlalchemy import asc, create_engine, desc
-from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.pool import NullPool
+import ray
from trinity.buffer.buffer_reader import BufferReader
-from trinity.buffer.schema import Base, create_dynamic_table
-from trinity.buffer.utils import retry_session
+from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy, StorageType
-from trinity.utils.log import get_logger
-
-logger = get_logger(__name__)
class SQLReader(BufferReader):
@@ -23,57 +15,11 @@ class SQLReader(BufferReader):
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
- self.engine = create_engine(meta.path, poolclass=NullPool)
-
- self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
- try:
- Base.metadata.create_all(self.engine, checkfirst=True)
- except OperationalError:
- logger.warning("Failed to create database, assuming it already exists.")
- self.session = sessionmaker(bind=self.engine)
- self.batch_size = config.read_batch_size
- self.max_retry_times = config.max_retry_times
- self.max_retry_interval = config.max_retry_interval
+ self.wrap_in_ray = meta.wrap_in_ray
+ self.db_wrapper = DBWrapper.get_wrapper(meta, config)
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
- if strategy is None:
- strategy = ReadStrategy.LFU
-
- if strategy == ReadStrategy.LFU:
- sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
-
- elif strategy == ReadStrategy.LRU:
- sortOrder = (desc(self.table_model_cls.id),)
-
- elif strategy == ReadStrategy.PRIORITY:
- sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
-
+ if self.wrap_in_ray:
+ return ray.get(self.db_wrapper.read.remote(strategy))
else:
- raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
-
- exp_list = []
- while len(exp_list) < self.batch_size:
- if len(exp_list):
- logger.info("waiting for experiences...")
- time.sleep(1)
- with retry_session(
- self.session, self.max_retry_times, self.max_retry_interval
- ) as session:
- # get a batch of experiences from the database
- experiences = (
- session.query(self.table_model_cls)
- .filter(self.table_model_cls.reward.isnot(None))
- .order_by(*sortOrder) # TODO: very slow
- .limit(self.batch_size - len(exp_list))
- .with_for_update()
- .all()
- )
- # update the consumed field
- for exp in experiences:
- exp.consumed += 1
- exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
- logger.info(f"get {len(exp_list)} experiences:")
- logger.info(f"reward = {[exp.reward for exp in exp_list]}")
- logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
- logger.info(f"first response_text = {exp_list[0].response_text}")
- return exp_list
+ return self.db_wrapper.read(strategy)
diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py
index e0b0bdf640..3e054d58c6 100644
--- a/trinity/buffer/writer/sql_writer.py
+++ b/trinity/buffer/writer/sql_writer.py
@@ -1,19 +1,12 @@
"""Writer of the SQL buffer."""
-from sqlalchemy import create_engine
-from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.pool import NullPool
+import ray
from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.buffer.buffer_writer import BufferWriter
-from trinity.buffer.schema import Base, create_dynamic_table
-from trinity.buffer.utils import retry_session
+from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
-from trinity.utils.log import get_logger
-
-logger = get_logger(__name__)
class SQLWriter(BufferWriter):
@@ -25,23 +18,14 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
# TODO: support other algorithms
algorithm = ALGORITHM_TYPE.get(meta.algorithm_type)
assert algorithm.use_rollout, "Only RFT buffer is supported for writing."
- self.engine = create_engine(meta.path, poolclass=NullPool)
- self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
-
- try:
- Base.metadata.create_all(self.engine, checkfirst=True)
- except OperationalError:
- logger.warning("Failed to create database, assuming it already exists.")
-
- self.session = sessionmaker(bind=self.engine)
- self.batch_size = config.read_batch_size
- self.max_retry_times = config.max_retry_times
- self.max_retry_interval = config.max_retry_interval
+ self.wrap_in_ray = meta.wrap_in_ray
+ self.db_wrapper = DBWrapper.get_wrapper(meta, config)
def write(self, data: list) -> None:
- with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
- experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
- session.add_all(experience_models)
+ if self.wrap_in_ray:
+ ray.get(self.db_wrapper.write.remote(data))
+ else:
+ self.db_wrapper.write(data)
def finish(self) -> None:
# TODO: implement this
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index 6a01bfb688..cf4a7882aa 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -11,6 +11,7 @@
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
+from trinity.utils.plugin_loader import load_plugins
logger = get_logger(__name__)
@@ -131,7 +132,8 @@ def activate_data_module(data_workflow_url: str, config_path: str):
return
-def run(config_path: str, dlc: bool = False):
+def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
+ load_plugins(plugin_dir)
config = load_config(config_path)
config.check_and_update()
pprint(config)
@@ -161,6 +163,11 @@ def run(config_path: str, dlc: bool = False):
elif config.mode == "bench":
bench(config)
+ if dlc:
+ from trinity.utils.dlc_utils import stop_ray_cluster
+
+ stop_ray_cluster()
+
def studio(port: int = 8501):
from streamlit.web import cli as stcli
@@ -188,6 +195,12 @@ def main() -> None:
# run command
run_parser = subparsers.add_parser("run", help="Run RFT process.")
run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
+ run_parser.add_argument(
+ "--plugin-dir",
+ type=str,
+ default=None,
+ help="Path to the directory containing plugin modules.",
+ )
run_parser.add_argument(
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
)
@@ -198,12 +211,10 @@ def main() -> None:
"--port", type=int, default=8501, help="The port for Trinity-Studio."
)
- # TODO: add more commands like `monitor`, `label`
-
args = parser.parse_args()
if args.command == "run":
# TODO: support parse all args from command line
- run(args.config, args.dlc)
+ run(args.config, args.dlc, args.plugin_dir)
elif args.command == "studio":
studio(args.port)
diff --git a/trinity/common/config.py b/trinity/common/config.py
index dd863edbd3..22d8f3d711 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -7,7 +7,6 @@
from omegaconf import OmegaConf
from trinity.common.constants import (
- MonitorType,
PromptType,
ReadStrategy,
StorageType,
@@ -77,10 +76,14 @@ class StorageConfig:
format: FormatConfig = field(default_factory=FormatConfig)
index: int = 0
+ # used for StorageType.SQL
+ wrap_in_ray: bool = True
+
# used for rollout tasks
default_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
+ workflow_args: dict = field(default_factory=dict)
# ! DO NOT SET, automatically set from algorithm.algorithm_type
algorithm_type: Optional[str] = None
@@ -303,8 +306,10 @@ class TrainerConfig:
@dataclass
class MonitorConfig:
- # TODO: support multiple monitors (List[MonitorType])
- monitor_type: MonitorType = MonitorType.WANDB
+ # TODO: support multiple monitors (List[str])
+ monitor_type: str = "tensorboard"
+ # the default args for monitor
+ monitor_args: Dict = field(default_factory=dict)
# ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor
cache_dir: str = ""
diff --git a/trinity/common/constants.py b/trinity/common/constants.py
index 47b04f853b..3c49d65c21 100644
--- a/trinity/common/constants.py
+++ b/trinity/common/constants.py
@@ -23,6 +23,9 @@ def __getattr__(cls, name):
return cls[name.upper()]
return super().__getattr__(name)
+ def __call__(cls, value, *args, **kwargs):
+ return super().__call__(value.lower(), *args, **kwargs)
+
class CaseInsensitiveEnum(Enum, metaclass=CaseInsensitiveEnumMeta):
pass
@@ -47,11 +50,11 @@ class ReadStrategy(CaseInsensitiveEnum):
"""Pop Strategy."""
DEFAULT = None
- FIFO = "FIFO"
- RANDOM = "RANDOM"
- LRU = "LRU"
- LFU = "LFU"
- PRIORITY = "PRIORITY"
+ FIFO = "fifo"
+ RANDOM = "random"
+ LRU = "lru"
+ LFU = "lfu"
+ PRIORITY = "priority"
class StorageType(CaseInsensitiveEnum):
diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py
index 25cb927799..fd5670b390 100644
--- a/trinity/common/models/__init__.py
+++ b/trinity/common/models/__init__.py
@@ -64,9 +64,9 @@ def create_inference_models(
else:
raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}")
- main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)]
+ main_bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)]
auxiliary_bundles = [
- {"GPU": 1, "CPU": 1}
+ {"GPU": 1}
for _ in range(
sum(
[
@@ -103,6 +103,7 @@ def create_inference_models(
num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
+ placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundles_for_engine[0],
),
)
@@ -121,6 +122,7 @@ def create_inference_models(
bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size)
model_config.enable_openai_api = True
model_config.engine_type = "vllm_async"
+ model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine])
engines.append(
ray.remote(vLLMAysncRolloutModel)
.options(
@@ -128,6 +130,7 @@ def create_inference_models(
num_gpus=0 if model_config.tensor_parallel_size > 1 else 1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
+ placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundles_for_engine[0],
),
)
diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py
index b5104f2cc7..cb15b1ae3d 100644
--- a/trinity/common/models/model.py
+++ b/trinity/common/models/model.py
@@ -103,6 +103,11 @@ def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())
def get_openai_client(self) -> openai.OpenAI:
+ """Get the openai client.
+
+ Returns:
+ openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
+ """
if self.openai_client is not None:
return self.openai_client
if not ray.get(self.model.has_api_server.remote()):
@@ -110,9 +115,9 @@ def get_openai_client(self) -> openai.OpenAI:
"OpenAI API server is not running on current model."
"Please set `enable_openai_api` to `True`."
)
- api_address = None
+ api_address, model_path = None, None
while True:
- api_address = ray.get(self.model.api_server_ready.remote())
+ api_address, model_path = ray.get(self.model.api_server_ready.remote())
if api_address is not None:
break
else:
@@ -127,4 +132,5 @@ def get_openai_client(self) -> openai.OpenAI:
base_url=api_address,
api_key="EMPTY",
)
+ setattr(self.openai_client, "model_path", model_path) # TODO: may be removed
return self.openai_client
diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py
index 02ea52ec58..27faa4c44a 100644
--- a/trinity/common/models/vllm_async_model.py
+++ b/trinity/common/models/vllm_async_model.py
@@ -5,7 +5,7 @@
import os
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import torch
@@ -319,26 +319,30 @@ async def run_api_server(self):
async def has_api_server(self) -> bool:
return self.config.enable_openai_api
- async def api_server_ready(self) -> Optional[str]:
+ async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
"""Check if the OpenAI API server is ready.
Returns:
- str: The URL of the OpenAI API server.
+ api_url (str): The URL of the OpenAI API server.
+ model_path (str): The path of the model.
"""
if not await self.has_api_server():
- return None
+ return None, None
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://{self.api_server_host}:{self.api_server_port}/health"
) as response:
if response.status == 200:
- return f"http://{self.api_server_host}:{self.api_server_port}/v1"
+ return (
+ f"http://{self.api_server_host}:{self.api_server_port}/v1",
+ self.config.model_path,
+ )
else:
- return None
+ return None, None
except Exception as e:
self.logger.error(e)
- return None
+ return None, None
async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 92bf29a64e..f5b1c9a7b9 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -3,10 +3,11 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
-from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
+from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
__all__ = [
"Task",
+ "Workflow",
"WORKFLOWS",
"SimpleWorkflow",
"MathWorkflow",
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index 9786bd6b77..fc4a87556b 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -28,8 +28,9 @@ class Task:
"""A Task class that defines a task and its associated reward function / workflow."""
workflow: Type[Workflow]
- format_args: FormatConfig
+ format_args: FormatConfig = field(default_factory=FormatConfig)
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
+ workflow_args: dict = field(default_factory=dict)
is_eval: bool = False
reward_fn: Optional[Type[RewardFn]] = None
raw_task: Optional[dict] = None # The raw data sample
@@ -41,6 +42,10 @@ def to_workflow(
Args:
model (ModelWrapper): The rollout model for the workflow.
+ auxiliary_models (List[openai.OpenAI]): The auxiliary models for the workflow.
+
+ Note:
+ `model_path` attribute is added to the `auxiliary_models` for use within the workflow.
Returns:
Workflow: The generated workflow object.
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 37257f71ce..0a897254fb 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -21,7 +21,7 @@
from trinity.explorer.runner_pool import RunnerPool
from trinity.manager.manager import CacheManager
from trinity.utils.log import get_logger
-from trinity.utils.monitor import Monitor
+from trinity.utils.monitor import MONITOR
@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1})
@@ -49,7 +49,7 @@ def __init__(self, config: Config):
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer))
self.runner_pool = self._init_runner_pool()
- self.monitor = Monitor(
+ self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
name=self.config.name,
role="explorer",
diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py
index 9ac2d36f16..80b8992b3b 100644
--- a/trinity/manager/config_manager.py
+++ b/trinity/manager/config_manager.py
@@ -7,22 +7,15 @@
import streamlit as st
import yaml
-from trinity.common.constants import (
- AlgorithmType,
- MonitorType,
- PromptType,
- StorageType,
- SyncMethod,
-)
-from trinity.common.rewards import REWARD_FUNCTIONS
-from trinity.common.workflows.workflow import WORKFLOWS
-from trinity.trainer.verl.ray_trainer import AdvantageEstimator
+from trinity.common.constants import AlgorithmType, StorageType
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+from trinity.manager.config_registry.trainer_config_manager import use_critic
class ConfigManager:
def __init__(self):
- self._init_default_config()
self.unfinished_fields = set()
+ CONFIG_GENERATORS.set_unfinished_fields(self.unfinished_fields)
st.set_page_config(page_title="Trinity-RFT Config Generator", page_icon=":robot:")
st.title("Trinity-RFT Config Generator")
if "_init_config_manager" not in st.session_state:
@@ -44,1319 +37,256 @@ def __init__(self):
st.session_state.is_running = False
self.generate_config()
- def _init_default_config(self):
- self.default_config = {
- "_init_config_manager": True,
- "mode": "both",
- "project": "Trinity-RFT",
- "exp_name": "qwen2.5-1.5B",
- "checkpoint_root_dir": "",
- "monitor_type": MonitorType.TENSORBOARD.value,
- # Algorithm Configs
- "algorithm_type": AlgorithmType.PPO.value,
- "_grouped_adv_repeat_times": 2,
- "_not_grouped_adv_repeat_times": 1,
- "repeat_times": 1,
- "gamma": 1.0,
- "lam": 1.0,
- # Model Configs
- "model_path": "",
- "critic_model_path": "",
- "max_prompt_tokens": 1024,
- "max_response_tokens": 1024,
- # Cluster Config
- "node_num": 1,
- "gpu_per_node": 8,
- "total_gpu_num": 8,
- "trainer_gpu_num": 6,
- # Buffer Configs
- "total_epochs": 20,
- "_train_batch_size_per_gpu": 16,
- "train_batch_size": 96,
- "buffer_max_retry_times": 3,
- "max_retry_interval": 1,
- # Taskset Configs
- "taskset_path": "",
- "taskset_subset_name": None,
- "taskset_split": "train",
- "taskset_prompt_key": "question",
- "taskset_response_key": "answer",
- "temperature": 1.0,
- "top_p": 1.0, # TODO: to be used
- "top_k": -1, # TODO: to be used
- "logprobs": 0,
- # Eval Taskset Configs
- "_eval_tasksets_num": 0,
- # Explorer Input Configs
- "default_workflow_type": "math_workflow",
- "default_reward_fn_type": "math_reward",
- "system_prompt": None,
- "reply_prefix": None,
- # Experience Buffer / DPO Dataset Configs
- "_dpo_storage_type": StorageType.FILE.value,
- "_not_dpo_storage_type": StorageType.QUEUE.value,
- "storage_type": StorageType.QUEUE.value,
- "_dpo_experience_buffer_path": "",
- "_not_dpo_experience_buffer_path": "",
- "experience_buffer_path": "",
- "dpo_dataset_train_split": "train",
- "dpo_dataset_prompt_type": PromptType.MESSAGES.value,
- "dpo_dataset_prompt_key": "prompt",
- "dpo_dataset_chosen_key": "chosen",
- "dpo_dataset_rejected_key": "rejected",
- # SFT Warmup Dataset Configs
- "sft_warmup_dataset_path": "",
- "sft_warmup_train_split": "train",
- "sft_warmup_prompt_type": PromptType.MESSAGES.value,
- "sft_warmup_messages_key": "messages",
- "sft_warmup_prompt_key": "prompt",
- "sft_warmup_response_key": "response",
- # TrainerInput Configs
- # TODO: read_experience_strategy
- "sft_warmup_steps": 0,
- # Explorer and Sync Configs
- "runner_num": 32,
- "max_timeout": 900,
- "explorer_max_retry_times": 2,
- "eval_interval": 1000,
- "eval_on_latest_checkpoint": True,
- # Rollout Model Configs
- "engine_type": "vllm_async",
- "engine_num": 2,
- "tensor_parallel_size": 1,
- "use_v1": True,
- "enforce_eager": True,
- "enable_prefix_caching": False,
- "enable_chunked_prefill": False,
- "gpu_memory_utilization": 0.9,
- "dtype": "bfloat16",
- "seed": 42,
- # TODO: max_prompt_tokens
- # TODO: max_response_tokens
- # TODO: chat_template
- "enable_thinking": False,
- "enable_openai_api": False,
- # TODO: Auxiliary Models Configs
- # Synchronizer Configs
- "_not_dpo_sync_method": SyncMethod.NCCL.value,
- "sync_method": SyncMethod.NCCL.value,
- "sync_interval": 10,
- "sync_timeout": 1200,
- # Trainer Configs
- "trainer_type": "verl",
- "_nccl_save_interval": 100,
- "save_interval": 100,
- # TODO: enable_preview
- "_not_dpo_actor_use_kl_loss": True,
- "actor_use_kl_loss": True,
- "actor_kl_loss_coef": 0.001,
- "actor_entropy_coef": 0.001,
- "actor_grad_clip": 1.0,
- "actor_clip_ratio": 0.2,
- # veRL Trainer Configs
- "training_args": [
- "balance_batch",
- "gradient_checkpointing",
- "remove_padding",
- "dynamic_bsz",
- ],
- "ppo_epochs": 1,
- "training_strategy": "fsdp",
- "param_offload": False,
- "optimizer_offload": False,
- "resume_mode": "auto",
- "resume_from_path": "",
- "critic_warmup": 0,
- "total_training_steps": None,
- "default_hdfs_dir": None,
- "remove_previous_ckpt_in_save": False,
- "del_local_ckpt_after_load": False,
- "max_actor_ckpt_to_keep": None,
- "max_critic_ckpt_to_keep": None,
- "adv_estimator": "gae",
- "norm_adv_by_std_in_grpo": True,
- "use_kl_in_reward": False,
- "kl_penalty": "low_var_kl",
- "kl_ctrl_type": "fixed",
- "kl_ctrl_coef": 0.001,
- "horizon": 10000,
- "target_kl": 0.1,
- "actor_ppo_micro_batch_size_per_gpu": 4,
- "ref_log_prob_micro_batch_size_per_gpu": 8,
- "actor_ulysses_sequence_parallel_size": 1,
- "actor_lr": 1e-6,
- "actor_warmup_style": "constant",
- "actor_lr_warmup_steps_ratio": 0.0,
- "actor_tau": 0.0,
- "actor_opmd_baseline": "mean",
- "actor_use_uid": False,
- "actor_kl_loss_type": "low_var_kl",
- "actor_checkpoint": ["model", "hf_model", "optimizer", "extra"],
- "critic_lr": 1e-6,
- "critic_warmup_style": "constant",
- "critic_lr_warmup_steps_ratio": 0.0,
- "critic_grad_clip": 1.0,
- "critic_cliprange_value": 0.5,
- "critic_ppo_micro_batch_size_per_gpu": 8,
- "critic_ulysses_sequence_parallel_size": 1,
- "critic_checkpoint": ["model", "optimizer", "extra"],
- }
-
def reset_session_state(self):
- for key, value in self.default_config.items():
+ st.session_state["_init_config_manager"] = True
+ for key, value in CONFIG_GENERATORS.default_config.items():
st.session_state[key] = value
def maintain_session_state(self):
- for key in self.default_config:
+ st.session_state["_init_config_manager"] = True
+ for key in CONFIG_GENERATORS.default_config:
st.session_state[key] = st.session_state[key]
- eavl_dataset_keys = ["name", "path", "subset_name", "split", "prompt_key", "response_key"]
+
+ eval_dataset_keys = [
+ "name",
+ "path",
+ "subset_name",
+ "split",
+ "prompt_key",
+ "response_key",
+ "temperature",
+ "logprobs",
+ "n",
+ ]
+ last_idx, del_num = 0, 0
for idx in range(st.session_state["_eval_tasksets_num"]):
- for key in eavl_dataset_keys:
+ if st.session_state.get(f"eval_taskset_{idx}_del_flag", False):
+ del_num += 1
+ continue
+ for key in eval_dataset_keys:
full_key = f"eval_taskset_{idx}_{key}"
- st.session_state[full_key] = st.session_state[full_key]
-
- def _set_project(self):
- st.text_input("Project", key="project")
-
- def _set_exp_name(self):
- st.text_input("Experiment Name", key="exp_name")
-
- def _set_monitor_type(self):
- st.selectbox(
- "Monitor Type",
- options=[monitor_type.value for monitor_type in MonitorType],
- key="monitor_type",
- )
-
- def _set_model_path(self):
- st.text_input("Model Path", key="model_path")
- if not st.session_state["model_path"].strip():
- self.unfinished_fields.add("model_path")
- st.warning("Please input model path.")
-
- def _set_critic_model_path(self):
- if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value:
- st.text_input(
- "Critic Model Path (defaults to `model_path`)",
- key="critic_model_path",
- )
-
- def _set_checkpoint_root_dir(self):
- st.text_input("Checkpoint Root Dir", key="checkpoint_root_dir")
- if not st.session_state["checkpoint_root_dir"].strip(): # TODO: may auto generate
- self.unfinished_fields.add("checkpoint_root_dir")
- st.warning("Please input checkpoint root dir.")
- elif not os.path.isabs(st.session_state["checkpoint_root_dir"].strip()):
- self.unfinished_fields.add("checkpoint_root_dir")
- st.warning("Please input an absolute path.")
-
- def _set_node_num(self):
- st.number_input("Node Num", key="node_num", min_value=1, on_change=self._set_total_gpu_num)
-
- def _set_gpu_per_node(self):
- st.number_input(
- "GPU Per Node",
- key="gpu_per_node",
- min_value=1,
- max_value=8,
- on_change=self._set_total_gpu_num,
- )
-
- def _set_total_gpu_num(self):
- st.session_state["total_gpu_num"] = (
- st.session_state["gpu_per_node"] * st.session_state["node_num"]
- )
- self._set_trainer_gpu_num()
-
- def _set_trainer_gpu_num(self):
- if st.session_state["mode"] == "both":
- st.session_state["trainer_gpu_num"] = (
- st.session_state["total_gpu_num"]
- - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"]
- )
- else: # model == train
- st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"]
-
- def _set_max_prompt_tokens(self):
- st.number_input("Max Prompt Tokens", key="max_prompt_tokens", min_value=1)
-
- def _set_max_response_tokens(self):
- st.number_input("Max Response Tokens", key="max_response_tokens", min_value=1)
-
- def _set_total_epochs(self):
- st.number_input("Total Epochs", key="total_epochs", min_value=1)
-
- @property
- def _str_for_train_batch_size(self):
- trainer_gpu_num_str = (
- "`gpu_per_node * node_num - engine_num * tensor_parallel_size`"
- if st.session_state["mode"] == "both"
- else "`gpu_per_node * node_num`"
- )
- return (
- f"Please ensure that `train_batch_size` can be divided by "
- f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}."
- )
-
- def _set_train_batch_size(self):
- trainer_gpu_num = st.session_state["trainer_gpu_num"]
- st.session_state["train_batch_size"] = (
- st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
- )
-
- def on_change():
- st.session_state["_train_batch_size_per_gpu"] = max(
- st.session_state["train_batch_size"] // st.session_state["trainer_gpu_num"], 1
- )
-
- st.number_input(
- "Train Batch Size",
- key="train_batch_size",
- min_value=trainer_gpu_num,
- step=trainer_gpu_num,
- help=self._str_for_train_batch_size,
- on_change=on_change,
- )
-
- def _check_train_batch_size(self):
- if st.session_state["train_batch_size"] % st.session_state["trainer_gpu_num"] != 0:
- self.unfinished_fields.add("train_batch_size")
- st.warning(self._str_for_train_batch_size)
-
- def _set_taskset_path(self):
- st.text_input("Taskset Path", key="taskset_path")
- if not st.session_state["taskset_path"].strip():
- self.unfinished_fields.add("taskset_path")
- st.warning("Please input taskset path.")
-
- def _set_system_prompt(self):
- st.text_area(
- "System Prompt",
- key="system_prompt",
- placeholder="System prompt is used to guide the model behavior.",
- )
-
- def _set_reply_prefix(self):
- st.text_area(
- "Assistant Reply Prefix",
- key="reply_prefix",
- placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """
- """and a common setting is: \nLet me solve this step by step. """,
- )
-
- def _set_taskset_args(self):
- if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]:
- subset_name_col, split_col = st.columns(2)
- subset_name_col.text_input(
- "Subset Name :orange-badge[(Needs review)]",
- key="taskset_subset_name",
- help="The subset name used for `datasets.load_datasets`, see "
- "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
- )
- split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split")
- prompt_key_col, response_key_col = st.columns(2)
- prompt_key_col.text_input(
- "Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key"
- )
- response_key_col.text_input(
- "Response Key :orange-badge[(Needs review)]", key="taskset_response_key"
- )
- self._set_configs_with_st_columns(["temperature", "logprobs"])
-
- def _set_eval_taskset_idx(self, idx): # TODO: add delete
- st.text_input(
- "Taskset Name",
- key=f"eval_taskset_{idx}_name",
- )
- st.text_input(
- "Eval Taskset Path",
- key=f"eval_taskset_{idx}_path",
- )
- if not st.session_state[f"eval_taskset_{idx}_path"].strip():
- st.warning("Please input the taskset path, or it will be ignored.")
- subset_name_col, split_col = st.columns(2)
- subset_name_col.text_input(
- "Subset Name :orange-badge[(Needs review)]",
- key=f"eval_taskset_{idx}_subset_name",
- help="The subset name used for `datasets.load_datasets`, see "
- "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
- )
- split_col.text_input(
- "Eval Split :orange-badge[(Needs review)]",
- key=f"eval_taskset_{idx}_split",
- )
- prompt_key_col, response_key_col = st.columns(2)
- prompt_key_col.text_input(
- "Prompt Key :orange-badge[(Needs review)]",
- key=f"eval_taskset_{idx}_prompt_key",
- )
- response_key_col.text_input(
- "Response Key :orange-badge[(Needs review)]",
- key=f"eval_taskset_{idx}_response_key",
- )
-
- def _set_eval_tasksets(self):
- if st.button("Add Eval Taskset"):
- st.session_state["_eval_tasksets_num"] += 1
- if st.session_state["_eval_tasksets_num"] > 0:
- tabs = st.tabs(
- [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])]
- )
- for idx, tab in enumerate(tabs):
- with tab:
- self._set_eval_taskset_idx(idx)
-
- def _set_default_workflow_type(self):
- st.selectbox(
- "Default Workflow Type :orange-badge[(Needs review)]",
- WORKFLOWS.modules.keys(),
- key="default_workflow_type",
- help=r"""`simple_workflow`: call 'model.chat()' to get responses.
-
-`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses.
-
-Other workflows: conduct multi-turn task for the given dataset.
-""",
- )
-
- def _set_default_reward_fn_type(self):
- st.selectbox(
- "Default Reward Fn Type :orange-badge[(Needs review)]",
- REWARD_FUNCTIONS.modules.keys(),
- key="default_reward_fn_type",
- help=r"""`accuracy_reward`: check the accuracy for math problems.
-
-`format_reward`: check if the response matches the format (default: `** *`).
-
-`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1).
-""",
- )
-
- def _set_storage_type(self):
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["storage_type"] = st.session_state["_dpo_storage_type"]
- storage_candidates = [StorageType.FILE.value, StorageType.SQL.value]
- else:
- st.session_state["storage_type"] = st.session_state["_not_dpo_storage_type"]
- storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value]
-
- def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["_dpo_storage_type"] = st.session_state["storage_type"]
- else:
- st.session_state["_not_dpo_storage_type"] = st.session_state["storage_type"]
-
- st.selectbox(
- "Storage Type",
- storage_candidates,
- key="storage_type",
- on_change=on_change,
- )
-
- def _set_experience_buffer_path(self): # TODO
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["experience_buffer_path"] = st.session_state[
- "_dpo_experience_buffer_path"
- ]
- title = "DPO Dataset Path"
- help_msg = r"""This path to DPO dataset,
-
-if `storage_type == StorageType.FILE`, this should be a path to a file,
-
-if `storage_type == StorageType.SQL`, this should be a path to database."""
- else:
- st.session_state["experience_buffer_path"] = st.session_state[
- "_not_dpo_experience_buffer_path"
- ]
- title = "Experience Buffer Path"
- help_msg = r"""This path is used for `trainer`,
-
-if `storage_type == StorageType.QUEUE`, default to `None`,
-
-if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`."""
-
- def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["_dpo_experience_buffer_path"] = st.session_state[
- "experience_buffer_path"
- ]
- else:
- st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[
- "experience_buffer_path"
- ]
-
- st.text_input(
- title,
- key="experience_buffer_path",
- help=help_msg,
- on_change=on_change,
- )
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- if not st.session_state["experience_buffer_path"].strip():
- self.unfinished_fields.add("experience_buffer_path")
- st.warning("Please input DPO dataset path.")
-
- def _set_buffer_max_retry_times(self):
- st.number_input("Max Retry Times", key="buffer_max_retry_times", min_value=1)
-
- def _set_max_retry_interval(self):
- st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1)
-
- def _set_dpo_dataset_kwargs(self):
- dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2)
- dpo_dataset_train_split_col.text_input(
- "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split"
- )
- dpo_dataset_prompt_type_col.selectbox(
- "DPO Dataset Prompt Type :orange-badge[(Needs review)]",
- [prompt_type.value for prompt_type in PromptType],
- key="dpo_dataset_prompt_type",
- )
-
- (
- dpo_dataset_prompt_key_col,
- dpo_dataset_chosen_key_col,
- dpo_dataset_rejected_key_col,
- ) = st.columns(3)
- dpo_dataset_prompt_key_col.text_input(
- "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key"
- )
- dpo_dataset_chosen_key_col.text_input(
- "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key"
- )
- dpo_dataset_rejected_key_col.text_input(
- "DPO Dataset Rejected Key :orange-badge[(Needs review)]",
- key="dpo_dataset_rejected_key",
- )
-
- def _check_sft_warmup_dataset_path(self):
- if st.session_state["sft_warmup_steps"]:
- if not st.session_state["sft_warmup_dataset_path"].strip():
- self.unfinished_fields.add("sft_warmup_dataset_path")
- st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0")
-
- def _set_sft_warmup_dataset_path(self):
- st.text_input("SFT Warmup Dataset Path", key="sft_warmup_dataset_path")
- self._check_sft_warmup_dataset_path()
-
- def _set_sft_warmup_dataset_args(self):
- if (
- st.session_state["sft_warmup_dataset_path"]
- and "://" not in st.session_state["sft_warmup_dataset_path"]
- ): # TODO
- (
- sft_warmup_train_split_col,
- sft_warmup_prompt_type_col,
- ) = st.columns(2)
- sft_warmup_train_split_col.text_input(
- "SFT Dataset Train Split :orange-badge[(Needs review)]",
- key="sft_warmup_train_split",
- )
- sft_warmup_prompt_type_col.selectbox(
- "SFT Dataset Prompt Type :orange-badge[(Needs review)]",
- [prompt_type.value for prompt_type in PromptType],
- key="sft_warmup_prompt_type",
- )
- (
- sft_warmup_messages_key_col,
- sft_warmup_prompt_key_col,
- sft_warmup_response_key_col,
- ) = st.columns(
- 3
- ) # TODO: select by prompt type
- sft_warmup_messages_key_col.text_input(
- "SFT Dataset Messages Key :orange-badge[(Needs review)]",
- key="sft_warmup_messages_key",
- )
- sft_warmup_prompt_key_col.text_input(
- "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key"
- )
- sft_warmup_response_key_col.text_input(
- "SFT Dataset Response Key :orange-badge[(Needs review)]",
- key="sft_warmup_response_key",
- )
-
- def _set_engine_type(self):
- st.selectbox("Explorer Engine Type", ["vllm_async", "vllm"], key="engine_type")
-
- @property
- def _str_for_engine_num_and_tp_size(self):
- return r"""and it must meet the following constraints:
-```python
-assert engine_num * tensor_parallel_size < gpu_per_node * node_num
-if node_num > 1:
- assert gpu_per_node % tensor_parallel_size == 0
- assert engine_num * tensor_parallel_size % gpu_per_node == 0
-```"""
-
- def _set_engine_num(self):
- total_gpu_num = st.session_state["total_gpu_num"]
- max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"]
- if st.session_state["engine_num"] > max_engine_num:
- st.session_state["engine_num"] = max_engine_num
- self._set_trainer_gpu_num()
- st.number_input(
- "Engine Num",
- key="engine_num",
- min_value=1,
- max_value=max_engine_num,
- help=f"`engine_num` is used to set the quantity of inference engines, "
- f"{self._str_for_engine_num_and_tp_size}",
- on_change=self._set_trainer_gpu_num,
- )
-
- def _set_tensor_parallel_size(self):
- total_gpu_num = st.session_state["total_gpu_num"]
- max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"]
- if st.session_state["tensor_parallel_size"] > max_tensor_parallel_size:
- st.session_state["tensor_parallel_size"] = max_tensor_parallel_size
- self._set_trainer_gpu_num()
- st.number_input(
- "Tensor Parallel Size",
- key="tensor_parallel_size",
- min_value=1,
- max_value=max_tensor_parallel_size,
- help=f"`tensor_parallel_size` is used to set the tensor parallel size of inference engines, "
- f"{self._str_for_engine_num_and_tp_size}",
- on_change=self._set_trainer_gpu_num,
- )
-
- def _check_engine_num_and_tp_size(self):
- node_num = st.session_state["node_num"]
- gpu_per_node = st.session_state["gpu_per_node"]
- engine_num = st.session_state["engine_num"]
- tensor_parallel_size = st.session_state["tensor_parallel_size"]
- if node_num > 1:
- if gpu_per_node % tensor_parallel_size != 0:
- self.unfinished_fields.add("tensor_parallel_size")
- st.warning(
- "Please ensure that `tensor_parallel_size` is a factor of `gpu_per_node` when `node_num > 1`."
- )
- if engine_num * tensor_parallel_size % gpu_per_node != 0:
- self.unfinished_fields.add("engine_num")
- st.warning(
- "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`."
- )
-
- def _set_repeat_times(self): # TODO
- grouped_adv_algorithms = [
- AlgorithmType.GRPO.value,
- AlgorithmType.OPMD.value, # TODO: may add rloo
+ last_full_key = f"eval_taskset_{last_idx}_{key}"
+ st.session_state[last_full_key] = st.session_state[full_key]
+ last_idx += 1
+ st.session_state["_eval_tasksets_num"] -= del_num
+
+ auxiliary_model_keys = [
+ "model_path",
+ "engine_type",
+ "engine_num",
+ "tensor_parallel_size",
+ "gpu_memory_utilization",
+ "dtype",
+ "seed",
+ "use_v1",
+ "enforce_eager",
+ "enable_prefix_caching",
+ "enable_chunked_prefill",
+ "enable_thinking",
+ "enable_openai_api",
]
- if st.session_state["algorithm_type"] in grouped_adv_algorithms:
- min_repeat_times = 2
- st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"]
- else:
- min_repeat_times = 1
- st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"]
-
- def on_change():
- if st.session_state["algorithm_type"] in grouped_adv_algorithms:
- st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"]
- else:
- st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"]
-
- st.number_input(
- "Repeat Times",
- key="repeat_times",
- min_value=min_repeat_times,
- help="`repeat_times` is used to set how many experiences each task can generate, "
- "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
- on_change=on_change,
- )
-
- def _set_sync_method(self):
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["sync_method"] = SyncMethod.CHECKPOINT.value
- disabled = True
- else:
- st.session_state["sync_method"] = st.session_state["_not_dpo_sync_method"]
- disabled = False
-
- def on_change():
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
- st.session_state["_not_dpo_sync_method"] = st.session_state["sync_method"]
-
- st.selectbox(
- "Sync Method",
- [sync_method.value for sync_method in SyncMethod],
- key="sync_method",
- help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps.
-
-`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""",
- disabled=disabled,
- on_change=on_change,
- )
-
- def _set_sync_interval(self):
- st.number_input(
- "Sync Interval",
- key="sync_interval",
- min_value=1,
- help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""",
- )
-
- def _set_sync_timeout(self):
- st.number_input(
- "Sync Timeout",
- key="sync_timeout",
- min_value=1,
- help="The timeout value for the synchronization operation.",
- )
-
- def _set_runner_num(self):
- st.number_input("Runner Num", key="runner_num", min_value=1)
-
- def _set_dtype(self):
- st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype")
-
- def _set_temperature(self):
- st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0)
-
- def _set_top_p(self):
- st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0)
-
- def _set_top_k(self):
- st.number_input(
- "Top-k",
- key="top_k",
- min_value=-1,
- max_value=512,
- help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.",
- )
-
- def _set_seed(self):
- st.number_input("Seed", key="seed", step=1)
-
- def _set_logprobs(self):
- st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20)
-
- def _set_use_v1(self):
- st.checkbox("Use V1 Engine", key="use_v1")
-
- def _set_enable_prefix_caching(self):
- st.checkbox("Prefix Caching", key="enable_prefix_caching")
-
- def _set_enforce_eager(self):
- st.checkbox("Enforce Eager", key="enforce_eager")
-
- def _set_gpu_memory_utilization(self):
- st.number_input(
- "GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0
- )
-
- def _set_enable_chunked_prefill(self):
- st.checkbox("Chunked Prefill", key="enable_chunked_prefill")
-
- def _set_enable_thinking(self):
- st.checkbox("Enable Thinking For Qwen3", key="enable_thinking")
-
- def _set_enable_openai_api(self):
- st.checkbox("Enable OpenAI API", key="enable_openai_api")
-
- def _set_max_timeout(self):
- st.number_input("Max Timeout", key="max_timeout", min_value=0)
-
- def _set_explorer_max_retry_times(self):
- st.number_input("Explorer Max Retry Times", key="explorer_max_retry_times", min_value=0)
-
- def _set_trainer_type(self):
- st.selectbox("Trainer Type", ["verl"], key="trainer_type")
-
- def _set_algorithm_type(self):
- def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.PPO.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value
- elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["mode"] = "train"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- else: # TODO: add more algorithms
- pass
- self._set_trainer_gpu_num()
-
- st.selectbox(
- "Algorithm Type",
- [
- AlgorithmType.PPO.value,
- AlgorithmType.GRPO.value,
- AlgorithmType.DPO.value,
- AlgorithmType.OPMD.value,
- ],
- key="algorithm_type",
- on_change=on_change,
- )
-
- def _set_sft_warmup_steps(self):
- st.number_input("SFT Warmup Steps", key="sft_warmup_steps", min_value=0)
-
- def _set_eval_interval(self):
- st.number_input("Eval Interval", key="eval_interval", min_value=1)
-
- def _set_eval_on_latest_checkpoint(self):
- st.checkbox("Eval on Latest Checkpoint", key="eval_on_latest_ckp")
-
- def _set_training_args(self):
- st.multiselect(
- "Training Args",
- [
- "balance_batch",
- "gradient_checkpointing",
- "remove_padding",
- "dynamic_bsz",
- ],
- key="training_args",
- )
-
- def _set_save_interval(self):
- if (
- st.session_state["algorithm_type"] == AlgorithmType.DPO.value
- or st.session_state["sync_method"] == SyncMethod.NCCL.value
- ):
- st.session_state["save_interval"] = st.session_state["_nccl_save_interval"]
- freeze_save_interval = False
- else:
- st.session_state["save_interval"] = st.session_state["sync_interval"]
- freeze_save_interval = True
-
- def on_change():
- if (
- st.session_state["algorithm_type"] == AlgorithmType.DPO.value
- or st.session_state["sync_method"] == SyncMethod.NCCL.value
- ):
- st.session_state["_nccl_save_interval"] = st.session_state["save_interval"]
-
- st.number_input(
- "Save Interval",
- key="save_interval",
- min_value=1,
- help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`",
- disabled=freeze_save_interval,
- on_change=on_change,
- )
-
- def _set_ppo_epochs(self):
- st.number_input("PPO Epochs", key="ppo_epochs", min_value=1)
-
- def _set_training_strategy(self):
- st.selectbox(
- "Training Strategy",
- ["fsdp", "megatron"],
- key="training_strategy",
- help="megatron is not tested",
- )
-
- def _set_param_offload(self):
- st.checkbox("FSDP Param Offload", key="param_offload")
-
- def _set_optimizer_offload(self):
- st.checkbox("FSDP Optimizer Offload", key="optimizer_offload")
-
- def _set_resume_mode(self):
- st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], key="resume_mode")
-
- def _set_resume_from_path(self):
- if st.session_state["resume_mode"] == "resume_path":
- st.text_input("Resume Path", key="resume_from_path")
- if (
- not st.session_state["resume_from_path"].strip()
- or "global_step_" not in st.session_state["resume_from_path"]
- ):
- self.unfinished_fields.add("resume_from_path")
- st.warning("Please input a valid resume path when `resume_mode == resume_path`")
-
- def _set_critic_warmup(self):
- st.number_input("Critic Warmup Steps", key="critic_warmup", min_value=0)
-
- def _set_total_training_steps(self):
- st.number_input("Total Training Steps", key="total_training_steps", min_value=1)
-
- def _set_default_hdfs_dir(self):
- st.text_input("Default HDFS Dir", key="default_hdfs_dir")
-
- def _set_remove_previous_ckpt_in_save(self):
- st.checkbox("Remove Previous Checkpoint in Save", key="remove_previous_ckpt_in_save")
-
- def _set_del_local_ckpt_after_load(self):
- st.checkbox("Delete Local Checkpoint After Load", key="del_local_ckpt_after_load")
-
- def _set_max_actor_ckpt_to_keep(self):
- st.number_input("Max Actor Checkpoint to Keep", key="max_actor_ckpt_to_keep", min_value=1)
-
- def _set_max_critic_ckpt_to_keep(self):
- st.number_input("Max Critic Checkpoint to Keep", key="max_critic_ckpt_to_keep", min_value=1)
-
- def _set_gamma(self):
- st.number_input(r"Gamma :blue-badge[$\gamma$]", key="gamma")
-
- def _set_lam(self):
- st.number_input(r"Lambda :blue-badge[$\lambda$]", key="lam")
-
- def _set_norm_adv_by_std_in_grpo(self):
- st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo")
-
- def _set_use_kl_in_reward(self):
- st.checkbox("Use KL in Reward", key="use_kl_in_reward")
-
- def _set_kl_penalty(self):
- st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], key="kl_penalty")
-
- def _set_kl_ctrl_type(self):
- st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], key="kl_ctrl_type")
-
- def _set_kl_ctrl_coef(self):
- st.number_input("KL Ctrl Coef", key="kl_ctrl_coef", format="%.1e")
-
- def _set_horizon(self):
- st.number_input("Horizon", key="horizon", min_value=1.0)
-
- def _set_target_kl(self):
- st.number_input("Target KL", key="target_kl", format="%.1e")
-
- def _set_actor_ppo_micro_batch_size_per_gpu(self):
- st.session_state["actor_ppo_micro_batch_size_per_gpu"] = min(
- st.session_state["actor_ppo_micro_batch_size_per_gpu"],
- st.session_state["_train_batch_size_per_gpu"],
- )
- st.number_input(
- "Micro Batch Size Per GPU :blue-badge[(Actor)]",
- key="actor_ppo_micro_batch_size_per_gpu",
- min_value=1,
- max_value=st.session_state["_train_batch_size_per_gpu"],
- )
-
- def _set_ref_log_prob_micro_batch_size_per_gpu(self):
- st.session_state["ref_log_prob_micro_batch_size_per_gpu"] = min(
- st.session_state["ref_log_prob_micro_batch_size_per_gpu"],
- st.session_state["_train_batch_size_per_gpu"],
- )
- st.number_input(
- "Micro Batch Size Per GPU :blue-badge[(Ref)]",
- key="ref_log_prob_micro_batch_size_per_gpu",
- min_value=1,
- max_value=st.session_state["_train_batch_size_per_gpu"],
- )
-
- def _set_actor_ulysses_sequence_parallel_size(self):
- st.number_input(
- "Ulysses Sequence Parallel Size",
- key="actor_ulysses_sequence_parallel_size",
- min_value=1,
- max_value=8,
- )
-
- def _set_actor_lr(self):
- st.number_input(
- "Learning Rate :blue-badge[(Actor)]",
- key="actor_lr",
- min_value=1e-7,
- max_value=1e-3,
- format="%.1e",
- )
-
- def _set_actor_warmup_style(self):
- st.selectbox(
- "LR Warmup Style :blue-badge[(Actor)]",
- ["constant", "cosine"],
- key="actor_warmup_style",
- )
-
- def _set_actor_lr_warmup_steps_ratio(self):
- st.number_input(
- "LR Warmup Steps Ratio :blue-badge[(Actor)]",
- key="actor_lr_warmup_steps_ratio",
- min_value=0.0,
- max_value=1.0,
- )
-
- def _set_actor_grad_clip(self):
- st.number_input(
- "Grad Clip :blue-badge[(Actor)]",
- key="actor_grad_clip",
- min_value=0.0,
- max_value=1.0,
- help="Clipping by Norm",
- )
-
- def _set_actor_clip_ratio(self):
- st.number_input(
- r"Clip Ratio :blue-badge[$\epsilon$]",
- key="actor_clip_ratio",
- min_value=0.0,
- max_value=1.0,
- )
-
- def _set_actor_entropy_coef(self):
- st.number_input(
- "Entropy Coeff",
- key="actor_entropy_coef",
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- )
-
- def _set_actor_use_kl_loss(self):
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["actor_use_kl_loss"] = True
- else:
- st.session_state["actor_use_kl_loss"] = st.session_state["_not_dpo_actor_use_kl_loss"]
-
- def on_change():
- st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[
- "actor_use_kl_loss"
- ]
-
- st.checkbox("Use KL Loss", key="actor_use_kl_loss", on_change=on_change)
-
- def _set_actor_kl_loss_coef(self):
- st.number_input(
- r"KL Loss Coef :blue-badge[$\beta$]",
- key="actor_kl_loss_coef",
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- )
-
- def _set_actor_kl_loss_type(self):
- st.selectbox(
- "KL Loss Type",
- ["kl", "abs", "mse", "low_var_kl"],
- key="actor_kl_loss_type",
- )
-
- def _set_actor_tau(self):
- st.number_input(
- "Tau for OPMD",
- key="actor_tau",
- min_value=0.0,
- format="%.1e",
- )
-
- def _set_actor_opmd_baseline(self):
- st.selectbox(
- "OPMD Baseline",
- ["mean", "logavgexp"],
- key="actor_opmd_baseline",
- )
-
- def _set_actor_use_uid(self):
- st.checkbox("Use UID for OPMD", key="actor_use_uid")
-
- def _set_actor_checkpoint(self):
- st.multiselect(
- "Checkpoint",
- ["model", "hf_model", "optimizer", "extra"],
- key="actor_checkpoint",
- )
-
- def _set_critic_ppo_micro_batch_size_per_gpu(self):
- st.session_state["critic_ppo_micro_batch_size_per_gpu"] = min(
- st.session_state["critic_ppo_micro_batch_size_per_gpu"],
- st.session_state["_train_batch_size_per_gpu"],
- )
- st.number_input(
- "Micro Batch Size Per GPU :blue-badge[(Critic)]",
- key="critic_ppo_micro_batch_size_per_gpu",
- min_value=1,
- max_value=st.session_state["_train_batch_size_per_gpu"],
- )
-
- def _set_critic_ulysses_sequence_parallel_size(self):
- st.number_input(
- "Ulysses Sequence Parallel Size",
- key="critic_ulysses_sequence_parallel_size",
- min_value=1,
- max_value=8,
- )
-
- def _set_critic_lr(self):
- st.number_input(
- "Learning Rate :blue-badge[(Critic)]",
- key="critic_lr",
- min_value=1e-7,
- max_value=1e-3,
- format="%.1e",
- )
-
- def _set_critic_warmup_style(self):
- st.selectbox(
- "LR Warmup Style :blue-badge[(Critic)]",
- ["constant", "cosine"],
- key="critic_warmup_style",
- )
-
- def _set_critic_lr_warmup_steps_ratio(self):
- st.number_input(
- "LR Warmup Steps Ratio :blue-badge[(Critic)]",
- key="critic_lr_warmup_steps_ratio",
- min_value=0.0,
- max_value=1.0,
- )
-
- def _set_critic_grad_clip(self):
- st.number_input(
- "Grad Clip :blue-badge[(Critic)]",
- key="critic_grad_clip",
- min_value=0.0,
- max_value=1.0,
- help="Clipping by Norm",
- )
-
- def _set_critic_cliprange_value(self):
- st.number_input(
- "Cliprange Value",
- key="critic_cliprange_value",
- min_value=0.0,
- max_value=1.0,
- )
-
- def _set_critic_checkpoint(self):
- st.multiselect(
- "Checkpoint",
- ["model", "hf_model", "optimizer", "extra"],
- key="critic_checkpoint",
- )
-
- def _set_configs_with_st_columns(
- self, config_names: List[str], columns_config: List[int] = None
- ):
- if columns_config is None:
- columns_config = len(config_names)
- columns = st.columns(columns_config)
- for col, config_name in zip(columns, config_names):
- with col:
- getattr(self, f"_set_{config_name}")()
+ last_idx, del_num = 0, 0
+ for idx in range(st.session_state["_auxiliary_models_num"]):
+ if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False):
+ del_num += 1
+ continue
+ for key in auxiliary_model_keys:
+ full_key = f"auxiliary_model_{idx}_{key}"
+ last_full_key = f"auxiliary_model_{last_idx}_{key}"
+ st.session_state[last_full_key] = st.session_state[full_key]
+ last_idx += 1
+ st.session_state["_auxiliary_models_num"] -= del_num
+
+ def get_configs(self, *config_names: str, columns_spec: List[int] = None):
+ CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec)
def beginner_mode(self):
st.header("Essential Configs")
- self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3])
+ self.get_configs("project", "exp_name", columns_spec=[1, 2])
- self._set_model_path()
+ self.get_configs("model_path")
- self._set_checkpoint_root_dir()
+ self.get_configs("checkpoint_root_dir")
- self._set_taskset_path()
+ if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ self.get_configs("taskset_path")
+ else:
+ self.get_configs("experience_buffer_path")
- self._set_configs_with_st_columns(["algorithm_type", "sft_warmup_steps", "monitor_type"])
+ self.get_configs("algorithm_type", "sft_warmup_steps", "monitor_type")
if st.session_state["sft_warmup_steps"] > 0:
- self._set_sft_warmup_dataset_path()
+ self.get_configs("sft_warmup_dataset_path")
st.header("Important Configs")
- self._set_configs_with_st_columns(
- ["node_num", "gpu_per_node", "engine_num", "tensor_parallel_size"]
- if st.session_state["mode"] == "both"
- else ["node_num", "gpu_per_node"]
- )
- self._check_engine_num_and_tp_size()
+ self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size")
- self._set_configs_with_st_columns(
- ["total_epochs", "train_batch_size", "ppo_epochs", "repeat_times"]
- if st.session_state["mode"] == "both"
- else ["total_epochs", "train_batch_size", "ppo_epochs"]
- )
- self._check_train_batch_size()
+ self.get_configs("total_epochs", "train_batch_size", "ppo_epochs", "repeat_times")
- self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"])
+ self.get_configs("storage_type", "max_prompt_tokens", "max_response_tokens")
- self._set_configs_with_st_columns(
- ["sync_interval", "eval_interval", "save_interval"]
- if st.session_state["mode"] == "both"
- else ["eval_interval", "save_interval"]
- )
+ self.get_configs("sync_interval", "eval_interval", "save_interval")
if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
- self._set_taskset_args()
+ self.get_configs("taskset_args")
else:
- self._set_dpo_dataset_kwargs()
+ self.get_configs("dpo_dataset_kwargs")
if st.session_state["sft_warmup_steps"] > 0:
- self._set_sft_warmup_dataset_args()
+ self.get_configs("sft_warmup_dataset_args")
- self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"])
+ self.get_configs("default_workflow_type", "default_reward_fn_type")
- self._set_actor_use_kl_loss()
- if st.session_state["actor_use_kl_loss"]:
- self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"])
+ self.get_configs("actor_use_kl_loss")
+ self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type")
- self._set_configs_with_st_columns(
- [
- "actor_ppo_micro_batch_size_per_gpu",
- "actor_lr",
- "ref_log_prob_micro_batch_size_per_gpu",
- ]
+ self.get_configs(
+ "actor_ppo_micro_batch_size_per_gpu",
+ "actor_lr",
+ "ref_log_prob_micro_batch_size_per_gpu",
)
- use_critic = (
- st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value
- ) # TODO: may apply to expert mode
- if use_critic:
- self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"])
+ self.get_configs("critic_ppo_micro_batch_size_per_gpu", "critic_lr")
def _expert_model_part(self):
- self._set_configs_with_st_columns(["project", "exp_name"], columns_config=[1, 3])
+ self.get_configs("project", "exp_name", columns_spec=[1, 2])
- self._set_model_path()
- self._set_critic_model_path()
+ self.get_configs("model_path")
+ self.get_configs("critic_model_path")
- self._set_checkpoint_root_dir()
+ self.get_configs("checkpoint_root_dir")
- self._set_configs_with_st_columns(["monitor_type", "node_num", "gpu_per_node"])
- self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"])
+ self.get_configs("monitor_type", "node_num", "gpu_per_node")
+ self.get_configs("max_prompt_tokens", "max_response_tokens")
def _expert_buffer_part(self):
- self._set_configs_with_st_columns(["total_epochs", "train_batch_size"])
- self._check_train_batch_size()
+ self.get_configs("total_epochs", "train_batch_size")
- self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"])
- self._set_system_prompt()
- self._set_reply_prefix()
+ self.get_configs("default_workflow_type", "default_reward_fn_type")
+ self.get_configs("system_prompt")
+ self.get_configs("reply_prefix")
if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
with st.expander("Taskset Configs", expanded=True):
- self._set_taskset_path()
- self._set_taskset_args()
+ self.get_configs("taskset_path")
+ self.get_configs("taskset_args")
else:
with st.expander("DPO Dataset Configs", expanded=True):
- self._set_experience_buffer_path()
- self._set_dpo_dataset_kwargs()
+ self.get_configs("experience_buffer_path")
+ self.get_configs("storage_type")
+ self.get_configs("dpo_dataset_kwargs")
with st.expander("Eval Tasksets Configs", expanded=True):
- self._set_eval_tasksets()
+ self.get_configs("eval_tasksets")
with st.expander("SFT Dataset Configs"):
- self._set_sft_warmup_dataset_path()
- self._set_sft_warmup_dataset_args()
+ self.get_configs("sft_warmup_dataset_path")
+ self.get_configs("sft_warmup_dataset_args")
if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
with st.expander("Experiences Buffer Configs", expanded=True):
- self._set_storage_type()
- self._set_experience_buffer_path()
+ self.get_configs("storage_type")
+ self.get_configs("experience_buffer_path")
self.buffer_advanced_tab = st.expander("Advanced Config")
with self.buffer_advanced_tab:
- self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"])
+ self.get_configs("buffer_max_retry_times", "max_retry_interval")
def _expert_explorer_part(self):
- self._set_configs_with_st_columns(["sync_method", "sync_interval", "sync_timeout"])
-
- self._set_configs_with_st_columns(
- [
- "runner_num",
- "max_timeout",
- "explorer_max_retry_times",
- ]
- )
+ self.get_configs("sync_method", "sync_interval", "sync_timeout")
- self._set_configs_with_st_columns(["eval_interval", "eval_on_latest_checkpoint"])
+ self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval")
+
+ self.get_configs("eval_on_latest_checkpoint")
with st.expander("Rollout Model Config", expanded=True):
- self._set_configs_with_st_columns(["engine_type", "engine_num", "tensor_parallel_size"])
- self._check_engine_num_and_tp_size()
+ self.get_configs("engine_type", "engine_num", "tensor_parallel_size")
- self._set_configs_with_st_columns(["gpu_memory_utilization", "dtype", "seed"])
+ self.get_configs("gpu_memory_utilization", "dtype", "seed")
- self._set_configs_with_st_columns(
- ["use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill"]
+ self.get_configs(
+ "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill"
)
- self._set_configs_with_st_columns(["enable_thinking", "enable_openai_api"])
+ self.get_configs("enable_thinking", "enable_openai_api")
- with st.expander("Auxiliary Models", expanded=True): # TODO
- pass
+ with st.expander("Auxiliary Models", expanded=True):
+ self.get_configs("auxiliary_models")
def _expert_trainer_part(self):
- self._set_configs_with_st_columns(["algorithm_type", "gamma", "lam"])
- self._set_configs_with_st_columns(["repeat_times", "save_interval"])
- self._check_sft_warmup_dataset_path()
+ self.get_configs("algorithm_type", "gamma", "lam")
+ self.get_configs("repeat_times", "save_interval")
+ self.get_configs("enable_preview")
if st.session_state["trainer_type"] == "verl":
self._expert_verl_trainer_part()
- def _expert_verl_trainer_part(self):
- rl_training_tab, rl_algorithm_tab, actor_ref_tab, critic_tab = st.tabs(
- [
- "RL Training Config",
- "RL Algorithm Config",
- "Actor and Ref Config",
- "Critic Config",
- ]
- )
- with rl_training_tab:
- st.subheader("RL Training Config")
- self._set_training_args()
+ def _expert_verl_training_part(self):
+ st.subheader("RL Training Config")
+ self.get_configs("training_args")
- self._set_configs_with_st_columns(["ppo_epochs", "training_strategy", "resume_mode"])
+ self.get_configs("ppo_epochs", "training_strategy", "resume_mode")
- if st.session_state["training_strategy"] == "fsdp":
- self._set_configs_with_st_columns(["param_offload", "optimizer_offload"])
- self._set_resume_from_path()
+ self.get_configs("param_offload", "optimizer_offload")
+ self.get_configs("resume_from_path")
- with st.expander("Advanced Config"):
- self._set_configs_with_st_columns(["critic_warmup", "total_training_steps"])
+ with st.expander("Advanced Config"):
+ self.get_configs("critic_warmup", "total_training_steps")
- self._set_default_hdfs_dir()
+ self.get_configs("default_hdfs_dir")
- self._set_configs_with_st_columns(
- ["remove_previous_ckpt_in_save", "del_local_ckpt_after_load"]
- )
+ self.get_configs("remove_previous_ckpt_in_save", "del_local_ckpt_after_load")
- self._set_configs_with_st_columns(
- ["max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep"]
- )
+ self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep")
- with rl_algorithm_tab:
- st.subheader("RL Algorithm Config")
- self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"])
- self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"])
- self._set_configs_with_st_columns(["horizon", "target_kl"])
+ def _expert_verl_algorithm_part(self):
+ st.subheader("RL Algorithm Config")
+ self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward")
+ self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef")
+ self.get_configs("horizon", "target_kl")
- with actor_ref_tab:
- st.subheader("Actor Model Config")
- self._set_configs_with_st_columns(
- [
- "actor_ppo_micro_batch_size_per_gpu",
- "ref_log_prob_micro_batch_size_per_gpu",
- "actor_ulysses_sequence_parallel_size",
- ]
- )
+ def _expert_verl_actor_part(self):
+ st.subheader("Actor Model Config")
+ self.get_configs(
+ "actor_ppo_micro_batch_size_per_gpu",
+ "ref_log_prob_micro_batch_size_per_gpu",
+ "actor_ulysses_sequence_parallel_size",
+ )
- self._set_configs_with_st_columns(
- ["actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio"]
- )
+ self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio")
- self._set_configs_with_st_columns(
- ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef"]
- )
+ self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef")
- self._set_actor_use_kl_loss()
- if st.session_state["actor_use_kl_loss"]:
- self._set_configs_with_st_columns(["actor_kl_loss_coef", "actor_kl_loss_type"])
+ self.get_configs("actor_use_kl_loss", "actor_use_uid")
+ self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type")
- if st.session_state["algorithm_type"] == "opmd":
- self._set_configs_with_st_columns(
- ["actor_tau", "actor_opmd_baseline", "actor_use_uid"]
- )
+ self.get_configs("actor_tau", "actor_opmd_baseline")
- self._set_actor_checkpoint()
+ self.get_configs("actor_checkpoint")
- with critic_tab:
- st.subheader("Critic Model Config")
- self._set_configs_with_st_columns(
- ["critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"]
- )
+ def _expert_verl_critic_part(self):
+ st.subheader("Critic Model Config")
+ self.get_configs(
+ "critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"
+ )
- self._set_configs_with_st_columns(
- ["critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio"]
- )
+ self.get_configs("critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio")
+
+ self.get_configs("critic_grad_clip", "critic_cliprange_value")
+ self.get_configs("critic_checkpoint")
+
+ def _expert_verl_trainer_part(self):
+ name2func = {
+ "RL Training Config": self._expert_verl_training_part,
+ "RL Algorithm Config": self._expert_verl_algorithm_part,
+ "Actor and Ref Config": self._expert_verl_actor_part,
+ }
+ if use_critic():
+ name2func["Critic Config"] = self._expert_verl_critic_part
- self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"])
- self._set_critic_checkpoint()
+ tabs = st.tabs([name for name in name2func])
+ for tab, func in zip(tabs, name2func.values()):
+ with tab:
+ func()
def expert_mode(self):
tab2func = {
@@ -1455,7 +385,6 @@ def _generate_verl_config(self):
},
"trainer": {
"balance_batch": balance_batch,
- "logger": ["tensorboard"],
"resume_mode": st.session_state["resume_mode"],
"resume_from_path": st.session_state["resume_from_path"],
"default_hdfs_dir": st.session_state["default_hdfs_dir"],
@@ -1467,7 +396,7 @@ def _generate_verl_config(self):
},
}
- if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value:
+ if use_critic():
trainer_config["trainer"]["critic_warmup"] = st.session_state["critic_warmup"]
trainer_config["critic"] = {
"strategy": st.session_state["training_strategy"],
@@ -1510,8 +439,8 @@ def _generate_verl_config(self):
return trainer_config
def _gen_buffer_config(self):
+ experience_buffer_path = st.session_state["experience_buffer_path"].strip()
if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
- experience_buffer_path = st.session_state["experience_buffer_path"].strip()
if (
not experience_buffer_path
and st.session_state["storage_type"] == StorageType.SQL.value
@@ -1527,7 +456,20 @@ def _gen_buffer_config(self):
buffer_config = {
"batch_size": st.session_state["train_batch_size"],
"total_epochs": st.session_state["total_epochs"],
- "explorer_input": {
+ "trainer_input": {
+ "experience_buffer": {
+ "name": "experience_buffer",
+ "storage_type": st.session_state["storage_type"],
+ "path": experience_buffer_path,
+ },
+ "sft_warmup_steps": st.session_state["sft_warmup_steps"],
+ },
+ "max_retry_times": st.session_state["buffer_max_retry_times"],
+ "max_retry_interval": st.session_state["max_retry_interval"],
+ }
+
+ if st.session_state["mode"] != "train":
+ buffer_config["explorer_input"] = {
"taskset": {
"name": "taskset",
"storage_type": StorageType.FILE.value,
@@ -1548,31 +490,19 @@ def _gen_buffer_config(self):
"default_reward_fn_type": st.session_state["default_reward_fn_type"],
"system_prompt": st.session_state["system_prompt"],
"reply_prefix": st.session_state["reply_prefix"],
- },
- "trainer_input": {
- "experience_buffer": {
- "name": "experience_buffer",
- "storage_type": st.session_state["storage_type"],
- "path": experience_buffer_path,
- },
- "sft_warmup_steps": st.session_state["sft_warmup_steps"],
- },
- "max_retry_times": st.session_state["buffer_max_retry_times"],
- "max_retry_interval": st.session_state["max_retry_interval"],
- }
-
- for idx in range(st.session_state["_eval_tasksets_num"]):
- if st.session_state[f"eval_taskset_{idx}_path"].strip():
- buffer_config["explorer_input"]["eval_tasksets"].append(
- {
- "name": st.session_state[f"eval_taskset_{idx}_name"],
- "path": st.session_state[f"eval_taskset_{idx}_path"],
- "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"],
- "split": st.session_state[f"eval_taskset_{idx}_split"],
- "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"],
- "response_key": st.session_state[f"eval_taskset_{idx}_response_key"],
- }
- )
+ }
+ for idx in range(st.session_state["_eval_tasksets_num"]):
+ if st.session_state[f"eval_taskset_{idx}_path"].strip():
+ buffer_config["explorer_input"]["eval_tasksets"].append(
+ {
+ "name": st.session_state[f"eval_taskset_{idx}_name"],
+ "path": st.session_state[f"eval_taskset_{idx}_path"],
+ "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"],
+ "split": st.session_state[f"eval_taskset_{idx}_split"],
+ "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"],
+ "response_key": st.session_state[f"eval_taskset_{idx}_response_key"],
+ }
+ )
if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
experience_buffer = buffer_config["trainer_input"]["experience_buffer"]
experience_buffer["split"] = st.session_state["dpo_dataset_train_split"]
@@ -1676,7 +606,7 @@ def generate_config(self):
"trainer": {
"trainer_type": st.session_state["trainer_type"],
"save_interval": st.session_state["save_interval"],
- "enable_preview": True, # TODO
+ "enable_preview": st.session_state["enable_preview"],
"actor_use_kl_loss": st.session_state["actor_use_kl_loss"],
"actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"],
"actor_entropy_coef": st.session_state["actor_entropy_coef"],
@@ -1694,7 +624,7 @@ def generate_config(self):
},
}
- if st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value:
+ if use_critic():
config["model"]["critic_model_path"] = (
st.session_state["critic_model_path"].strip()
if st.session_state["critic_model_path"].strip()
diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py
new file mode 100644
index 0000000000..e62c565fb4
--- /dev/null
+++ b/trinity/manager/config_registry/__init__.py
@@ -0,0 +1,13 @@
+import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager
+import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager
+import trinity.manager.config_registry.model_config_manager as model_config_manager
+import trinity.manager.config_registry.trainer_config_manager as trainer_config_manager
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+
+__all__ = [
+ "CONFIG_GENERATORS",
+ "buffer_config_manager",
+ "explorer_config_manager",
+ "model_config_manager",
+ "trainer_config_manager",
+]
diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py
new file mode 100644
index 0000000000..044f982e94
--- /dev/null
+++ b/trinity/manager/config_registry/buffer_config_manager.py
@@ -0,0 +1,433 @@
+import streamlit as st
+
+from trinity.common.constants import AlgorithmType, PromptType, StorageType
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS
+from trinity.common.workflows.workflow import WORKFLOWS
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+
+
+@CONFIG_GENERATORS.register_config(default_value=20)
+def set_total_epochs(**kwargs):
+ st.number_input("Total Epochs", min_value=1, **kwargs)
+
+
+def _str_for_train_batch_size():
+ trainer_gpu_num_str = (
+ "`gpu_per_node * node_num - engine_num * tensor_parallel_size`"
+ if st.session_state["mode"] == "both"
+ else "`gpu_per_node * node_num`"
+ )
+ return (
+ f"Please ensure that `train_batch_size` can be divided by "
+ f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}."
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=96,
+ visible=lambda: st.session_state["trainer_gpu_num"] > 0,
+ other_configs={"_train_batch_size_per_gpu": 16},
+)
+def set_train_batch_size(**kwargs):
+ key = kwargs.get("key")
+ trainer_gpu_num = st.session_state["trainer_gpu_num"]
+ st.session_state[key] = (
+ st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
+ )
+
+ def on_change():
+ st.session_state["_train_batch_size_per_gpu"] = max(
+ st.session_state[key] // st.session_state["trainer_gpu_num"], 1
+ )
+
+ st.number_input(
+ "Train Batch Size",
+ min_value=trainer_gpu_num,
+ step=trainer_gpu_num,
+ help=_str_for_train_batch_size(),
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_check()
+def check_train_batch_size(unfinished_fields: set, key: str):
+ if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
+ unfinished_fields.add(key)
+ st.warning(_str_for_train_batch_size())
+
+
+@CONFIG_GENERATORS.register_config(default_value=3)
+def set_buffer_max_retry_times(**kwargs):
+ st.number_input("Max Retry Times", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=1)
+def set_max_retry_interval(**kwargs):
+ st.number_input("Max Retry Interval", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="")
+def set_taskset_path(**kwargs):
+ st.text_input("Taskset Path", **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_taskset_path(unfinished_fields: set, key: str):
+ if not st.session_state[key].strip():
+ unfinished_fields.add(key)
+ st.warning("Please input taskset path.")
+
+
+# def _set_temperature(self):
+# st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0)
+
+# def _set_top_p(self):
+# st.number_input("Top-p", key="top_p", min_value=0.0, max_value=1.0)
+
+# def _set_top_k(self):
+# st.number_input(
+# "Top-k",
+# key="top_k",
+# min_value=-1,
+# max_value=512,
+# help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.",
+# )
+
+# def _set_logprobs(self):
+# st.number_input("Logprobs", key="logprobs", min_value=0, max_value=20)
+
+
+@CONFIG_GENERATORS.register_config(
+ visible=lambda: st.session_state["taskset_path"]
+ and "://" not in st.session_state["taskset_path"],
+ other_configs={
+ "taskset_subset_name": None,
+ "taskset_split": "train",
+ "taskset_prompt_key": "question",
+ "taskset_response_key": "answer",
+ "temperature": 1.0,
+ "top_p": 1.0, # TODO: to be used
+ "top_k": -1, # TODO: to be used
+ "logprobs": 0,
+ },
+)
+def set_taskset_args(**kwargs):
+ subset_name_col, split_col = st.columns(2)
+ subset_name_col.text_input(
+ "Subset Name :orange-badge[(Needs review)]",
+ key="taskset_subset_name",
+ help="The subset name used for `datasets.load_datasets`, see "
+ "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
+ )
+ split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split")
+ prompt_key_col, response_key_col = st.columns(2)
+ prompt_key_col.text_input("Prompt Key :orange-badge[(Needs review)]", key="taskset_prompt_key")
+ response_key_col.text_input(
+ "Response Key :orange-badge[(Needs review)]", key="taskset_response_key"
+ )
+ # self._set_configs_with_st_columns(["temperature", "logprobs"])
+ temperature_col, logprobs_col = st.columns(2)
+ temperature_col.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0)
+ logprobs_col.number_input("Logprobs", key="logprobs", min_value=0, max_value=20)
+
+
+def _set_eval_taskset_idx(idx):
+ col1, col2 = st.columns([9, 1])
+ col1.text_input(
+ "Taskset Name",
+ key=f"eval_taskset_{idx}_name",
+ )
+ if col2.button("✖️", key=f"eval_taskset_{idx}_del_flag", type="primary"):
+ st.rerun()
+ st.text_input(
+ "Eval Taskset Path",
+ key=f"eval_taskset_{idx}_path",
+ )
+ if not st.session_state[f"eval_taskset_{idx}_path"].strip():
+ st.warning("Please input the taskset path, or it will be ignored.")
+ subset_name_col, split_col = st.columns(2)
+ subset_name_col.text_input(
+ "Subset Name :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_subset_name",
+ help="The subset name used for `datasets.load_datasets`, see "
+ "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
+ )
+ split_col.text_input(
+ "Eval Split :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_split",
+ )
+ prompt_key_col, response_key_col = st.columns(2)
+ prompt_key_col.text_input(
+ "Prompt Key :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_prompt_key",
+ )
+ response_key_col.text_input(
+ "Response Key :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_response_key",
+ )
+
+ temperature_col, logprobs_col, n_col = st.columns(3)
+ temperature_col.number_input(
+ "Temperature",
+ key=f"eval_taskset_{idx}_temperature",
+ min_value=0.0,
+ max_value=1.0,
+ )
+ logprobs_col.number_input(
+ "Logprobs",
+ key=f"eval_taskset_{idx}_logprobs",
+ min_value=0,
+ max_value=20,
+ )
+ n_col.number_input(
+ "Eval repeat times",
+ key=f"eval_taskset_{idx}_n",
+ min_value=1,
+ max_value=20,
+ )
+
+
+@CONFIG_GENERATORS.register_config(other_configs={"_eval_tasksets_num": 0})
+def set_eval_tasksets(**kwargs):
+ if st.button("Add Eval Taskset"):
+ idx = st.session_state["_eval_tasksets_num"]
+ st.session_state[f"eval_taskset_{idx}_split"] = "test"
+ st.session_state[f"eval_taskset_{idx}_prompt_key"] = "prompt"
+ st.session_state[f"eval_taskset_{idx}_response_key"] = "response"
+ st.session_state[f"eval_taskset_{idx}_temperature"] = 0.1
+ st.session_state["_eval_tasksets_num"] += 1
+ if st.session_state["_eval_tasksets_num"] > 0:
+ tabs = st.tabs(
+ [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])]
+ )
+ for idx, tab in enumerate(tabs):
+ with tab:
+ _set_eval_taskset_idx(idx)
+
+
+@CONFIG_GENERATORS.register_config(default_value="math_workflow")
+def set_default_workflow_type(**kwargs):
+ st.selectbox(
+ "Default Workflow Type :orange-badge[(Needs review)]",
+ WORKFLOWS.modules.keys(),
+ help=r"""`simple_workflow`: call 'model.chat()' to get responses.
+
+`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses.
+
+Other workflows: conduct multi-turn task for the given dataset.
+""",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value="math_reward")
+def set_default_reward_fn_type(**kwargs):
+ st.selectbox(
+ "Default Reward Fn Type :orange-badge[(Needs review)]",
+ REWARD_FUNCTIONS.modules.keys(),
+ help=r"""`accuracy_reward`: check the accuracy for math problems.
+
+`format_reward`: check if the response matches the format (default: `** *`).
+
+`math_reward`: `accuracy_reward` (1 or 0) + `format_reward` (+0.1 or -0.1).
+""",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_system_prompt(**kwargs):
+ st.text_area(
+ "System Prompt",
+ placeholder="System prompt is used to guide the model behavior.",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_reply_prefix(**kwargs):
+ st.text_area(
+ "Assistant Reply Prefix",
+ placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """
+ """and a common setting is: \nLet me solve this step by step. """,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=StorageType.QUEUE.value,
+ other_configs={
+ "_dpo_storage_type": StorageType.FILE.value,
+ "_not_dpo_storage_type": StorageType.QUEUE.value,
+ },
+)
+def set_storage_type(**kwargs):
+ key = kwargs.get("key")
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state[key] = st.session_state["_dpo_storage_type"]
+ storage_candidates = [StorageType.FILE.value, StorageType.SQL.value]
+ else:
+ st.session_state[key] = st.session_state["_not_dpo_storage_type"]
+ storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value]
+
+ def on_change():
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state["_dpo_storage_type"] = st.session_state[key]
+ else:
+ st.session_state["_not_dpo_storage_type"] = st.session_state[key]
+
+ st.selectbox(
+ "Storage Type",
+ storage_candidates,
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value="",
+ other_configs={
+ "_dpo_experience_buffer_path": "",
+ "_not_dpo_experience_buffer_path": "",
+ },
+)
+def set_experience_buffer_path(**kwargs): # TODO
+ key = kwargs.get("key")
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]:
+ st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"]
+ st.session_state[key] = st.session_state["_dpo_experience_buffer_path"]
+ title = "DPO Dataset Path"
+ help_msg = r"""This path to DPO dataset,
+
+if `storage_type == StorageType.FILE`, this should be a path to a file,
+
+if `storage_type == StorageType.SQL`, this should be a path to database."""
+ else:
+ st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"]
+ title = "Experience Buffer Path"
+ help_msg = r"""This path is used for `trainer`,
+
+if `storage_type == StorageType.QUEUE`, default to `None`,
+
+if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`."""
+
+ def on_change():
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state["_dpo_experience_buffer_path"] = st.session_state[key]
+ else:
+ st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key]
+
+ st.text_input(title, help=help_msg, on_change=on_change, **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_experience_buffer_path(unfinished_fields: set, key: str):
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if not st.session_state[key].strip():
+ unfinished_fields.add(key)
+ st.warning("Please input DPO dataset path.")
+
+
+@CONFIG_GENERATORS.register_config(
+ other_configs={
+ "dpo_dataset_train_split": "train",
+ "dpo_dataset_prompt_type": PromptType.MESSAGES.value,
+ "dpo_dataset_prompt_key": "prompt",
+ "dpo_dataset_chosen_key": "chosen",
+ "dpo_dataset_rejected_key": "rejected",
+ }
+)
+def set_dpo_dataset_kwargs(**kwargs):
+ dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2)
+ dpo_dataset_train_split_col.text_input(
+ "DPO Dataset Train Split :orange-badge[(Needs review)]", key="dpo_dataset_train_split"
+ )
+ dpo_dataset_prompt_type_col.selectbox(
+ "DPO Dataset Prompt Type :orange-badge[(Needs review)]",
+ [prompt_type.value for prompt_type in PromptType],
+ key="dpo_dataset_prompt_type",
+ )
+
+ (
+ dpo_dataset_prompt_key_col,
+ dpo_dataset_chosen_key_col,
+ dpo_dataset_rejected_key_col,
+ ) = st.columns(3)
+ dpo_dataset_prompt_key_col.text_input(
+ "DPO Dataset Prompt Key :orange-badge[(Needs review)]", key="dpo_dataset_prompt_key"
+ )
+ dpo_dataset_chosen_key_col.text_input(
+ "DPO Dataset Chosen Key :orange-badge[(Needs review)]", key="dpo_dataset_chosen_key"
+ )
+ dpo_dataset_rejected_key_col.text_input(
+ "DPO Dataset Rejected Key :orange-badge[(Needs review)]",
+ key="dpo_dataset_rejected_key",
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value="")
+def set_sft_warmup_dataset_path(**kwargs):
+ st.text_input("SFT Warmup Dataset Path", **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_sft_warmup_dataset_path(unfinished_fields: set, key: str):
+ if st.session_state["sft_warmup_steps"]:
+ if not st.session_state[key].strip():
+ unfinished_fields.add(key)
+ st.warning("Please input SFT warmup dataset path when `sft_warmup_steps` is not 0")
+
+
+@CONFIG_GENERATORS.register_config(
+ visible=lambda: st.session_state["sft_warmup_dataset_path"]
+ and "://" not in st.session_state["sft_warmup_dataset_path"],
+ other_configs={
+ "sft_warmup_train_split": "train",
+ "sft_warmup_prompt_type": PromptType.MESSAGES.value,
+ "sft_warmup_messages_key": "messages",
+ "sft_warmup_prompt_key": "prompt",
+ "sft_warmup_response_key": "response",
+ },
+)
+def set_sft_warmup_dataset_args(**kwargs):
+ (
+ sft_warmup_train_split_col,
+ sft_warmup_prompt_type_col,
+ ) = st.columns(2)
+ sft_warmup_train_split_col.text_input(
+ "SFT Dataset Train Split :orange-badge[(Needs review)]",
+ key="sft_warmup_train_split",
+ )
+ sft_warmup_prompt_type_col.selectbox(
+ "SFT Dataset Prompt Type :orange-badge[(Needs review)]",
+ [prompt_type.value for prompt_type in PromptType],
+ key="sft_warmup_prompt_type",
+ )
+ (
+ sft_warmup_messages_key_col,
+ sft_warmup_prompt_key_col,
+ sft_warmup_response_key_col,
+ ) = st.columns(
+ 3
+ ) # TODO: select by prompt type
+ sft_warmup_messages_key_col.text_input(
+ "SFT Dataset Messages Key :orange-badge[(Needs review)]",
+ key="sft_warmup_messages_key",
+ )
+ sft_warmup_prompt_key_col.text_input(
+ "SFT Dataset Prompt Key :orange-badge[(Needs review)]", key="sft_warmup_prompt_key"
+ )
+ sft_warmup_response_key_col.text_input(
+ "SFT Dataset Response Key :orange-badge[(Needs review)]",
+ key="sft_warmup_response_key",
+ )
+
+
+# TODO: read_experience_strategy
+
+
+@CONFIG_GENERATORS.register_config(default_value=0)
+def set_sft_warmup_steps(**kwargs):
+ st.number_input("SFT Warmup Steps", min_value=0, **kwargs)
diff --git a/trinity/manager/config_registry/config_registry.py b/trinity/manager/config_registry/config_registry.py
new file mode 100644
index 0000000000..3b621a2de2
--- /dev/null
+++ b/trinity/manager/config_registry/config_registry.py
@@ -0,0 +1,209 @@
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Set
+
+import streamlit as st
+
+from trinity.utils.registry import Registry
+
+
+class ConfigRegistry(Registry):
+ """
+ A registry for managing configuration settings and their associated functions.
+ """
+
+ def __init__(self, name: str):
+ super().__init__(name)
+ self._default_config = {} # Stores default values for configs
+ self._config_visibles = {} # Stores visibles for config visibility
+ self.unfinished_fields = set()
+
+ def set_unfinished_fields(self, unfinished_fields: set):
+ """
+ Set the unfinished fields to track incomplete configurations.
+
+ Args:
+ unfinished_fields (set): Set of field names that are not yet configured.
+ """
+ self.unfinished_fields = unfinished_fields
+
+ @property
+ def default_config(self) -> dict:
+ """
+ Get the dictionary of default configuration values.
+ """
+ return self._default_config
+
+ def get(self, config_name: str):
+ """
+ Retrieve a configuration function if its visible is met (if any).
+
+ Args:
+ config_name (str): Name of the configuration to retrieve.
+
+ Returns:
+ The configuration function if visibles are met, else None.
+ """
+ if config_name in self._config_visibles:
+ if not self._config_visibles[config_name]():
+ return None
+ return super().get(config_name)
+
+ def get_check_func(self, config_name: str):
+ """
+ Get the check function associated with a configuration.
+
+ Args:
+ config_name (str): Name of the configuration.
+
+ Returns:
+ The check function for the specified configuration.
+ """
+ check_func_name = f"check_{config_name}"
+ return super().get(check_func_name)
+
+ def get_configs(self, *config_names: str, columns_spec: List[int] = None):
+ """
+ Retrieve and display multiple configurations in Streamlit columns.
+
+ Args:
+ *config_names (str): Names of configurations to retrieve.
+ columns_spec (List[int], optional): Configuration for Streamlit columns.
+ """
+ config_pair = []
+ for config_name in config_names:
+ config_func = self.get(config_name)
+ if config_func is not None:
+ config_pair.append((config_name, config_func))
+ if len(config_pair) == 0:
+ return
+
+ if columns_spec is None:
+ columns_spec = len(config_pair)
+ columns = st.columns(columns_spec)
+ for col, (_, config_func) in zip(columns, config_pair):
+ with col:
+ config_func()
+ for config_name, _ in config_pair:
+ check_func = self.get_check_func(config_name)
+ if check_func is not None:
+ check_func(unfinished_fields=self.unfinished_fields)
+
+ def _register_config(
+ self,
+ config_name: str,
+ config_func: Callable[[None], None],
+ default_value: Optional[Any] = None,
+ visible: Optional[Callable[[], bool]] = None,
+ other_configs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Internal method to register a configuration and its associated function.
+
+ Args:
+ config_name (str): Name of the configuration.
+ config_func (Callable): Function to set the configuration.
+ default_value (Any, optional): Default value for the configuration.
+ visible (Callable, optional): visible for when the config should be visible/applicable.
+ other_configs (Dict[str, Any], optional): Additional configurations to register.
+ """
+ assert config_name not in self._default_config, f"{config_name} already exists."
+ self._default_config[config_name] = default_value
+ if visible is not None:
+ self._config_visibles[config_name] = visible
+ if other_configs is not None:
+ for name, value in other_configs.items():
+ assert name not in self._default_config, f"{name} already exists."
+ self._default_config[name] = value
+ super()._register_module(module_name=config_name, module_cls=config_func)
+
+ def register_config(
+ self,
+ default_value: Optional[Any] = None,
+ config_func: Optional[Callable[[None], None]] = None,
+ visible: Optional[Callable[[], bool]] = None,
+ other_configs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Decorator to register a configuration function.
+
+ The function name must start with 'set_', and the part after 'set_' becomes the config name.
+
+ Note: This function will automatically pass `key=config_name` as an argument to the
+ registered configuration function. Ensure your function accepts this keyword argument.
+
+ Args:
+ default_value (Any, optional): Default value for the configuration.
+ config_func (Callable, optional): The configuration function to register.
+ visible (Callable, optional): visible for when the config should be visible.
+ other_configs (Dict[str, Any], optional): Additional configurations to register.
+
+ Returns:
+ A decorator function if config_func is None, else the registered config function.
+ """
+
+ # if config_func is None, should return a decorator function
+ def _register(config_func: Callable[[None], None]):
+ config_name = config_func.__name__
+ prefix = "set_"
+ assert config_name.startswith(
+ prefix
+ ), f"Config function name should start with `{prefix}`, got {config_name}"
+ config_name = config_name[len(prefix) :]
+ config_func = partial(config_func, key=config_name)
+ self._register_config(
+ config_name=config_name,
+ config_func=config_func,
+ default_value=default_value,
+ visible=visible,
+ other_configs=other_configs,
+ )
+ return config_func
+
+ if config_func is not None:
+ return _register(config_func)
+ return _register
+
+ def _register_check(self, config_name: str, check_func: Callable[[Set, str], None]):
+ """
+ Internal method to register a check function for a configuration.
+
+ Args:
+ config_name (str): Name of the configuration to check.
+ check_func (Callable): Function to check the configuration.
+ """
+ assert config_name in self._default_config, f"`{config_name}` is not registered."
+ super()._register_module(module_name=f"check_{config_name}", module_cls=check_func)
+
+ def register_check(self, check_func: Callable[[Set, str], None] = None):
+ """
+ Decorator to register a check function for a configuration.
+
+ The function name must start with 'check_', and the part after 'check_' should match a config name.
+
+ Note: This function will automatically pass `key=config_name` and `unfinished_fields=self.unfinished_fields` as an argument to the registered check function. Ensure your function accepts these keyword arguments.
+
+ Args:
+ check_func (Callable, optional): The check function to register.
+
+ Returns:
+ A decorator function if check_func is None, else the registered check function.
+ """
+
+ def _register(check_func: Callable[[Set, str], None]):
+ config_name = check_func.__name__
+ prefix = "check_"
+ assert config_name.startswith(
+ prefix
+ ), f"Check function name must start with `{prefix}`, got {config_name}"
+ config_name = config_name[len(prefix) :]
+ check_func = partial(check_func, key=config_name)
+ self._register_check(config_name, check_func)
+ return check_func
+
+ if check_func is not None:
+ return _register(check_func)
+ return _register
+
+
+# Global registry for configuration generators
+CONFIG_GENERATORS = ConfigRegistry("config_generators")
diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py
new file mode 100644
index 0000000000..9393187f60
--- /dev/null
+++ b/trinity/manager/config_registry/explorer_config_manager.py
@@ -0,0 +1,298 @@
+import streamlit as st
+
+from trinity.common.constants import AlgorithmType, SyncMethod
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num
+
+
+def explorer_visible() -> bool:
+ return st.session_state["mode"] == "both"
+
+
+@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible)
+def set_runner_num(**kwargs):
+ st.number_input("Runner Num", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible)
+def set_max_timeout(**kwargs):
+ st.number_input("Max Timeout", min_value=0, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible)
+def set_explorer_max_retry_times(**kwargs):
+ st.number_input("Explorer Max Retry Times", min_value=0, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=1000, visible=explorer_visible)
+def set_eval_interval(**kwargs):
+ st.number_input("Eval Interval", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible)
+def set_eval_on_latest_checkpoint(**kwargs):
+ st.checkbox("Eval on Latest Checkpoint", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="vllm_async", visible=explorer_visible)
+def set_engine_type(**kwargs):
+ st.selectbox("Engine Type", ["vllm_async", "vllm"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=2, visible=explorer_visible)
+def set_engine_num(**kwargs):
+ key = kwargs.get("key")
+ total_gpu_num = st.session_state["total_gpu_num"]
+ max_engine_num = (total_gpu_num - 1) // st.session_state["tensor_parallel_size"]
+ if st.session_state[key] > max_engine_num:
+ st.session_state[key] = max_engine_num
+ set_trainer_gpu_num()
+ st.number_input(
+ "Engine Num",
+ min_value=1,
+ max_value=max_engine_num,
+ on_change=set_trainer_gpu_num,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1, visible=explorer_visible)
+def set_tensor_parallel_size(**kwargs):
+ key = kwargs.get("key")
+ total_gpu_num = st.session_state["total_gpu_num"]
+ max_tensor_parallel_size = (total_gpu_num - 1) // st.session_state["engine_num"]
+ if st.session_state[key] > max_tensor_parallel_size:
+ st.session_state[key] = max_tensor_parallel_size
+ set_trainer_gpu_num()
+ st.number_input(
+ "Tensor Parallel Size",
+ min_value=1,
+ max_value=max_tensor_parallel_size,
+ on_change=set_trainer_gpu_num,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_check()
+def check_tensor_parallel_size(unfinished_fields: set, key: str):
+ if st.session_state["trainer_gpu_num"] <= 0:
+ unfinished_fields.add("engine_num")
+ unfinished_fields.add("tensor_parallel_size")
+ st.warning(
+ "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`."
+ )
+ elif (
+ st.session_state["node_num"] > 1
+ and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0
+ ):
+ unfinished_fields.add("engine_num")
+ unfinished_fields.add("tensor_parallel_size")
+ st.warning(
+ "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`"
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible)
+def set_use_v1(**kwargs):
+ st.checkbox("Use V1 Engine", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible)
+def set_enforce_eager(**kwargs):
+ st.checkbox("Enforce Eager", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible)
+def set_enable_prefix_caching(**kwargs):
+ st.checkbox("Prefix Caching", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible)
+def set_enable_chunked_prefill(**kwargs):
+ st.checkbox("Chunked Prefill", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.9, visible=explorer_visible)
+def set_gpu_memory_utilization(**kwargs):
+ st.number_input("GPU Memory Utilization", min_value=0.0, max_value=1.0, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="bfloat16", visible=explorer_visible)
+def set_dtype(**kwargs):
+ st.selectbox("Dtype", ["bfloat16", "float16", "float32"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=42, visible=explorer_visible)
+def set_seed(**kwargs):
+ st.number_input("Seed", step=1, **kwargs)
+
+
+# TODO: max_prompt_tokens
+# TODO: max_response_tokens
+# TODO: chat_template
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible)
+def set_enable_thinking(**kwargs):
+ st.checkbox("Enable Thinking For Qwen3", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=explorer_visible)
+def set_enable_openai_api(**kwargs):
+ st.checkbox("Enable OpenAI API", **kwargs)
+
+
+def _set_auxiliary_model_idx(idx):
+ col1, col2 = st.columns([9, 1])
+ col1.text_input(
+ "Model Path",
+ key=f"auxiliary_model_{idx}_model_path",
+ )
+ if col2.button("✖️", key=f"auxiliary_model_{idx}_del_flag", type="primary"):
+ st.rerun()
+
+ engine_type_col, engine_num_col, tensor_parallel_size_col = st.columns(3)
+ total_gpu_num = st.session_state["total_gpu_num"]
+ engine_type_col.selectbox(
+ "Engine Type", ["vllm_async"], key=f"auxiliary_model_{idx}_engine_type"
+ )
+ engine_num_col.number_input(
+ "Engine Num",
+ min_value=1,
+ max_value=total_gpu_num - 1,
+ on_change=set_trainer_gpu_num,
+ key=f"auxiliary_model_{idx}_engine_num",
+ )
+ tensor_parallel_size_col.number_input(
+ "Tensor Parallel Size",
+ min_value=1,
+ max_value=8,
+ on_change=set_trainer_gpu_num,
+ key=f"auxiliary_model_{idx}_tensor_parallel_size",
+ )
+
+ gpu_memory_utilization_col, dtype_col, seed_col = st.columns(3)
+ gpu_memory_utilization_col.number_input(
+ "GPU Memory Utilization",
+ min_value=0.0,
+ max_value=1.0,
+ key=f"auxiliary_model_{idx}_gpu_memory_utilization",
+ )
+ dtype_col.selectbox(
+ "Dtype", ["bfloat16", "float16", "float32"], key=f"auxiliary_model_{idx}_dtype"
+ )
+ seed_col.number_input("Seed", step=1, key=f"auxiliary_model_{idx}_seed")
+
+ (
+ use_v1_col,
+ enforce_eager_col,
+ enable_prefix_caching_col,
+ enable_chunked_prefill_col,
+ ) = st.columns(4)
+ use_v1_col.checkbox("Use V1 Engine", key=f"auxiliary_model_{idx}_use_v1")
+ enforce_eager_col.checkbox("Enforce Eager", key=f"auxiliary_model_{idx}_enforce_eager")
+ enable_prefix_caching_col.checkbox(
+ "Prefix Caching", key=f"auxiliary_model_{idx}_enable_prefix_caching"
+ )
+ enable_chunked_prefill_col.checkbox(
+ "Chunked Prefill", key=f"auxiliary_model_{idx}_enable_chunked_prefill"
+ )
+
+ enable_thinking_col, enable_openai_api = st.columns(2)
+ enable_thinking_col.checkbox(
+ "Enable Thinking For Qwen3", key=f"auxiliary_model_{idx}_enable_thinking"
+ )
+ enable_openai_api.checkbox("Enable OpenAI API", key=f"auxiliary_model_{idx}_enable_openai_api")
+
+
+@CONFIG_GENERATORS.register_config(other_configs={"_auxiliary_models_num": 0})
+def set_auxiliary_models(**kwargs):
+ if st.button("Add Auxiliary Models"):
+ idx = st.session_state["_auxiliary_models_num"]
+ st.session_state[f"auxiliary_model_{idx}_engine_num"] = 1
+ st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"] = 1
+ st.session_state[f"auxiliary_model_{idx}_gpu_memory_utilization"] = 0.9
+ st.session_state[f"auxiliary_model_{idx}_seed"] = 42
+ st.session_state[f"auxiliary_model_{idx}_use_v1"] = True
+ st.session_state[f"auxiliary_model_{idx}_enforce_eager"] = True
+ st.session_state["_auxiliary_models_num"] += 1
+ set_trainer_gpu_num()
+ if st.session_state["_auxiliary_models_num"] > 0:
+ tabs = st.tabs(
+ [f"Auxiliary Model {i + 1}" for i in range(st.session_state["_auxiliary_models_num"])]
+ )
+ for idx, tab in enumerate(tabs):
+ with tab:
+ _set_auxiliary_model_idx(idx)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_auxiliary_models(unfinished_fields: set, key: str):
+ if st.session_state["trainer_gpu_num"] <= 0:
+ unfinished_fields.add("engine_num")
+ unfinished_fields.add("tensor_parallel_size")
+ st.warning(
+ "Please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that at least one GPU is reserved for the `trainer`."
+ )
+ elif (
+ st.session_state["node_num"] > 1
+ and st.session_state["trainer_gpu_num"] % st.session_state["gpu_per_node"] != 0
+ ):
+ unfinished_fields.add("engine_num")
+ unfinished_fields.add("tensor_parallel_size")
+ st.warning(
+ "When `node_num > 1`, please check the settings of each `engine_num` and `tensor_marallel_size` to ensure that the number of GPUs reserved for the `trainer` is divisible by `gpu_per_node`"
+ )
+
+
+# Synchronizer Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=SyncMethod.NCCL.value,
+ visible=explorer_visible,
+ other_configs={"_not_dpo_sync_method": SyncMethod.NCCL.value},
+)
+def set_sync_method(**kwargs):
+ key = kwargs.get("key")
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state[key] = SyncMethod.CHECKPOINT.value
+ disabled = True
+ else:
+ st.session_state[key] = st.session_state["_not_dpo_sync_method"]
+ disabled = False
+
+ def on_change():
+ if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ st.session_state["_not_dpo_sync_method"] = st.session_state[key]
+
+ st.selectbox(
+ "Sync Method",
+ [sync_method.value for sync_method in SyncMethod],
+ help="""`nccl`: the explorer and trainer sync model weights once every `sync_interval` steps.
+
+`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_interval`.""",
+ disabled=disabled,
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=10, visible=explorer_visible)
+def set_sync_interval(**kwargs):
+ st.number_input(
+ "Sync Interval",
+ min_value=1,
+ help="""The step interval at which the `explorer` and `trainer` synchronize model weight.""",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1200, visible=explorer_visible)
+def set_sync_timeout(**kwargs):
+ st.number_input(
+ "Sync Timeout",
+ min_value=1,
+ help="The timeout value for the synchronization operation.",
+ **kwargs,
+ )
diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py
new file mode 100644
index 0000000000..837bf27679
--- /dev/null
+++ b/trinity/manager/config_registry/model_config_manager.py
@@ -0,0 +1,206 @@
+import os
+
+import streamlit as st
+
+from trinity.common.constants import AlgorithmType, MonitorType
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+from trinity.manager.config_registry.trainer_config_manager import use_critic
+from trinity.trainer.verl.ray_trainer import AdvantageEstimator
+
+
+def set_total_gpu_num():
+ st.session_state["total_gpu_num"] = (
+ st.session_state["gpu_per_node"] * st.session_state["node_num"]
+ )
+ set_trainer_gpu_num()
+
+
+def set_trainer_gpu_num():
+ if st.session_state["mode"] == "both":
+ trainer_gpu_num = (
+ st.session_state["total_gpu_num"]
+ - st.session_state["engine_num"] * st.session_state["tensor_parallel_size"]
+ )
+ for idx in range(st.session_state["_auxiliary_models_num"]):
+ engine_num = st.session_state[f"auxiliary_model_{idx}_engine_num"]
+ tensor_parallel_size = st.session_state[f"auxiliary_model_{idx}_tensor_parallel_size"]
+ trainer_gpu_num -= engine_num * tensor_parallel_size
+ st.session_state["trainer_gpu_num"] = trainer_gpu_num
+ else: # model == train
+ st.session_state["trainer_gpu_num"] = st.session_state["total_gpu_num"]
+
+
+@CONFIG_GENERATORS.register_config(default_value="Trinity-RFT")
+def set_project(**kwargs):
+ st.text_input("Project", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="qwen2.5-1.5B")
+def set_exp_name(**kwargs):
+ st.text_input("Experiment Name", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="")
+def set_checkpoint_root_dir(**kwargs):
+ st.text_input("Checkpoint Root Dir", **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_checkpoint_root_dir(unfinished_fields: set, key: str):
+ if not st.session_state[key].strip(): # TODO: may auto generate
+ unfinished_fields.add(key)
+ st.warning("Please input checkpoint root dir.")
+ elif not os.path.isabs(st.session_state[key].strip()):
+ unfinished_fields.add("checkpoint_root_dir")
+ st.warning("Please input an absolute path.")
+
+
+@CONFIG_GENERATORS.register_config(default_value=MonitorType.TENSORBOARD.value)
+def set_monitor_type(**kwargs):
+ st.selectbox(
+ "Monitor Type",
+ options=[monitor_type.value for monitor_type in MonitorType],
+ **kwargs,
+ )
+
+
+# Algorithm Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=AlgorithmType.PPO.value,
+ other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value},
+)
+def set_algorithm_type(**kwargs):
+ def on_change():
+ if st.session_state["algorithm_type"] == AlgorithmType.PPO.value:
+ st.session_state["mode"] = "both"
+ st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value
+ elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value:
+ st.session_state["mode"] = "both"
+ st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
+ elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state["mode"] = "train"
+ st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
+ elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value:
+ st.session_state["mode"] = "both"
+ st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
+ else: # TODO: add more algorithms
+ pass
+ set_trainer_gpu_num()
+
+ st.selectbox(
+ "Algorithm Type",
+ [
+ AlgorithmType.PPO.value,
+ AlgorithmType.GRPO.value,
+ AlgorithmType.DPO.value,
+ AlgorithmType.OPMD.value,
+ ],
+ key="algorithm_type",
+ on_change=on_change,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=1,
+ visible=lambda: st.session_state["mode"] == "both",
+ other_configs={
+ "_grouped_adv_repeat_times": 2,
+ "_not_grouped_adv_repeat_times": 1,
+ },
+)
+def set_repeat_times(**kwargs): # TODO
+ key = kwargs.get("key")
+ grouped_adv_algorithms = [
+ AlgorithmType.GRPO.value,
+ AlgorithmType.OPMD.value, # TODO: may add rloo
+ ]
+ if st.session_state["algorithm_type"] in grouped_adv_algorithms:
+ min_repeat_times = 2
+ st.session_state[key] = st.session_state["_grouped_adv_repeat_times"]
+ else:
+ min_repeat_times = 1
+ st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"]
+
+ def on_change():
+ if st.session_state["algorithm_type"] in grouped_adv_algorithms:
+ st.session_state["_grouped_adv_repeat_times"] = st.session_state[key]
+ else:
+ st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key]
+
+ st.number_input(
+ "Repeat Times",
+ min_value=min_repeat_times,
+ help="`repeat_times` is used to set how many experiences each task can generate, "
+ "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1.0)
+def set_gamma(**kwargs):
+ st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=1.0)
+def set_lam(**kwargs):
+ st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs)
+
+
+# Model Configs
+
+
+@CONFIG_GENERATORS.register_config(default_value="")
+def set_model_path(**kwargs):
+ st.text_input("Model Path", **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_model_path(unfinished_fields: set, key: str):
+ if not st.session_state[key].strip():
+ unfinished_fields.add(key)
+ st.warning("Please input model path.")
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value="",
+ visible=use_critic,
+)
+def set_critic_model_path(**kwargs):
+ st.text_input(
+ "Critic Model Path (defaults to `model_path`)",
+ key="critic_model_path",
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1024)
+def set_max_prompt_tokens(**kwargs):
+ st.number_input("Max Prompt Tokens", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=1024)
+def set_max_response_tokens(**kwargs):
+ st.number_input("Max Response Tokens", min_value=1, **kwargs)
+
+
+# Cluster Config
+
+
+@CONFIG_GENERATORS.register_config(default_value=1)
+def set_node_num(**kwargs):
+ st.number_input("Node Num", min_value=1, on_change=set_total_gpu_num, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=8, other_configs={"total_gpu_num": 8, "trainer_gpu_num": 6}
+)
+def set_gpu_per_node(**kwargs):
+ st.number_input(
+ "GPU Per Node",
+ min_value=1,
+ max_value=8,
+ on_change=set_total_gpu_num,
+ **kwargs,
+ )
diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py
new file mode 100644
index 0000000000..d0f5d26897
--- /dev/null
+++ b/trinity/manager/config_registry/trainer_config_manager.py
@@ -0,0 +1,450 @@
+import streamlit as st
+
+from trinity.common.constants import AlgorithmType, SyncMethod
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+from trinity.trainer.verl.ray_trainer import AdvantageEstimator
+
+
+def use_critic():
+ return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value
+
+
+@CONFIG_GENERATORS.register_config(default_value="verl")
+def set_trainer_type(**kwargs):
+ st.selectbox("Trainer Type", ["verl"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=100, other_configs={"_nccl_save_interval": 100})
+def set_save_interval(**kwargs):
+ key = kwargs.get("key")
+ if (
+ st.session_state["algorithm_type"] == AlgorithmType.DPO.value
+ or st.session_state["sync_method"] == SyncMethod.NCCL.value
+ ):
+ st.session_state[key] = st.session_state["_nccl_save_interval"]
+ freeze_save_interval = False
+ else:
+ st.session_state[key] = st.session_state["sync_interval"]
+ freeze_save_interval = True
+
+ def on_change():
+ if (
+ st.session_state["algorithm_type"] == AlgorithmType.DPO.value
+ or st.session_state["sync_method"] == SyncMethod.NCCL.value
+ ):
+ st.session_state["_nccl_save_interval"] = st.session_state[key]
+
+ st.number_input(
+ "Save Interval",
+ min_value=1,
+ help="Set to `sync_interval` when `algorithm_type != DPO && sync_method == checkpoint`",
+ disabled=freeze_save_interval,
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=True)
+def set_enable_preview(**kwargs):
+ st.checkbox("Enable Preview", **kwargs)
+
+
+def _actor_use_kl_loss_visible():
+ if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ st.session_state["actor_use_kl_loss"] = True
+ return False
+ return True
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=True,
+ visible=_actor_use_kl_loss_visible,
+ other_configs={"_not_dpo_actor_use_kl_loss": True},
+)
+def set_actor_use_kl_loss(**kwargs):
+ key = kwargs.get("key")
+ st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"]
+
+ def on_change():
+ st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key]
+
+ st.checkbox("Use KL Loss", on_change=on_change, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
+)
+def set_actor_kl_loss_coef(**kwargs):
+ st.number_input(
+ r"KL Loss Coef :blue-badge[$\beta$]",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
+)
+def set_actor_entropy_coef(**kwargs):
+ st.number_input(
+ "Entropy Coeff",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1.0)
+def set_actor_grad_clip(**kwargs):
+ st.number_input(
+ "Grad Clip :blue-badge[(Actor)]",
+ min_value=0.0,
+ max_value=1.0,
+ help="Clipping by Norm",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.2)
+def set_actor_clip_ratio(**kwargs):
+ st.number_input(
+ r"Clip Ratio :blue-badge[$\epsilon$]",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+# veRL Trainer Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=[
+ "balance_batch",
+ "gradient_checkpointing",
+ "remove_padding",
+ "dynamic_bsz",
+ ]
+)
+def set_training_args(**kwargs):
+ st.multiselect(
+ "Training Args",
+ [
+ "balance_batch",
+ "gradient_checkpointing",
+ "remove_padding",
+ "dynamic_bsz",
+ ],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1)
+def set_ppo_epochs(**kwargs):
+ st.number_input("PPO Epochs", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="fsdp")
+def set_training_strategy(**kwargs):
+ st.selectbox(
+ "Training Strategy",
+ ["fsdp", "megatron"],
+ help="megatron is not tested",
+ **kwargs,
+ )
+
+
+def use_fsdp():
+ return st.session_state["training_strategy"] == "fsdp"
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp)
+def set_param_offload(**kwargs):
+ st.checkbox("FSDP Param Offload", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False, visible=use_fsdp)
+def set_optimizer_offload(**kwargs):
+ st.checkbox("FSDP Optimizer Offload", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="auto")
+def set_resume_mode(**kwargs):
+ st.selectbox("Resume Mode", ["disable", "auto", "resume_path"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value="", visible=lambda: st.session_state["resume_mode"] == "resume_path"
+)
+def set_resume_from_path(**kwargs):
+ st.text_input("Resume Path", **kwargs)
+
+
+@CONFIG_GENERATORS.register_check()
+def check_resume_from_path(unfinished_fields: set, key: str):
+ if st.session_state["resume_mode"] == "resume_path" and (
+ not st.session_state[key].strip() or "global_step_" not in st.session_state[key]
+ ):
+ unfinished_fields.add(key)
+ st.warning("Please input a valid resume path when `resume_mode == resume_path`")
+
+
+@CONFIG_GENERATORS.register_config(default_value=0)
+def set_critic_warmup(**kwargs):
+ st.number_input("Critic Warmup Steps", min_value=0, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_total_training_steps(**kwargs):
+ st.number_input("Total Training Steps", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_default_hdfs_dir(**kwargs):
+ st.text_input("Default HDFS Dir", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False)
+def set_remove_previous_ckpt_in_save(**kwargs):
+ st.checkbox("Remove Previous Checkpoint in Save", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False)
+def set_del_local_ckpt_after_load(**kwargs):
+ st.checkbox("Delete Local Checkpoint After Load", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_max_actor_ckpt_to_keep(**kwargs):
+ st.number_input("Max Actor Checkpoint to Keep", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=None)
+def set_max_critic_ckpt_to_keep(**kwargs):
+ st.number_input("Max Critic Checkpoint to Keep", min_value=1, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=True)
+def set_norm_adv_by_std_in_grpo(**kwargs):
+ st.checkbox("Norm Adv by Std in GRPO", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=False)
+def set_use_kl_in_reward(**kwargs):
+ st.checkbox("Use KL in Reward", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
+def set_kl_penalty(**kwargs):
+ st.selectbox("KL Penalty", ["kl", "abs", "mse", "low_var_kl"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="fixed")
+def set_kl_ctrl_type(**kwargs):
+ st.selectbox("KL Ctrl Type", ["fixed", "adaptive"], **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.001)
+def set_kl_ctrl_coef(**kwargs):
+ st.number_input("KL Ctrl Coef", format="%.1e", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=10000)
+def set_horizon(**kwargs):
+ st.number_input("Horizon", min_value=1.0, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.1)
+def set_target_kl(**kwargs):
+ st.number_input("Target KL", format="%.1e", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value=4)
+def set_actor_ppo_micro_batch_size_per_gpu(**kwargs):
+ key = kwargs.get("key")
+ max_value = st.session_state["_train_batch_size_per_gpu"]
+ st.session_state[key] = min(st.session_state[key], max_value)
+ st.number_input(
+ "Micro Batch Size Per GPU :blue-badge[(Actor)]", min_value=1, max_value=max_value, **kwargs
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=8)
+def set_ref_log_prob_micro_batch_size_per_gpu(**kwargs):
+ key = kwargs.get("key")
+ max_value = st.session_state["_train_batch_size_per_gpu"]
+ st.session_state[key] = min(st.session_state[key], max_value)
+ st.number_input(
+ "Micro Batch Size Per GPU :blue-badge[(Ref)]", min_value=1, max_value=max_value, **kwargs
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1)
+def set_actor_ulysses_sequence_parallel_size(**kwargs):
+ st.number_input(
+ "Ulysses Sequence Parallel Size",
+ min_value=1,
+ max_value=8,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1e-6)
+def set_actor_lr(**kwargs):
+ st.number_input(
+ "Learning Rate :blue-badge[(Actor)]",
+ min_value=1e-7,
+ max_value=1e-3,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value="constant")
+def set_actor_warmup_style(**kwargs):
+ st.selectbox(
+ "LR Warmup Style :blue-badge[(Actor)]",
+ ["constant", "cosine"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.0)
+def set_actor_lr_warmup_steps_ratio(**kwargs):
+ st.number_input(
+ "LR Warmup Steps Ratio :blue-badge[(Actor)]",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd"
+)
+def set_actor_tau(**kwargs):
+ st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd"
+)
+def set_actor_opmd_baseline(**kwargs):
+ st.selectbox(
+ "OPMD Baseline",
+ ["mean", "logavgexp"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd"
+)
+def set_actor_use_uid(**kwargs):
+ st.checkbox("Use UID for OPMD", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
+def set_actor_kl_loss_type(**kwargs):
+ st.selectbox(
+ "KL Loss Type",
+ ["kl", "abs", "mse", "low_var_kl"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=["model", "hf_model", "optimizer", "extra"])
+def set_actor_checkpoint(**kwargs):
+ st.multiselect(
+ "Checkpoint",
+ ["model", "hf_model", "optimizer", "extra"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1e-6, visible=use_critic)
+def set_critic_lr(**kwargs):
+ st.number_input(
+ "Learning Rate :blue-badge[(Critic)]",
+ min_value=1e-7,
+ max_value=1e-3,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value="constant", visible=use_critic)
+def set_critic_warmup_style(**kwargs):
+ st.selectbox(
+ "LR Warmup Style :blue-badge[(Critic)]",
+ ["constant", "cosine"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.0, visible=use_critic)
+def set_critic_lr_warmup_steps_ratio(**kwargs):
+ st.number_input(
+ "LR Warmup Steps Ratio :blue-badge[(Critic)]",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1.0, visible=use_critic)
+def set_critic_grad_clip(**kwargs):
+ st.number_input(
+ "Grad Clip :blue-badge[(Critic)]",
+ min_value=0.0,
+ max_value=1.0,
+ help="Clipping by Norm",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=0.5, visible=use_critic)
+def set_critic_cliprange_value(**kwargs):
+ st.number_input(
+ "Cliprange Value",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=8, visible=use_critic)
+def set_critic_ppo_micro_batch_size_per_gpu(**kwargs):
+ key = kwargs.get("key")
+ max_value = st.session_state["_train_batch_size_per_gpu"]
+ st.session_state[key] = min(st.session_state[key], max_value)
+ st.number_input(
+ "Micro Batch Size Per GPU :blue-badge[(Critic)]",
+ min_value=1,
+ max_value=max_value,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(default_value=1, visible=use_critic)
+def set_critic_ulysses_sequence_parallel_size(**kwargs):
+ st.number_input(
+ "Ulysses Sequence Parallel Size",
+ min_value=1,
+ max_value=8,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=["model", "optimizer", "extra"], visible=use_critic
+)
+def set_critic_checkpoint(**kwargs):
+ st.multiselect(
+ "Checkpoint",
+ ["model", "hf_model", "optimizer", "extra"],
+ **kwargs,
+ )
diff --git a/trinity/plugins/__init__.py b/trinity/plugins/__init__.py
new file mode 100644
index 0000000000..1b8629c9ca
--- /dev/null
+++ b/trinity/plugins/__init__.py
@@ -0,0 +1 @@
+"""Add your custom modules to this directory."""
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 595084ac02..616234d0d6 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -314,11 +314,6 @@ def update_policy(self, data: DataProto): # noqa: C901
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
- # TODO: for pairwise_opmd and use_uid, is it necessary to somehow sort samples within batch by uid,
- # to ensure that there are samples with the same uid within each micro-batch
- # (at which level pairwise loss is computed)?
- # (In comparison, advantage is computed at the level of batch, same for opmd, grpo, etc.)
-
metrics = {}
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index d040c329dd..e7eb8a209b 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -40,7 +40,7 @@
from trinity.common.config import Config
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
-from trinity.utils.monitor import Monitor
+from trinity.utils.monitor import MONITOR
class _InternalDataLoader:
@@ -145,7 +145,7 @@ def __init__(
)
self.init_workers()
- self.logger = Monitor(
+ self.logger = MONITOR.get(global_config.monitor.monitor_type)(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
role="trainer",
diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py
index 8d7a7d3a06..b250d856d6 100644
--- a/trinity/utils/dlc_utils.py
+++ b/trinity/utils/dlc_utils.py
@@ -9,6 +9,20 @@
logger = get_logger(__name__)
+CLUSTER_ACTOR_NAME = "cluster_status"
+
+
+@ray.remote
+class ClusterStatus:
+ def __init__(self):
+ self.finished = False
+
+ def finish(self) -> None:
+ self.finished = True
+
+ def running(self) -> bool:
+ return not self.finished
+
def get_dlc_env_vars() -> dict:
envs = {
@@ -71,16 +85,40 @@ def setup_ray_cluster(namespace: str):
logger.error(f"ret.stdout: {ret.stdout!r}")
logger.error(f"ret.stderr: {ret.stderr!r}")
sys.exit(1)
+
+ wait_for_ray_setup()
+ ray.init(
+ address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
+ namespace=namespace,
+ ignore_reinit_error=True,
+ )
if is_master:
- wait_for_ray_setup()
- ray.init(
- address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
- namespace=namespace,
- ignore_reinit_error=True,
- )
# master wait for worker nodes to join
wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"])
+ else:
+ # woker wait on the cluster status actor
+ cluster_status = ClusterStatus.options(
+ name=CLUSTER_ACTOR_NAME,
+ get_if_exists=True,
+ ).remote()
+ while True:
+ if ray.get(cluster_status.running.remote()):
+ ret = subprocess.run("ray status", shell=True, capture_output=True)
+ print(ret.stdout.decode())
+ time.sleep(5)
+ else:
+ logger.info("Ray cluster is not running, exiting.")
+ break
+ sys.exit(0)
+
- if not is_master:
- # woker just exit
- sys.exit(0)
+def stop_ray_cluster():
+ """
+ Stop the ray cluster by sending a signal to the cluster status actor.
+ """
+ cluster_status = ClusterStatus.options(
+ name=CLUSTER_ACTOR_NAME,
+ get_if_exists=True,
+ ).remote()
+ ray.get(cluster_status.finish.remote())
+ logger.info("Stopping ray cluster...")
diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py
index e3aa216eda..e80afaf59b 100644
--- a/trinity/utils/eval_utils.py
+++ b/trinity/utils/eval_utils.py
@@ -15,12 +15,19 @@ def simple_answer_parser(response: str) -> str:
return parse(response)
-def find_boxed_answer(string):
+def find_boxed_answer(raw_answer, timeout=10):
"""
- Find answers from solutions where the answers are enclosed in LaTeX's `\boxed` tag
+ Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag
+
+ Args:
+ raw_answer (`str`): raw answer from model
+ timeout (`int`): timeout in seconds for regex
+
+ Returns:
+ `str`: answer if found, otherwise None
"""
pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))"
- res = re.findall(pattern, string)
+ res = re.findall(pattern, raw_answer, timeout=timeout)
if res:
answer = res[-1][0] # regard the last boxed as the answer
if answer.startswith("{"):
diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py
index 3044c6dcc8..f12a854335 100644
--- a/trinity/utils/monitor.py
+++ b/trinity/utils/monitor.py
@@ -1,5 +1,7 @@
"""Monitor"""
+
import os
+from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
@@ -8,11 +10,13 @@
from torch.utils.tensorboard import SummaryWriter
from trinity.common.config import Config
-from trinity.common.constants import MonitorType
from trinity.utils.log import get_logger
+from trinity.utils.registry import Registry
+
+MONITOR = Registry("monitor")
-class Monitor:
+class Monitor(ABC):
"""Monitor"""
def __init__(
@@ -22,15 +26,25 @@ def __init__(
role: str,
config: Config = None, # pass the global Config for recording
) -> None:
- if config.monitor.monitor_type == MonitorType.WANDB:
- self.logger = WandbLogger(project, name, role, config)
- elif config.monitor.monitor_type == MonitorType.TENSORBOARD:
- self.logger = TensorboardLogger(project, name, role, config)
- else:
- raise ValueError(f"Unknown monitor type: {config.monitor.monitor_type}")
+ self.project = project
+ self.name = name
+ self.role = role
+ self.config = config
+ @abstractmethod
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
- self.logger.log_table(table_name, experiences_table, step=step)
+ """Log a table"""
+
+ @abstractmethod
+ def log(self, data: dict, step: int, commit: bool = False) -> None:
+ """Log metrics."""
+
+ @abstractmethod
+ def close(self) -> None:
+ """Close the monitor"""
+
+ def __del__(self) -> None:
+ self.close()
def calculate_metrics(
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
@@ -51,15 +65,9 @@ def calculate_metrics(
metrics[key] = val
return metrics
- def log(self, data: dict, step: int, commit: bool = False) -> None:
- """Log metrics."""
- self.logger.log(data, step=step, commit=commit)
-
- def close(self) -> None:
- self.logger.close()
-
-class TensorboardLogger:
+@MONITOR.register_module("tensorboard")
+class TensorboardMonitor(Monitor):
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard")
os.makedirs(self.tensorboard_dir, exist_ok=True)
@@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
def close(self) -> None:
self.logger.close()
- def __del__(self) -> None:
- self.logger.close()
-
-class WandbLogger:
+@MONITOR.register_module("wandb")
+class WandbMonitor(Monitor):
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.logger = wandb.init(
project=project,
@@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
def close(self) -> None:
self.logger.finish()
-
- def __del__(self) -> None:
- self.logger.finish()
diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py
new file mode 100644
index 0000000000..a5a779ae83
--- /dev/null
+++ b/trinity/utils/plugin_loader.py
@@ -0,0 +1,65 @@
+"""Load modules from custom directory"""
+
+import importlib
+import os
+import shutil
+import sys
+from pathlib import Path
+
+from trinity.utils.log import get_logger
+
+logger = get_logger(__name__)
+
+
+def load_plugins(plugin_dir: str) -> None:
+ """
+ Load plugin modules from a directory.
+ """
+ if plugin_dir is None:
+ plugin_dir = Path(__file__).parent.parent / "plugins"
+ if not os.path.exists(plugin_dir):
+ logger.error(f"--plugin-dir [{plugin_dir}] does not exist.")
+ return None
+ if not os.path.isdir(plugin_dir):
+ logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.")
+ return None
+
+ logger.info(f"Loading plugin modules from [{plugin_dir}]...")
+ for file in Path(plugin_dir).glob("*.py"):
+ if file.name.startswith("__"):
+ continue
+ logger.info(f"Loading plugin modules from [{file}]...")
+ # load modules from file
+ load_from_file(os.path.join(plugin_dir, file))
+
+
+def load_from_file(file_path: str):
+ """
+ Load modules from a Python file
+
+ Args:
+ file_path (`str`): The python file path.
+
+ Returns:
+ `Any`: The loaded module.
+ """
+ module_name = os.path.splitext(os.path.basename(file_path))[0]
+
+ full_module_name = f"trinity.plugins.{module_name}"
+
+ spec = importlib.util.spec_from_file_location(full_module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load module from {file_path}")
+
+ module = importlib.util.module_from_spec(spec)
+
+ module.__package__ = "trinity.plugins"
+
+ spec.loader.exec_module(module)
+
+ if full_module_name in sys.modules:
+ raise ImportError(f"Module {module_name} already exists.")
+ sys.modules[full_module_name] = module
+ shutil.copy2(file_path, Path(__file__).parent.parent / "plugins")
+ logger.info(f"Load {file_path} as {full_module_name}")
+ return module
diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py
index b31f6872bd..d5ee37f36e 100644
--- a/trinity/utils/registry.py
+++ b/trinity/utils/registry.py
@@ -1,21 +1,4 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# --------------------------------------------------------
-# Most of the code here has been modified from:
-# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py
-# --------------------------------------------------------
+from typing import Any, Type
from trinity.utils.log import get_logger
@@ -25,59 +8,57 @@
# TODO: support lazy load
# e.g. @MODULES.register_module("name", lazy=True)
class Registry(object):
- """This class is used to register some modules to registry by a repo
- name."""
+ """A class for registry."""
def __init__(self, name: str):
"""
- Initialization method.
-
- :param name: a registry repo name
+ Args:
+ name (`str`): The name of the registry.
"""
self._name = name
self._modules = {}
@property
- def name(self):
+ def name(self) -> str:
"""
Get name of current registry.
- :return: name of current registry.
+ Returns:
+ `str`: The name of current registry.
"""
return self._name
@property
- def modules(self):
+ def modules(self) -> dict:
"""
Get all modules in current registry.
- :return: a dict storing modules in current registry.
+ Returns:
+ `dict`: A dict storing modules in current registry.
"""
return self._modules
- def list(self):
+ def list(self) -> None:
"""Logging the list of module in current registry."""
for m in self._modules.keys():
logger.info(f"{self._name}\t{m}")
- def get(self, module_key):
+ def get(self, module_key) -> Any:
"""
Get module named module_key from in current registry. If not found,
return None.
- :param module_key: specified module name
- :return: module named module_key
+ Args:
+ module_key (`str`): specified module name
+
+ Returns:
+ `Any`: the module object
"""
return self._modules.get(module_key, None)
def _register_module(self, module_name=None, module_cls=None, force=False):
"""
Register module to registry.
-
- :param module_name: module name
- :param module_cls: module class object
- :param force: Whether to override an existing class with the
- same name. Default: False.
"""
if module_name is None:
@@ -89,25 +70,35 @@ def _register_module(self, module_name=None, module_cls=None, force=False):
self._modules[module_name] = module_cls
module_cls._name = module_name
- def register_module(self, module_name: str = None, module_cls: type = None, force=False):
+ def register_module(self, module_name: str, module_cls: Type = None, force=False, lazy=False):
"""
- Register module class object to registry with the specified modulename.
+ Register module class object to registry with the specified module name.
- :param module_name: module name
- :param module_cls: module class object
- :param force: Whether to override an existing class with
- the same name. Default: False.
+ Args:
+ module_name (`str`): The module name.
+ module_cls (`Type`): module class object
+ force (`bool`): Whether to override an existing class with
+ the same name. Default: False.
+ lazy (`bool`): Whether to register the module class object lazily.
+ Default: False.
Example:
- >>> registry = Registry()
- >>> @registry.register_module()
- >>> class TextFormatter:
- >>> pass
-
- >>> class TextFormatter2:
- >>> pass
- >>> registry.register_module( module_name='text_formatter2',
- module_cls=TextFormatter2)
+ ```python
+ WORKFLOWS = Registry("workflows")
+
+ # register a module using decorator
+ @WORKFLOWS.register_module(name="workflow_name")
+ class MyWorkflow(Workflow):
+ pass
+
+ # or register a module directly
+ WORKFLOWS.register_module(
+ name="workflow_name",
+ module_cls=MyWorkflow,
+ force=True,
+ )
+ ```
+
"""
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f"module_name must be either of None, str," f"got {type(module_name)}")
@@ -120,8 +111,10 @@ def _register(module_cls):
"""
Register module class object to registry.
- :param module_cls: module class object
- :return: module class object.
+ Args:
+ module_cls (`Type`): module class object
+ Returns:
+ `Type`: Decorated module class object.
"""
self._register_module(module_name=module_name, module_cls=module_cls, force=force)
return module_cls
From 3c759d94061d0f72109043437f5e7e598da14a49 Mon Sep 17 00:00:00 2001
From: pxc
Date: Wed, 11 Jun 2025 19:12:52 +0800
Subject: [PATCH 10/28] fix entropy lss
---
trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
index e575caa449..41583ec3ba 100644
--- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
+++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
@@ -65,8 +65,8 @@ class DummyEntropyLossFn(EntropyLossFn):
Dummy entropy loss function.
"""
- def __init__(self):
- pass
+ def __init__(self, entropy_coef: float):
+ self.entropy_coef = entropy_coef
def __call__(
self,
From dc8cb0c52e98bb3aefc0ebd34e4e95b1050258c7 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Thu, 12 Jun 2025 15:41:20 +0800
Subject: [PATCH 11/28] Add Sample Strategy (#78)
---
.../source/tutorial/trinity_configs.md | 19 ++-
trinity/algorithm/__init__.py | 3 +
trinity/algorithm/algorithm.py | 22 ++--
.../entropy_loss_fn/entropy_loss_fn.py | 4 +-
trinity/algorithm/sample_strategy/__init__.py | 13 ++
.../sample_strategy/sample_strategy.py | 114 ++++++++++++++++++
trinity/algorithm/sample_strategy/utils.py | 78 ++++++++++++
trinity/common/config.py | 22 +++-
trinity/trainer/trainer.py | 40 +-----
trinity/trainer/verl_trainer.py | 77 ++++--------
trinity/utils/timer.py | 18 +++
11 files changed, 292 insertions(+), 118 deletions(-)
create mode 100644 trinity/algorithm/sample_strategy/__init__.py
create mode 100644 trinity/algorithm/sample_strategy/sample_strategy.py
create mode 100644 trinity/algorithm/sample_strategy/utils.py
create mode 100644 trinity/utils/timer.py
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 8cb8856fbc..dbb8402ceb 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -79,14 +79,25 @@ Specifies the algorithm type and its related hyperparameters.
algorithm:
algorithm_type: grpo
repeat_times: 1
- gamma: 1.0
- lam: 1.0
+
+ # The following parameters are optional
+ # If not specified, they will automatically be set based on the `algorithm_type`
+ sample_strategy: "default"
+ advantage_fn: "ppo"
+ kl_penalty_fn: "none"
+ kl_loss_fn: "k2"
+ entropy_loss_fn: "default"
```
- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`.
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`.
-- `gamma`: Discount factor for future rewards. Default is `1.0`.
-- `lam`: Lambda value for Generalized Advantage Estimation (GAE). Default is `1.0`.
+
+- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
+- `advantage_fn`: The advantage function used for computing advantages.
+- `kl_penalty_fn`: The KL penalty function used for computing KL penalty.
+- `kl_loss_fn`: The KL loss function used for computing KL loss.
+- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.
+
---
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
index 101364c57c..ff52f609e5 100644
--- a/trinity/algorithm/__init__.py
+++ b/trinity/algorithm/__init__.py
@@ -2,6 +2,7 @@
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy
__all__ = [
"AdvantageFn",
@@ -12,4 +13,6 @@
"KL_FN",
"EntropyLossFn",
"ENTROPY_LOSS_FN",
+ "SampleStrategy",
+ "SAMPLE_STRATEGY",
]
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index f94798fe85..88b9b946b7 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -7,7 +7,6 @@
from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
-from trinity.common.experience import Experience, Experiences
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
@@ -31,10 +30,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
can_balance_batch: bool
schema: type
- @classmethod
- def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
- return Experiences.gather_experiences(exps, pad_token_id)
-
@classmethod
def get_default_config(cls) -> Dict:
raise NotImplementedError
@@ -62,6 +57,7 @@ class SFTAlgorithm(AlgorithmType):
@classmethod
def get_default_config(cls) -> Dict:
return {
+ "sample_strategy": "default",
"policy_loss_fn": "sft",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
@@ -83,11 +79,12 @@ class PPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 1,
+ "sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
- "entropy_loss_fn": "basic",
+ "entropy_loss_fn": "default",
}
@@ -106,11 +103,12 @@ class GRPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
+ "sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "grpo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
- "entropy_loss_fn": "basic",
+ "entropy_loss_fn": "default",
}
@@ -129,11 +127,12 @@ class OPMDAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
+ "sample_strategy": "warmup",
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
- "entropy_loss_fn": "basic",
+ "entropy_loss_fn": "default",
}
@@ -148,17 +147,14 @@ class DPOAlgorithm(AlgorithmType):
can_balance_batch: bool = False
schema: type = DPODataModel
- @classmethod
- def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
- return Experiences.gather_dpo_experiences(exps, pad_token_id)
-
@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2, # fake repeat times
+ "sample_strategy": "dpo",
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
- "entropy_loss_fn": "basic",
+ "entropy_loss_fn": "default",
}
@classmethod
diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
index 41583ec3ba..d6179a832c 100644
--- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
+++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
@@ -40,8 +40,8 @@ def default_args(cls) -> Dict:
return {"entropy_coef": 0.0}
-@ENTROPY_LOSS_FN.register_module("basic")
-class BasicEntropyLossFn(EntropyLossFn):
+@ENTROPY_LOSS_FN.register_module("default")
+class DefaultEntropyLossFn(EntropyLossFn):
"""
Basic entropy loss function.
"""
diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py
new file mode 100644
index 0000000000..60f2e268ae
--- /dev/null
+++ b/trinity/algorithm/sample_strategy/__init__.py
@@ -0,0 +1,13 @@
+from trinity.algorithm.sample_strategy.sample_strategy import (
+ SAMPLE_STRATEGY,
+ DefaultSampleStrategy,
+ SampleStrategy,
+ WarmupSampleStrategy,
+)
+
+__all__ = [
+ "SAMPLE_STRATEGY",
+ "SampleStrategy",
+ "DefaultSampleStrategy",
+ "WarmupSampleStrategy",
+]
diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py
new file mode 100644
index 0000000000..8686a0d497
--- /dev/null
+++ b/trinity/algorithm/sample_strategy/sample_strategy.py
@@ -0,0 +1,114 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Tuple
+
+from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto
+from trinity.buffer import get_buffer_reader
+from trinity.common.config import BufferConfig
+from trinity.common.experience import Experiences
+from trinity.utils.registry import Registry
+from trinity.utils.timer import Timer
+
+SAMPLE_STRATEGY = Registry("sample_strategy")
+
+
+class SampleStrategy(ABC):
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ self.pad_token_id = buffer_config.pad_token_id
+ self.trainer_type = trainer_type
+
+ @abstractmethod
+ def sample(self, step: int) -> Tuple[Any, Dict, List]:
+ """Sample experiences from buffer.
+
+ Args:
+ step (`int`): The step number of current step.
+
+ Returns:
+ `Any`: The sampled experiences.
+ `Dict`: Metrics for logging.
+ `List`: Representative experiences for logging.
+ """
+
+ @classmethod
+ def default_args(cls) -> dict:
+ return {}
+
+
+@SAMPLE_STRATEGY.register_module("warmup")
+class WarmupSampleStrategy(SampleStrategy):
+ """The default sample strategy."""
+
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ super().__init__(buffer_config, trainer_type)
+ self.exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
+ )
+ self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps
+ if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None:
+ raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0")
+ if buffer_config.trainer_input.sft_warmup_dataset is not None:
+ self.sft_buffer = get_buffer_reader(
+ buffer_config.trainer_input.sft_warmup_dataset, buffer_config
+ )
+ else:
+ self.sft_buffer = None
+
+ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
+ metrics = {}
+ with Timer(metrics, "read_time"):
+ if step <= self.sft_warmup_steps:
+ exp_list = self.sft_buffer.read()
+ else:
+ exp_list = self.exp_buffer.read()
+ repr_samples = representative_sample(exp_list)
+ with Timer(metrics, "gather_time"):
+ exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
+ if self.trainer_type == "verl":
+ with Timer(metrics, "convert_time"):
+ data = to_data_proto(exps)
+ return data, metrics, repr_samples
+ else:
+ raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+
+
+@SAMPLE_STRATEGY.register_module("default")
+class DefaultSampleStrategy(SampleStrategy):
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ super().__init__(buffer_config, trainer_type)
+ self.exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
+ )
+
+ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
+ metrics = {}
+ with Timer(metrics, "read_time"):
+ exp_list = self.exp_buffer.read()
+ repr_samples = representative_sample(exp_list)
+ with Timer(metrics, "gather_time"):
+ exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
+ if self.trainer_type == "verl":
+ with Timer(metrics, "convert_time"):
+ data = to_data_proto(exps)
+ return data, metrics, repr_samples
+ else:
+ raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+
+
+@SAMPLE_STRATEGY.register_module("dpo")
+class DPOSampleStrategy(WarmupSampleStrategy):
+ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
+ metrics = {}
+ with Timer(metrics, "read_time"):
+ if step <= self.sft_warmup_steps:
+ exp_list = self.sft_buffer.read()
+ else:
+ exp_list = self.exp_buffer.read()
+ repr_samples = representative_sample(exp_list)
+ with Timer(metrics, "gather_time"):
+ exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore
+ if self.trainer_type == "verl":
+ with Timer(metrics, "convert_time"):
+ data = to_data_proto(exps)
+ return data, metrics, repr_samples
+ else:
+ raise NotImplementedError(f"backend {self.trainer_type} is not supported")
diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py
new file mode 100644
index 0000000000..8c443a20b1
--- /dev/null
+++ b/trinity/algorithm/sample_strategy/utils.py
@@ -0,0 +1,78 @@
+import random
+from typing import List
+
+import numpy as np
+import torch
+from verl.trainer.ppo.ray_trainer import DataProto
+
+from trinity.common.experience import Experience, Experiences
+
+
+def to_data_proto(experiences: Experiences) -> DataProto:
+ attention_mask = experiences.attention_masks
+ cumsum = torch.cumsum(attention_mask, dim=-1)
+ position_ids = torch.clip(cumsum - 1, 0, None).long()
+ batch_dict = {
+ "uid": np.array(experiences.run_ids),
+ "position_ids": position_ids,
+ "input_ids": experiences.tokens.long(),
+ "responses": experiences.tokens[:, experiences.prompt_length :].long(),
+ "attention_mask": attention_mask.long(),
+ "response_mask": (
+ experiences.action_masks[:, experiences.prompt_length :].long()
+ if hasattr(experiences, "action_masks") and experiences.action_masks is not None
+ else attention_mask[:, experiences.prompt_length :].long()
+ ),
+ }
+ if experiences.rewards is not None:
+ token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
+ eos_mask_idx = cumsum.argmax(dim=-1)
+ token_level_rewards[
+ torch.arange(experiences.batch_size), eos_mask_idx
+ ] = experiences.rewards
+ token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
+ batch_dict.update(
+ {
+ "token_level_scores": token_level_rewards,
+ "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
+ }
+ )
+ return DataProto.from_single_dict(batch_dict)
+
+
+def representative_sample(experiences: List[Experience]) -> List[dict]:
+ if experiences[0].reward is None:
+ sample = random.choice(experiences)
+ return [
+ {
+ "prompt": sample.prompt_text,
+ "response": sample.response_text,
+ }
+ ]
+ samples = []
+ min_reward_sample = None
+ max_reward_sample = None
+ for exp in experiences:
+ if exp.reward is None:
+ continue
+ if min_reward_sample is None or exp.reward < min_reward_sample.reward:
+ min_reward_sample = exp
+ if max_reward_sample is None or exp.reward > max_reward_sample.reward:
+ max_reward_sample = exp
+ if min_reward_sample is not None:
+ samples.append(
+ {
+ "prompt": min_reward_sample.prompt_text,
+ "response": min_reward_sample.response_text,
+ "reward": min_reward_sample.reward,
+ }
+ )
+ if max_reward_sample is not None:
+ samples.append(
+ {
+ "prompt": max_reward_sample.prompt_text,
+ "response": max_reward_sample.response_text,
+ "reward": max_reward_sample.reward,
+ }
+ )
+ return samples
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 22d8f3d711..7c371f4bcb 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -176,9 +176,8 @@ class AlgorithmConfig:
# for GRPO-like algorithms, repeat each task for `repeat_times` times
repeat_times: int = 1
- policy_loss_fn: Optional[str] = None # "ppo"
- # If not set, use PolicyLossFn.default_args()
- policy_loss_fn_args: Optional[dict] = None
+ sample_strategy: Optional[str] = None
+ sample_strategy_args: Optional[dict] = None
advantage_fn: Optional[str] = None # "ppo"
# If not set, use AdvantageFn.default_args()
@@ -188,11 +187,15 @@ class AlgorithmConfig:
# If not set, use kl_penalty_fn.default_args()
kl_penalty_fn_args: Optional[dict] = None
+ policy_loss_fn: Optional[str] = None # "ppo"
+ # If not set, use PolicyLossFn.default_args()
+ policy_loss_fn_args: Optional[dict] = None
+
kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss
# If not set, use kl_loss_fn.default_args()
kl_loss_fn_args: Optional[dict] = None
- entropy_loss_fn: Optional[str] = None # "basic"
+ entropy_loss_fn: Optional[str] = None # "default"
# If not set, use entropy_loss_fn.default_args()
entropy_loss_fn_args: Optional[dict] = None
@@ -489,23 +492,32 @@ def _check_algorithm(self) -> None:
ENTROPY_LOSS_FN,
KL_FN,
POLICY_LOSS_FN,
+ SAMPLE_STRATEGY,
)
from trinity.algorithm.algorithm import ALGORITHM_TYPE
algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type)
algorithm.check_config(self)
default_config = {
+ "sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
- "entropy_loss_fn": "basic",
+ "entropy_loss_fn": "default",
}
default_config.update(algorithm.get_default_config())
for key, value in default_config.items():
if getattr(self.algorithm, key, None) is None:
setattr(self.algorithm, key, value)
+ # TODO: simplify the following code
+ sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy)
+ if sample_strategy_cls is None:
+ raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}")
+ if self.algorithm.sample_strategy_args is None:
+ self.algorithm.sample_strategy_args = sample_strategy_cls.default_args()
+
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 95859685ee..2920604fbb 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -12,9 +12,7 @@
import ray
-from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
from trinity.algorithm.algorithm_manager import AlgorithmManager
-from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.utils.log import get_logger
@@ -28,18 +26,6 @@ def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.algorithm_manager = AlgorithmManager(config)
- self.train_buffer = get_buffer_reader(
- self.config.buffer.trainer_input.experience_buffer, # type: ignore
- self.config.buffer,
- )
- self.sft_warmup_buffer = (
- get_buffer_reader(
- self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore
- self.config.buffer,
- )
- if self.config.buffer.trainer_input.sft_warmup_steps > 0
- else None
- )
self.engine = get_trainer_wrapper(config)
def prepare(self) -> None:
@@ -71,29 +57,7 @@ def train_step(self) -> Tuple[bool, int]:
Returns:
bool: Whether to continue training.
"""
- algo_config = self.algorithm_manager.get_current_algorithm_config(
- self.engine.train_step_num + 1
- )
- algo_type = algo_config.algorithm_type
- algorithm = ALGORITHM_TYPE.get(algo_type)
- if algorithm.use_rollout:
- strategy = self.config.buffer.trainer_input.read_experience_strategy
- else:
- strategy = None
- try:
- if algorithm == SFTAlgorithm:
- exps = self.sft_warmup_buffer.read()
- else:
- exps = self.train_buffer.read(strategy=strategy)
- except StopIteration:
- self.logger.warning("No more data to train. Stop training.")
- return False, self.engine.train_step_num
-
- experiences = algorithm.gather_experience(
- exps,
- pad_token_id=self.config.buffer.pad_token_id, # type: ignore
- )
- return self.engine.train_step(experiences)
+ return self.engine.train_step()
def sync_weight(self) -> None:
"""Sync the model weight."""
@@ -126,7 +90,7 @@ def train_step_num(self) -> int:
"""Get the current training step number."""
@abstractmethod
- def train_step(self, experiences) -> Tuple[bool, int]:
+ def train_step(self) -> Tuple[bool, int]:
"""Training."""
@abstractmethod
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 84da4cbf98..110a54a7db 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -6,9 +6,8 @@
import os
import sys
from pprint import pprint
-from typing import Tuple
+from typing import Dict, List, Tuple
-import numpy as np
import pandas as pd
import ray
import torch
@@ -20,7 +19,6 @@
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import (
- DataProto,
RayClassWithInitArgs,
RayPPOTrainer,
RayWorkerGroup,
@@ -33,7 +31,7 @@
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
-from trinity.algorithm import ADVANTAGE_FN, KL_FN
+from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY
from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.algorithm.utils import prefix_metrics
@@ -135,7 +133,11 @@ def __init__(
self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)(
**self.algorithm_config.kl_penalty_fn_args
)
-
+ self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)(
+ buffer_config=global_config.buffer,
+ trainer_type=global_config.trainer.trainer_type,
+ **global_config.algorithm.sample_strategy_args,
+ )
super().__init__(
config,
tokenizer,
@@ -237,9 +239,7 @@ def init_workers(self):
self.actor_rollout_wg.init_model()
def reset_experiences_example_table(self):
- self.experiences_example_table = pd.DataFrame(
- columns=["step", "reward", "prompt", "response"]
- )
+ self.sample_exps_to_log = []
@property
def train_step_num(self) -> int:
@@ -270,9 +270,15 @@ def _create_dataloader(self):
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
- def train_step(self, experiences: Experiences) -> Tuple[bool, int]:
- self.global_steps += 1
+ def train_step(self) -> Tuple[bool, int]: # noqa C901
metrics = {}
+ try:
+ batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
+ prefix_metrics(sample_metrics, "sample", metrics)
+ except StopIteration:
+ print("No more data to train. Stop training.")
+ return False, self.global_steps
+ self.global_steps += 1
timing_raw = {}
algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
@@ -283,39 +289,6 @@ def train_step(self, experiences: Experiences) -> Tuple[bool, int]:
self.algorithm = algorithm
with _timer("step", timing_raw):
- # Convert rewards to token_level_rewards
- attention_mask = experiences.attention_masks
- cumsum = torch.cumsum(attention_mask, dim=-1)
- position_ids = torch.clip(cumsum - 1, 0, None).long()
- batch_dict = {
- "uid": np.array(experiences.run_ids),
- "position_ids": position_ids,
- "input_ids": experiences.tokens.long(),
- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
- "attention_mask": attention_mask.long(),
- "response_mask": (
- experiences.action_masks[:, experiences.prompt_length :].long()
- if hasattr(experiences, "action_masks") and experiences.action_masks is not None
- else attention_mask[:, experiences.prompt_length :].long()
- ),
- }
- if self.algorithm.use_advantage:
- token_level_rewards = torch.zeros(
- attention_mask.shape, dtype=experiences.rewards.dtype
- )
- eos_mask_idx = cumsum.argmax(dim=-1)
- token_level_rewards[
- torch.arange(experiences.batch_size), eos_mask_idx
- ] = experiences.rewards
- token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
- batch_dict.update(
- {
- "token_level_scores": token_level_rewards,
- "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
- }
- )
-
- batch = DataProto.from_single_dict(batch_dict)
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
if self.algorithm.can_balance_batch and self.config.trainer.balance_batch:
@@ -381,7 +354,7 @@ def train_step(self, experiences: Experiences) -> Tuple[bool, int]:
)
if self.algorithm.use_advantage and self.config.enable_preview: # TODO
- self._log_experiences(experiences)
+ self._log_experiences(exp_samples)
# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
@@ -419,21 +392,13 @@ def _log_single_experience(
"response": [response_text],
}
)
- self.experiences_example_table = pd.concat(
- [self.experiences_example_table, new_row], ignore_index=True
- )
-
- def _log_experiences(self, experiences: Experiences) -> None:
- skip_special_tokens = False
- reward_max_id = torch.argmax(experiences.rewards)
- self._log_single_experience(experiences, reward_max_id, skip_special_tokens)
-
- reward_min_id = torch.argmin(experiences.rewards)
- self._log_single_experience(experiences, reward_min_id, skip_special_tokens)
+ self.sample_exps_to_log = pd.concat([self.sample_exps_to_log, new_row], ignore_index=True)
+ def _log_experiences(self, samples: List[Dict]) -> None:
+ self.sample_exps_to_log.extend(samples)
if self.global_steps % self.config.trainer.sync_freq == 0:
self.logger.log_table(
- "rollout_examples", self.experiences_example_table, self.global_steps
+ "rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps
)
self.reset_experiences_example_table()
diff --git a/trinity/utils/timer.py b/trinity/utils/timer.py
new file mode 100644
index 0000000000..5e80f406b8
--- /dev/null
+++ b/trinity/utils/timer.py
@@ -0,0 +1,18 @@
+"""Timer context manager"""
+
+import time
+
+
+class Timer:
+ def __init__(self, metrics_dict, key_name):
+ self.metrics = metrics_dict
+ self.key = key_name
+
+ def __enter__(self):
+ self.start_time = time.time()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ end_time = time.time()
+ elapsed_time = end_time - self.start_time
+ self.metrics[self.key] = elapsed_time
From 0e566079b9f49ba9550c3a3033d1366a82f6c6fe Mon Sep 17 00:00:00 2001
From: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com>
Date: Fri, 13 Jun 2025 18:14:20 +0800
Subject: [PATCH 12/28] Add doc for SFT (#81)
---
.../sphinx_doc/source/tutorial/example_dpo.md | 60 +++++++++++++++----
1 file changed, 50 insertions(+), 10 deletions(-)
diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md
index 44543ff2bc..b5846bc24b 100644
--- a/docs/sphinx_doc/source/tutorial/example_dpo.md
+++ b/docs/sphinx_doc/source/tutorial/example_dpo.md
@@ -1,6 +1,6 @@
-# Offline DPO
+# Offline DPO and SFT
-This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset).
+This example describes DPO and SFT based on the Qwen-2.5-1.5B-Instruct model.
## Step 1: Model and Data Preparation
@@ -20,7 +20,7 @@ More details of model downloading are referred to [ModelScope](https://modelscop
### Data Preparation
-Download the Human-Like-DPO-Dataset dataset to the local directory `$DATASET_PATH/human_like_dpo_dataset`:
+For DPO, we download the [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset) to the local directory `$DATASET_PATH/human_like_dpo_dataset`:
```shell
# Using Modelscope
@@ -34,9 +34,11 @@ More details of dataset downloading are referred to [ModelScope](https://modelsc
Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pass the proper keys to the config.
-## Step 2: Setup Configuration and Run Experiment
+For SFT, we download the dataset to the local directory `/PATH/TO/SFT_DATASET/`, which usually contains message-based data.
-### Configuration
+## Step 2: Setup Configuration
+
+### Configuration for DPO
We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following:
@@ -53,7 +55,7 @@ algorithm:
kl_coef: 0.1 # value of beta in DPO
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
- model_path: /PATH/TO/MODEL/
+ model_path: $MODEL_PATH/Qwen2.5-1.5B-Instruct
cluster:
node_num: 1
gpu_per_node: 8
@@ -62,9 +64,9 @@ buffer:
batch_size: 64
trainer_input:
experience_buffer:
- name: dpo_buffer
+ name: human_like_dpo
storage_type: file
- path: /PATH/TO/DATASET/
+ path: $DATASET_PATH/human_like_dpo_dataset
format:
prompt_type: plaintext # plaintext/messages/chatpair
prompt_key: prompt
@@ -75,10 +77,48 @@ trainer:
save_interval: 30
```
-### Run the Experiment
+### Configuration for SFT
+
+We set the `algorithm_type` as `sft` to run SFT process. Then we modify the config file `sft.yaml` with the following changes:
+
+```yaml
+project:
+name:
+mode: train
+algorithm:
+ algorithm_type: sft
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/
+model:
+ model_path: /PATH/TO/MODEL/
+cluster:
+ node_num: 1
+ gpu_per_node: 2
+buffer:
+ total_epochs: 5
+ batch_size: 64
+ trainer_input:
+ experience_buffer:
+ name:
+ storage_type: file
+ path: /PATH/TO/SFT_DATASET/
+ split: train
+ format:
+ prompt_type: messages
+ messages_key: messages
+trainer:
+ trainer_config_path: /PATH/TO/TRAIN_CONFIG_YAML/
+ save_interval: 50
+```
+
+## Step 3: Run the Experiment
-Run RFT process with the following command:
+Run DPO process with the following command:
```shell
trinity run --config examples/dpo_humanlike/dpo.yaml
```
+or, for SFT:
+
+```shell
+trinity run --config /PATH/TO/sft.yaml
+```
From aeabfe5b208670fa259b62d115bbf7e4af4b8136 Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Tue, 17 Jun 2025 11:27:32 +0800
Subject: [PATCH 13/28] merge verl 0.4.0 (#79)
---
examples/dpo_humanlike/train_dpo.yaml | 3 +-
examples/opmd_gsm8k/train_opmd_gsm8k.yaml | 7 +-
pyproject.toml | 2 +-
trinity/common/verl_config.py | 13 +-
trinity/trainer/verl/dp_actor.py | 218 +----
trinity/trainer/verl/fsdp_workers.py | 935 ++++++++++------------
trinity/trainer/verl_trainer.py | 31 +-
7 files changed, 499 insertions(+), 710 deletions(-)
diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml
index 8ffc68b397..028c997e06 100644
--- a/examples/dpo_humanlike/train_dpo.yaml
+++ b/examples/dpo_humanlike/train_dpo.yaml
@@ -26,8 +26,7 @@ actor_rollout_ref:
min_lr_ratio: 0.1 # only useful for warmup with cosine
warmup_style: cosine # select from constant/cosine
total_training_steps: 783 #
- beta1: 0.9
- beta2: 0.95
+ betas: [0.9, 0.95]
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
index 326904d987..44a0111d64 100644
--- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
@@ -15,8 +15,8 @@
# entropy_coeff: default to 0.0 for now
#
# optimizer:
-# beta1, beta2: 0.0, 0.95 # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
-# lr: set smaller to account for beta1 = 0.0
+# betas: [0.0, 0.95] # smaller than default values (0.9, 0.999), as a remedy for abrupt distribution shift
+# lr: set smaller to account for betas[0] = 0.0
#
# misc:
# adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd
@@ -50,8 +50,7 @@ actor_rollout_ref:
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
- beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
- beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
+ betas: [0.0, 0.95] # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
diff --git a/pyproject.toml b/pyproject.toml
index 022c9a8ffe..bafa620470 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,7 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
- "verl==0.3.0.post1",
+ "verl==0.4.0",
"ray[default]>=2.45.0",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index 644fe9a8f5..e6b1b9e4e1 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -33,8 +33,7 @@ class Optim:
min_lr_ratio: Optional[float] = 0.0
warmup_style: str = "constant"
total_training_steps: int = -1
- beta1: float = 0.9
- beta2: float = 0.999
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
@dataclass
@@ -82,6 +81,7 @@ class Actor:
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
+ loss_agg_mode: str = "token-mean" # do not set
@dataclass
@@ -99,12 +99,20 @@ class _ValKwargs:
do_sample: bool = False
+@dataclass
+class _MultiTurn:
+ enable: bool = False
+
+
@dataclass
class Rollout:
# do not set
val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
+ multi_turn: _MultiTurn = field(default_factory=_MultiTurn)
temperature: float = 1.0
n: int = 1 # > 1 for grpo
+ log_prob_micro_batch_size: Optional[int] = None
+ log_prob_micro_batch_size_per_gpu: int = 1
@dataclass
@@ -148,6 +156,7 @@ class Critic:
cliprange_value: float = 0.0
checkpoint: Checkpoint = field(default_factory=Checkpoint)
rollout_n: int = 1
+ loss_agg_mode: str = "token-mean"
@dataclass
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 616234d0d6..0d750c8303 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -1,4 +1,6 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2023-2024 SGLang Team
+# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,49 +14,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-Modified from dp_actor.py
+Single Process Actor.
+Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/actor/dp_actor.py
"""
import itertools
-from typing import Tuple
+import logging
+import os
import torch
-import verl.utils.torch_functional as verl_F
-from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl import DataProto
+from verl.utils.debug import GPUMemoryLogger
+from verl.utils.device import get_torch_device
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
-from verl.utils.torch_functional import logprobs_from_logits
-from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
-from verl.workers.actor import BasePPOActor
+from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor
from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig
__all__ = ["DataParallelPPOActor"]
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
-class DataParallelPPOActor(BasePPOActor):
+
+class DataParallelPPOActor(DPActor):
def __init__(
- self,
- config,
- actor_module: nn.Module,
- actor_optimizer: torch.optim.Optimizer = None,
+ self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None
):
"""When optimizer is None, it is Reference Policy"""
- super().__init__(config)
- self.actor_module = actor_module
- self.actor_optimizer = actor_optimizer
- self.use_remove_padding = self.config.get("use_remove_padding", False)
- print(f"Actor use_remove_padding={self.use_remove_padding}")
- self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
- self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
-
- self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
+ super().__init__(config, actor_module, actor_optimizer)
+
self.policy_loss_fn = None
self.kl_loss_fn = None
self.entropy_loss_fn = None
@@ -68,150 +63,8 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
**algorithm_config.entropy_loss_fn_args
)
- def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Returns:
- entropy: # (bs, response_len)
- log_probs: # (bs, response_len)
- """
- response_length = micro_batch["responses"].size(-1)
- multi_modal_inputs = {}
- if "multi_modal_inputs" in micro_batch:
- for key in micro_batch["multi_modal_inputs"][0].keys():
- multi_modal_inputs[key] = torch.cat(
- [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
- )
-
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
- input_ids = micro_batch["input_ids"]
- batch_size, seqlen = input_ids.shape
- attention_mask = micro_batch["attention_mask"]
- position_ids = micro_batch["position_ids"]
- if position_ids.dim() == 3: # qwen2vl mrope
- position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
-
- if self.use_remove_padding:
- input_ids_rmpad, indices, *_ = unpad_input(
- input_ids.unsqueeze(-1), attention_mask
- ) # input_ids_rmpad (total_nnz, ...)
- input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
-
- # unpad the position_ids to align the rotary
- if position_ids.dim() == 3:
- position_ids_rmpad = (
- index_first_axis(
- rearrange(position_ids, "c b s ... -> (b s) c ..."), indices
- )
- .transpose(0, 1)
- .unsqueeze(1)
- ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
- else:
- position_ids_rmpad = index_first_axis(
- rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
- ).transpose(0, 1)
-
- # for compute the log_prob
- input_ids_rmpad_rolled = torch.roll(
- input_ids_rmpad, shifts=-1, dims=1
- ) # (1, total_nnz)
-
- # pad and slice the inputs if sp > 1
- if self.use_ulysses_sp:
- input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
- input_ids_rmpad,
- position_ids_rmpad,
- sp_size=self.ulysses_sequence_parallel_size,
- )
- input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
- input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
- )
-
- input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(
- 0
- ) # ((total_nnz / sp) + pad)
-
- # only pass input_ids and position_ids to enable flash_attn_varlen
- output = self.actor_module(
- input_ids=input_ids_rmpad,
- attention_mask=None,
- position_ids=position_ids_rmpad,
- **multi_modal_inputs,
- use_cache=False,
- ) # prevent model thinks we are generating
- logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
-
- logits_rmpad.div_(temperature)
-
- # compute entropy
- entropy_rmpad = self.compute_entropy_from_logits(
- logits_rmpad
- ) # ((total_nnz / sp) + pad)
-
- # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
- log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
-
- # gather log_prob if sp > 1
- if self.use_ulysses_sp:
- # gather and unpad for the ulysses sp
- log_probs = gather_outpus_and_unpad(
- log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size
- )
- entropy_rmpad = gather_outpus_and_unpad(
- entropy_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
- )
- # pad back to (bsz, seqlen)
- full_entropy = pad_input(
- hidden_states=entropy_rmpad.unsqueeze(-1),
- indices=indices,
- batch=batch_size,
- seqlen=seqlen,
- )
- full_log_probs = pad_input(
- hidden_states=log_probs.unsqueeze(-1),
- indices=indices,
- batch=batch_size,
- seqlen=seqlen,
- )
-
- # only return response part:
- entropy = full_entropy.squeeze(-1)[
- :, -response_length - 1 : -1
- ] # (bsz, response_length)
- log_probs = full_log_probs.squeeze(-1)[
- :, -response_length - 1 : -1
- ] # (bsz, response_length)
-
- else: # not using rmpad and no ulysses sp
- output = self.actor_module(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **multi_modal_inputs,
- use_cache=False,
- ) # prevent model thinks we are generating
- logits = output.logits
- logits.div_(temperature)
- logits = logits[
- :, -response_length - 1 : -1, :
- ] # (bsz, response_length, vocab_size)
- log_probs = logprobs_from_logits(logits, micro_batch["responses"])
- entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
-
- return entropy, log_probs
-
- def _optimizer_step(self):
- assert self.config.grad_clip is not None
-
- if isinstance(self.actor_module, FSDP):
- grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
- else:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self.actor_module.parameters(), max_norm=self.config.grad_clip
- )
- self.actor_optimizer.step()
- return grad_norm
-
- def compute_log_prob(self, data: DataProto) -> torch.Tensor:
+ @GPUMemoryLogger(role="dp actor", logger=logger)
+ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
@@ -235,7 +88,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
micro_batch_size = data.meta_info["micro_batch_size"]
temperature = data.meta_info[
"temperature"
- ] # temperature must be in the data.meta_info to avoid slient error
+ ] # temperature must be in the data.meta_info to avoid silent error
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
@@ -258,30 +111,40 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
micro_batches = batch.split(micro_batch_size)
log_probs_lst = []
+ entropy_lst = []
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
-
with torch.no_grad():
- _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
+ entropy, log_probs = self._forward_micro_batch(
+ micro_batch, temperature=temperature, calculate_entropy=calculate_entropy
+ )
log_probs_lst.append(log_probs)
- log_probs = torch.concat(log_probs_lst, dim=0)
+ if calculate_entropy:
+ entropy_lst.append(entropy)
+ log_probs = torch.concat(log_probs_lst, dim=0)
+ entropys = None
+ if calculate_entropy:
+ entropys = torch.concat(entropy_lst, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
+ if calculate_entropy:
+ entropys = entropys[revert_indices] # type: ignore
- return log_probs
+ return log_probs, entropys
- def update_policy(self, data: DataProto): # noqa: C901
+ @GPUMemoryLogger(role="dp actor", logger=logger)
+ def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()
temperature = data.meta_info[
"temperature"
- ] # temperature must be in the data.meta_info to avoid slient error
+ ] # temperature must be in the data.meta_info to avoid silent error
select_keys = [
"input_ids",
"position_ids",
@@ -351,12 +214,12 @@ def update_policy(self, data: DataProto): # noqa: C901
# Support all hardwares
if isinstance(data, DataProto):
data = {
- **data.batch.to(torch.cuda.current_device()),
+ **data.batch.to(get_torch_device().current_device()),
**data.non_tensor_batch,
}
else:
data = data.to(
- torch.cuda.current_device()
+ get_torch_device().current_device()
) # actor device is cpu when using offload
responses = data["responses"]
response_length = responses.size(1)
@@ -365,8 +228,11 @@ def update_policy(self, data: DataProto): # noqa: C901
assert response_mask.shape == attention_mask[:, -response_length:].shape
# all return: (bsz, response_length)
+ calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
entropy, log_prob = self._forward_micro_batch(
- micro_batch=data, temperature=temperature
+ micro_batch=data,
+ temperature=temperature,
+ calculate_entropy=calculate_entropy,
)
kwargs = {
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index c0af427b4a..66d055feeb 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -12,74 +12,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
-The main entry point to run the PPO algorithm
+The main entry point to run the PPO algorithm.
+Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/fsdp_workers.py
"""
+import json
import logging
import os
import warnings
+from dataclasses import asdict
import psutil
import torch
import torch.distributed
-import verl.utils.torch_functional as verl_F
+import torch.distributed as dist
+import vllm # noqa: F401 ; import vllm to avoid "Cuda failure 1 'invalid argument'"
from codetiming import Timer
from omegaconf import DictConfig, open_dict
+from peft import LoraConfig, TaskType, get_peft_model
+from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
+from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
+from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
+from verl.utils.device import get_torch_device, is_cuda_available
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
+ CPUOffloadPolicy,
+ MixedPrecisionPolicy,
+ apply_fsdp2,
+ fsdp2_load_full_state_dict,
+ fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
+ layered_summon_lora_params,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
-from verl.utils.model import compute_position_id_with_mask
+from verl.utils.py_functional import convert_to_regular_types
+from verl.workers.fsdp_workers import (
+ create_device_mesh,
+ device_name,
+ get_sharding_strategy,
+)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from trinity.common.config import AlgorithmConfig
-from trinity.common.constants import SyncMethod
+from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address
logger = logging.getLogger(__file__)
-logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
-
-
-def create_device_mesh(world_size, fsdp_size):
- if fsdp_size < 0 or fsdp_size >= world_size:
- device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
- else:
- device_mesh = init_device_mesh(
- "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
- )
- return device_mesh
-
-
-def get_sharding_strategy(device_mesh):
- from torch.distributed.fsdp import ShardingStrategy
-
- if device_mesh.ndim == 1:
- sharding_strategy = ShardingStrategy.FULL_SHARD
- elif device_mesh.ndim == 2:
- sharding_strategy = ShardingStrategy.HYBRID_SHARD
- else:
- raise NotImplementedError(
- f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2"
- )
- return sharding_strategy
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class ActorRolloutRefWorker(Worker):
@@ -94,7 +90,13 @@ def __init__(self, config: DictConfig, role: str):
import torch.distributed
if not torch.distributed.is_initialized():
- torch.distributed.init_process_group(backend="nccl")
+ rank = int(os.environ.get("RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ torch.distributed.init_process_group(
+ backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl",
+ rank=rank,
+ world_size=world_size,
+ )
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
@@ -111,12 +113,14 @@ def __init__(self, config: DictConfig, role: str):
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
- "cuda",
+ device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
+ self._lora_rank = self.config.model.get("lora_rank", 0)
+ self._is_lora = self._lora_rank > 0
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
@@ -153,6 +157,8 @@ def __init__(self, config: DictConfig, role: str):
self.config.actor.ppo_micro_batch_size_per_gpu = (
self.config.actor.ppo_micro_batch_size
)
+
+ if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
@@ -181,22 +187,22 @@ def __init__(self, config: DictConfig, role: str):
self.config.ref.log_prob_micro_batch_size
)
- def _build_model_optimizer(
+ def _build_model_optimizer( # noqa: C901
self,
model_path,
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
+ use_fused_kernels=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
+ enable_activation_offload=False,
):
from torch import optim
- from torch.distributed.fsdp import CPUOffload
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- from torch.distributed.fsdp import MixedPrecision
+ from torch.distributed.fsdp import CPUOffload, MixedPrecision
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -211,8 +217,8 @@ def _build_model_optimizer(
assert role in ["actor", "ref"]
- log_gpu_memory_usage("Before init from HF AutoModel", logger=logger)
- local_path = copy_to_local(model_path)
+ log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
+ local_path = model_path
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
@@ -227,9 +233,13 @@ def _build_model_optimizer(
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(
- local_path, trust_remote_code=trust_remote_code
+ local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)
+ # patch for kimi-vl
+ if getattr(actor_model_config, "model_type", None) == "kimi_vl":
+ actor_model_config.text_config.topk_method = "greedy"
+
self.generation_config = get_generation_config(
local_path, trust_remote_code=trust_remote_code
)
@@ -260,17 +270,9 @@ def _build_model_optimizer(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
- attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
- if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
- from verl.models.transformers.monkey_patch import apply_monkey_patch
-
- apply_monkey_patch(
- model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
- )
-
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import (
@@ -279,6 +281,13 @@ def _build_model_optimizer(
_apply_liger_kernel_to_instance(model=actor_module)
+ apply_monkey_patch(
+ model=actor_module,
+ use_remove_padding=use_remove_padding,
+ ulysses_sp_size=self.ulysses_sequence_parallel_size,
+ use_fused_kernels=use_fused_kernels,
+ )
+
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
@@ -286,12 +295,24 @@ def _build_model_optimizer(
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
+ if self._is_lora:
+ print("Applying LoRA to actor module")
+ actor_module.enable_input_require_grads()
+ # Convert config to regular Python types before creating PEFT model
+ lora_config = {
+ "task_type": TaskType.CAUSAL_LM,
+ "r": self.config.model.lora_rank,
+ "lora_alpha": self.config.model.lora_alpha,
+ "target_modules": convert_to_regular_types(self.config.model.target_modules),
+ "bias": "none",
+ }
+ actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
- log_gpu_memory_usage("After init from HF AutoModel", logger=logger)
+ log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
@@ -313,14 +334,17 @@ def _build_model_optimizer(
)
auto_wrap_policy = get_fsdp_wrap_policy(
- module=actor_module, config=fsdp_config.get("wrap_policy", None)
+ module=actor_module,
+ config=fsdp_config.get("wrap_policy", None),
+ is_lora=self.config.model.get("lora_rank", 0) > 0,
)
if self._is_rollout and self.config.rollout.name == "hf":
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
auto_wrap_policy = None
- print(f"wrap_policy: {auto_wrap_policy}")
+ if self.rank == 0:
+ print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
@@ -329,74 +353,104 @@ def _build_model_optimizer(
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
- actor_module_fsdp = FSDP(
- actor_module,
- cpu_offload=cpu_offload,
- param_init_fn=init_fn,
- use_orig_params=False,
- auto_wrap_policy=auto_wrap_policy,
- device_id=torch.cuda.current_device(),
- sharding_strategy=sharding_strategy, # zero3
- mixed_precision=mixed_precision,
- sync_module_states=True,
- device_mesh=self.device_mesh,
- forward_prefetch=False,
- )
+ fsdp_strategy = self.config.actor.strategy
+ if fsdp_strategy == "fsdp":
+ actor_module_fsdp = FSDP(
+ actor_module,
+ cpu_offload=cpu_offload,
+ param_init_fn=init_fn,
+ use_orig_params=False,
+ auto_wrap_policy=auto_wrap_policy,
+ device_id=get_torch_device().current_device(),
+ sharding_strategy=sharding_strategy, # zero3
+ mixed_precision=mixed_precision,
+ sync_module_states=True,
+ device_mesh=self.device_mesh,
+ forward_prefetch=False,
+ )
+ elif fsdp_strategy == "fsdp2":
+ assert (
+ CPUOffloadPolicy is not None
+ ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
+ )
+ if role == "actor" and fsdp_config.offload_policy:
+ cpu_offload = CPUOffloadPolicy(pin_memory=True)
+ self._is_offload_param = False
+ self._is_offload_optimizer = False
+ else:
+ cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
+
+ fsdp_kwargs = {
+ "mesh": fsdp_mesh,
+ "mp_policy": mp_policy,
+ "offload_policy": cpu_offload,
+ "reshard_after_forward": fsdp_config.reshard_after_forward,
+ }
+ full_state = actor_module.state_dict()
+ apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
+ fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
+ actor_module_fsdp = actor_module
+ else:
+ raise NotImplementedError(f"not implement {fsdp_strategy}")
- log_gpu_memory_usage("After Actor FSDP init", logger=logger)
+ if enable_activation_offload:
+ enable_activation_offloading(
+ actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing
+ )
+
+ log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
- beta1 = optim_config.get("beta1", 0.9)
- beta2 = optim_config.get("beta2", 0.999)
+ from verl.utils.torch_functional import (
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ )
+
actor_optimizer = optim.AdamW(
actor_module_fsdp.parameters(),
lr=optim_config.lr,
- betas=(beta1, beta2),
+ betas=optim_config.get("betas", (0.9, 0.999)),
weight_decay=optim_config.get("weight_decay", 1e-2),
)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
+ warmup_style = optim_config.get("warmup_style", "constant")
+ min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
+ num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
- print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
-
- if optim_config.warmup_style == "constant":
- from verl.utils.torch_functional import (
- get_constant_schedule_with_warmup,
- )
+ if self.rank == 0:
+ print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
+ if warmup_style == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
- elif optim_config.warmup_style == "cosine":
- from verl.utils.torch_functional import get_cosine_schedule_with_warmup
-
- assert (
- total_steps > 0
- ), "Cosine scheduler of actor requires total_training_steps > 0"
+ elif warmup_style == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
- min_lr_ratio=optim_config.min_lr_ratio,
+ min_lr_ratio=min_lr_ratio,
+ num_cycles=num_cycles,
)
else:
- raise NotImplementedError(
- f"Lr scheduler style {optim_config.warmup_style} is not supported"
- )
+ raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
+
+ log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
actor_optimizer = None
actor_lr_scheduler = None
- log_gpu_memory_usage("After actor optimizer init", logger=logger)
-
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
- def _build_rollout(self):
+ def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
@@ -406,62 +460,129 @@ def _build_rollout(self):
self.world_size % infer_tp == 0
), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
rollout_device_mesh = init_device_mesh(
- "cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
+ device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
-
- if self.config.rollout.name == "hf":
+ rollout_name = self.config.rollout.name
+ if rollout_name == "hf":
from verl.workers.rollout import HFRollout
- from verl.workers.sharding_manager import BaseShardingManager
+ from verl.workers.sharding_manager.base import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
- elif self.config.rollout.name == "vllm":
- if self.config.rollout.use_fire_sampling:
- from verl.workers.rollout.vllm_rollout import (
- FIREvLLMRollout as vLLMRollout,
- )
- from verl.workers.rollout.vllm_rollout import vllm_mode
- else:
- from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
- from verl.workers.sharding_manager import FSDPVLLMShardingManager
- log_gpu_memory_usage("Before building vllm rollout", logger=None)
- local_path = copy_to_local(self.config.model.path)
+ elif rollout_name == "vllm":
+ from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
+ from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
+
+ log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
+ local_path = copy_to_local(
+ self.config.model.path, use_shm=self.config.model.get("use_shm", False)
+ )
+ lora_kwargs = (
+ {
+ "lora_kwargs": {
+ "enable_lora": True,
+ "max_loras": 1,
+ "max_lora_rank": self._lora_rank,
+ }
+ }
+ if self._is_lora
+ else {}
+ )
+ # lora_kwargs = {}
if vllm_mode == "customized":
rollout = vLLMRollout(
actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
+ trust_remote_code=trust_remote_code,
+ **lora_kwargs,
)
elif vllm_mode == "spmd":
- rollout = vLLMRollout(
+ from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
+
+ vllm_rollout_cls = (
+ vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
+ )
+ rollout = vllm_rollout_cls(
model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
+ trust_remote_code=trust_remote_code,
+ **lora_kwargs,
)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
- log_gpu_memory_usage("After building vllm rollout", logger=None)
- if torch.distributed.get_world_size() == 1:
- self.config.rollout.load_format = "dummy_hf"
+
+ log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
+ full_params = torch.distributed.get_world_size() == 1
rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
+ full_params=full_params,
+ device_mesh=rollout_device_mesh,
+ offload_param=self._is_offload_param,
+ load_format=self.config.rollout.load_format,
+ layered_summon=self.config.rollout.get("layered_summon", False),
+ )
+ log_gpu_memory_usage("After building sharding manager", logger=logger)
+
+ elif rollout_name in ["sglang", "sglang_async"]:
+ if rollout_name == "sglang_async":
+ warnings.warn(
+ "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ from verl.workers.rollout.sglang_rollout import SGLangRollout
+
+ # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
+ # SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
+ # the main process of ray can not find any CUDA device, which would potentially lead to:
+ # "RuntimeError: No CUDA GPUs are available".
+ # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
+ # we import it here use the abs path.
+ # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
+ from verl.workers.sharding_manager.fsdp_sglang import (
+ FSDPSGLangShardingManager,
+ )
+
+ local_path = copy_to_local(self.config.model.path)
+ log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
+ rollout = SGLangRollout(
+ actor_module=local_path,
+ config=self.config.rollout,
+ tokenizer=self.tokenizer,
+ model_hf_config=self.actor_model_config,
+ trust_remote_code=trust_remote_code,
+ )
+ log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
+
+ if torch.distributed.get_world_size() == 1:
+ self.config.rollout.load_format = "dummy_hf"
+ rollout_sharding_manager = FSDPSGLangShardingManager(
+ module=self.actor_module_fsdp,
+ inference_engine=rollout._engine,
+ model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
+ offload_param=self._is_offload_param,
)
- log_gpu_memory_usage("After building sharding manager", logger=None)
+ log_gpu_memory_usage("After building sharding manager", logger=logger)
+
+ else:
+ raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
return rollout, rollout_sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
- from .dp_actor import DataParallelPPOActor
+ from trinity.trainer.verl.dp_actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
@@ -473,6 +594,8 @@ def init_model(self):
)
use_remove_padding = self.config.model.get("use_remove_padding", False)
+ use_shm = self.config.model.get("use_shm", False)
+ use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
@@ -482,27 +605,36 @@ def init_model(self):
else:
optim_config = None
fsdp_config = OmegaConf.create()
+
+ local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
(
self.actor_module_fsdp,
self.actor_optimizer,
self.actor_lr_scheduler,
self.actor_model_config,
) = self._build_model_optimizer(
- model_path=self.config.model.path,
+ model_path=local_path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
+ use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get(
"enable_gradient_checkpointing", False
),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
+ enable_activation_offload=self.config.model.get("enable_activation_offload", False),
)
# get the original unwrapped module
- self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
+ if fsdp_version(self.actor_module_fsdp) == 1:
+ self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
+
+ if self._is_offload_param:
+ offload_fsdp_model_to_cpu(self.actor_module_fsdp)
+ log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
@@ -512,6 +644,7 @@ def init_model(self):
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
+ self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.actor_module_fsdp,
@@ -519,15 +652,19 @@ def init_model(self):
)
if self._is_rollout:
- self.rollout, self.rollout_sharding_manager = self._build_rollout()
+ self.rollout, self.rollout_sharding_manager = self._build_rollout(
+ trust_remote_code=self.config.model.get("trust_remote_code", False)
+ )
if self._is_ref:
+ local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
self.ref_module_fsdp = self._build_model_optimizer(
- model_path=self.config.model.path,
+ model_path=local_path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
+ use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
@@ -535,6 +672,7 @@ def init_model(self):
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
+ self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(
config=self.config.ref, actor_module=self.ref_module_fsdp
)
@@ -555,8 +693,6 @@ def init_model(self):
checkpoint_contents=self.config.actor.checkpoint.contents,
)
- torch.cuda.empty_cache()
-
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def setup_weight_sync_group(self):
if (
@@ -588,7 +724,6 @@ def setup_weight_sync_group(self):
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
explorer = ray.get_actor("explorer")
- group_name = "rollout_weight_sync"
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
@@ -605,7 +740,7 @@ def setup_weight_sync_group(self):
timeout=timeout,
world_size=world_size,
rank=0,
- group_name=group_name,
+ group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
)
ray.get(setup_ref)
@@ -630,18 +765,16 @@ def set_algorithm(self, algo_config: AlgorithmConfig):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
- data = data.to(torch.cuda.current_device())
+ data = data.to(get_torch_device().current_device())
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(
- optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()
+ optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()
)
- log_gpu_memory_usage("Before update policy", logger=logger)
-
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
@@ -655,17 +788,17 @@ def update_actor(self, data: DataProto):
metrics["perf/mfu/actor"] = (
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
)
- metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (
+ metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (
+ 1024**3
+ )
+ metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (
1024**3
)
- metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
- self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
-
- log_gpu_memory_usage("After update policy", logger=logger)
+ self.actor_lr_scheduler.step()
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
@@ -675,19 +808,19 @@ def update_actor(self, data: DataProto):
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
+ log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
+ log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
# Support all hardwares
- prompts = prompts.to(torch.cuda.current_device())
+ prompts = prompts.to(get_torch_device().current_device())
assert self._is_rollout
- if self._is_offload_param:
- load_fsdp_model_to_gpu(self.actor_module_fsdp)
meta_info = {
"eos_token_id": self.generation_config.eos_token_id
@@ -699,12 +832,6 @@ def generate_sequences(self, prompts: DataProto):
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
- # after parameters sync with rollout, offload actor model to CPU
- if self._is_offload_param:
- offload_fsdp_model_to_cpu(self.actor_module_fsdp)
- if self._is_offload_optimizer:
- offload_fsdp_optimizer(optimizer=self.actor_optimizer)
-
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
@@ -717,18 +844,23 @@ def generate_sequences(self, prompts: DataProto):
output = output.to("cpu")
# clear kv cache
- torch.cuda.empty_cache()
- log_gpu_memory_usage("After recompute log prob", logger=logger)
+ get_torch_device().empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
+ # when is_lora is True, we use the actor without lora applied to calculate the log_prob
+ # which is mostly used for ref log_prob calculation
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
- data = data.to(torch.cuda.current_device())
+ from contextlib import nullcontext
+
+ is_lora = data.meta_info.pop("is_lora", False)
+ adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()
+ data = data.to(get_torch_device().current_device())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
@@ -737,9 +869,10 @@ def compute_log_prob(self, data: DataProto):
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
- output = self.actor.compute_log_prob(data=data)
+ with adapter_ctx:
+ output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
- tensors={"old_log_probs": output},
+ tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
@@ -748,21 +881,29 @@ def compute_log_prob(self, data: DataProto):
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
- if self.world_size > 1:
+ if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
+ log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
- log_gpu_memory_usage("After compute_log_prob", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
+ if self._is_lora:
+ # if _is_lora, actor without lora applied is the ref
+ data.meta_info["is_lora"] = True
+ data = self.compute_log_prob(data)
+ # this old_log_probs is in fact ref_log_prob
+ data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]})
+ return data
assert self._is_ref
-
+ # else:
+ # otherwise, the class have a standalone ref model
# Support all hardwares
- data = data.to(torch.cuda.current_device())
+ data = data.to(get_torch_device().current_device())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
@@ -771,7 +912,7 @@ def compute_ref_log_prob(self, data: DataProto):
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
- output = self.ref_policy.compute_log_prob(data=data)
+ output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
@@ -779,17 +920,15 @@ def compute_ref_log_prob(self, data: DataProto):
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
- if self.world_size > 1:
+ if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:
self.ref_policy.actor_module._handle.reshard(True)
- torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
# only support save and load ckpt for actor
assert self._is_actor
- import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
@@ -800,8 +939,42 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
+ dist.barrier()
+
+ if self._is_lora and hasattr(
+ getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"
+ ):
+ lora_save_path = os.path.join(local_path, "lora_adapter")
+ peft_model = getattr(self, "actor_module", self.actor_module_fsdp)
+ peft_config = {}
+ if dist.get_rank() == 0:
+ os.makedirs(lora_save_path, exist_ok=True)
+ peft_config = asdict(peft_model.peft_config.get("default", {}))
+ peft_config["task_type"] = peft_config["task_type"].value
+ peft_config["peft_type"] = peft_config["peft_type"].value
+ peft_config["target_modules"] = list(peft_config["target_modules"])
+ try:
+ if fsdp_version(self.actor_module_fsdp) > 0:
+ self.actor_module_fsdp = self.actor_module_fsdp.cuda()
+ lora_params = layered_summon_lora_params(self.actor_module_fsdp)
+ if dist.get_rank() == 0:
+ save_file(
+ lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")
+ )
+ with open(
+ os.path.join(lora_save_path, "adapter_config.json"),
+ "w",
+ encoding="utf-8",
+ ) as f:
+ json.dump(peft_config, f, ensure_ascii=False, indent=4)
+ except Exception as e:
+ if dist.get_rank() == 0:
+ print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})")
+
+ dist.barrier()
+ if dist.get_rank() == 0:
+ print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}")
- torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
@@ -839,7 +1012,7 @@ def __init__(self, config):
import torch.distributed
if not torch.distributed.is_initialized():
- torch.distributed.init_process_group(backend="nccl")
+ torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
@@ -854,7 +1027,7 @@ def __init__(self, config):
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
- "cuda",
+ device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
@@ -879,26 +1052,29 @@ def __init__(self, config):
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
+
+ if self.config.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
assert (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
+ self._is_lora = self.config.model.get("lora_rank", 0) > 0
def _build_critic_model_optimizer(self, config):
# the following line is necessary
from torch import optim
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
- local_path = copy_to_local(config.model.path)
+ use_shm = config.model.get("use_shm", False)
+ local_path = copy_to_local(config.model.path, use_shm=use_shm)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
- tokenizer_path = copy_to_local(config.model.tokenizer_path)
+ tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
self.tokenizer = hf_tokenizer(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
@@ -925,11 +1101,15 @@ def _build_critic_model_optimizer(self, config):
from transformers import AutoConfig, AutoModelForTokenClassification
- trust_remote_code = False
critic_model_config = AutoConfig.from_pretrained(
- local_path, trust_remote_code=trust_remote_code
+ local_path,
+ attn_implementation="flash_attention_2",
+ trust_remote_code=config.model.get("trust_remote_code", False),
)
critic_model_config.num_labels = 1
+ # patch for kimi-vl
+ if getattr(critic_model_config, "model_type", None) == "kimi_vl":
+ critic_model_config.text_config.topk_method = "greedy"
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
@@ -937,23 +1117,22 @@ def _build_critic_model_optimizer(self, config):
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
- setattr(critic_model_config, "classifier_dropout", 0.0)
- setattr(critic_model_config, "hidden_dropout", "0")
+ critic_model_config.classifier_dropout = 0.0
+ critic_model_config.hidden_dropout = "0"
critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
- attn_implementation="flash_attention_2",
- trust_remote_code=trust_remote_code,
+ trust_remote_code=config.model.get("trust_remote_code", False),
)
use_remove_padding = config.model.get("use_remove_padding", False)
- if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
- from verl.models.transformers.monkey_patch import apply_monkey_patch
- apply_monkey_patch(
- model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
- )
+ apply_monkey_patch(
+ model=critic_module,
+ use_remove_padding=use_remove_padding,
+ ulysses_sp_size=self.ulysses_sequence_parallel_size,
+ )
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
@@ -962,6 +1141,20 @@ def _build_critic_model_optimizer(self, config):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
+
+ if self._is_lora:
+ print("Applying LoRA to critic module")
+ critic_module.enable_input_require_grads()
+ # Convert config to regular Python types before creating PEFT model
+ lora_config = {
+ "task_type": TaskType.CAUSAL_LM,
+ "r": self.config.model.lora_rank,
+ "lora_alpha": self.config.model.lora_alpha,
+ "target_modules": convert_to_regular_types(self.config.model.target_modules),
+ "bias": "none",
+ }
+ critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))
+
if self.rank == 0:
print_model_size(critic_module)
@@ -987,7 +1180,9 @@ def _build_critic_model_optimizer(self, config):
)
auto_wrap_policy = get_fsdp_wrap_policy(
- module=critic_module, config=self.config.model.fsdp_config.wrap_policy
+ module=critic_module,
+ config=self.config.model.fsdp_config.wrap_policy,
+ is_lora=self.config.model.get("lora_rank", 0) > 0,
)
log_gpu_memory_usage("Before critic FSDP", logger=None)
@@ -996,59 +1191,87 @@ def _build_critic_model_optimizer(self, config):
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
- critic_module = FSDP(
- critic_module,
- param_init_fn=init_fn,
- use_orig_params=False,
- auto_wrap_policy=auto_wrap_policy,
- device_id=torch.cuda.current_device(),
- sharding_strategy=sharding_strategy,
- mixed_precision=mixed_precision,
- sync_module_states=True,
- forward_prefetch=False,
- device_mesh=self.device_mesh,
- cpu_offload=None,
- )
+ if config.strategy == "fsdp":
+ critic_module = FSDP(
+ critic_module,
+ param_init_fn=init_fn,
+ use_orig_params=False,
+ auto_wrap_policy=auto_wrap_policy,
+ device_id=get_torch_device().current_device(),
+ sharding_strategy=sharding_strategy,
+ mixed_precision=mixed_precision,
+ sync_module_states=True,
+ forward_prefetch=False,
+ device_mesh=self.device_mesh,
+ cpu_offload=None,
+ )
+ elif config.strategy == "fsdp2":
+ assert (
+ CPUOffloadPolicy is not None
+ ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
+ )
+ offload_policy = None
+ if fsdp_config.offload_policy:
+ self._is_offload_param = False
+ self._is_offload_optimizer = False
+ offload_policy = CPUOffloadPolicy(pin_memory=True)
+
+ fsdp_kwargs = {
+ "mesh": fsdp_mesh,
+ "mp_policy": mp_policy,
+ "offload_policy": offload_policy,
+ "reshard_after_forward": fsdp_config.reshard_after_forward,
+ }
+ full_state = critic_module.state_dict()
+ apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
+ fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
+ else:
+ raise NotImplementedError(f"Unknown strategy {config.strategy}")
+
+ if config.model.get("enable_activation_offload", False):
+ enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
+ enable_activation_offloading(
+ critic_module, config.strategy, enable_gradient_checkpointing
+ )
log_gpu_memory_usage("After critic FSDP", logger=None)
- beta1 = config.optim.get("beta1", 0.9)
- beta2 = config.optim.get("beta2", 0.999)
critic_optimizer = optim.AdamW(
critic_module.parameters(),
lr=config.optim.lr,
- betas=(beta1, beta2),
+ betas=config.optim.get("betas", (0.9, 0.999)),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
+ warmup_style = config.optim.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
- print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
+ if self.rank == 0:
+ print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
- if config.optim.warmup_style == "constant":
- from verl.utils.torch_functional import get_constant_schedule_with_warmup
+ from verl.utils.torch_functional import (
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ )
+ if warmup_style == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
- elif config.optim.warmup_style == "cosine":
- from verl.utils.torch_functional import get_cosine_schedule_with_warmup
-
- assert total_steps > 0, "Cosine scheduler of critic requires total_training_steps > 0"
+ elif warmup_style == "cosine":
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
- min_lr_ratio=config.optim.min_lr_ratio,
)
else:
- raise NotImplementedError(
- f"Lr scheduler style {config.optim.warmup_style} is not supported"
- )
+ raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler
@@ -1067,8 +1290,10 @@ def init_model(self):
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
+ log_gpu_memory_usage("After offload critic model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
+ log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
self.critic = DataParallelPPOCritic(
config=self.config,
@@ -1088,7 +1313,7 @@ def init_model(self):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
# Support all hardwares
- data = data.to(torch.cuda.current_device())
+ data = data.to(get_torch_device().current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
@@ -1111,12 +1336,12 @@ def compute_values(self, data: DataProto):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
# Support all hardwares
- data = data.to(torch.cuda.current_device())
+ data = data.to(get_torch_device().current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(
- optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()
+ optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()
)
# perform forward computation
@@ -1197,327 +1422,3 @@ def clear_optimizer_state(self):
self.critic_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
-
-
-# TODO(sgm): we may need to extract it to dp_reward_model.py
-class RewardModelWorker(Worker):
- """
- Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
- """
-
- def __init__(self, config):
- super().__init__()
- import torch.distributed
-
- if not torch.distributed.is_initialized():
- torch.distributed.init_process_group(backend="nccl")
- self.config = config
-
- # build device mesh for Ulysses Sequence Parallel
- world_size = torch.distributed.get_world_size()
- from torch.distributed.device_mesh import init_device_mesh
-
- fsdp_size = self.config.model.fsdp_config.fsdp_size
- self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
-
- self.ulysses_device_mesh = None
- self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
- dp = world_size // self.ulysses_sequence_parallel_size
- if self.ulysses_sequence_parallel_size > 1:
- self.ulysses_device_mesh = init_device_mesh(
- "cuda",
- mesh_shape=(dp, self.ulysses_sequence_parallel_size),
- mesh_dim_names=["dp", "sp"],
- )
-
- self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
-
- self.use_remove_padding = self.config.model.get("use_remove_padding", False)
-
- # normalize config
- if self.config.micro_batch_size is not None:
- self.config.micro_batch_size //= torch.distributed.get_world_size()
- self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
-
- def _build_model(self, config):
- # the following line is necessary
- from torch.distributed.fsdp import CPUOffload
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- from transformers import AutoConfig, AutoModelForTokenClassification
-
- # download the checkpoint from hdfs
- local_path = copy_to_local(config.model.path)
-
- if self.config.model.input_tokenizer is None:
- self._do_switch_chat_template = False
- else:
- self._do_switch_chat_template = True
- input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
- self.input_tokenizer = hf_tokenizer(
- input_tokenizer_local_path,
- trust_remote_code=config.model.get("trust_remote_code", False),
- )
- self.tokenizer = hf_tokenizer(
- local_path, trust_remote_code=config.model.get("trust_remote_code", False)
- )
-
- trust_remote_code = config.model.get("trust_remote_code", False)
- model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
- model_config.num_labels = 1
-
- # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
- init_context = get_init_weight_context_manager(
- use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
- )
-
- with init_context(), warnings.catch_warnings():
- warnings.simplefilter("ignore")
- setattr(model_config, "classifier_dropout", 0.0)
- reward_module = AutoModelForTokenClassification.from_pretrained(
- pretrained_model_name_or_path=local_path,
- config=model_config,
- torch_dtype=torch.bfloat16,
- attn_implementation="flash_attention_2",
- trust_remote_code=trust_remote_code,
- )
-
- if (
- config.model.get("use_remove_padding", False)
- or self.ulysses_sequence_parallel_size > 1
- ):
- from verl.models.transformers.monkey_patch import apply_monkey_patch
-
- apply_monkey_patch(
- model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
- )
-
- reward_module.to(torch.bfloat16)
- auto_wrap_policy = get_fsdp_wrap_policy(
- module=reward_module, config=self.config.model.fsdp_config
- )
-
- fsdp_mesh = self.device_mesh
- sharding_strategy = get_sharding_strategy(fsdp_mesh)
-
- reward_module = FSDP(
- reward_module,
- param_init_fn=init_fn,
- use_orig_params=False,
- auto_wrap_policy=auto_wrap_policy,
- device_id=torch.cuda.current_device(),
- sharding_strategy=sharding_strategy, # zero3
- sync_module_states=True,
- cpu_offload=CPUOffload(offload_params=True),
- forward_prefetch=False,
- device_mesh=self.device_mesh,
- )
-
- return reward_module
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def init_model(self):
- # This is used to import external_lib into the huggingface systems
- import_external_libs(self.config.model.get("external_lib", None))
- self.reward_module = self._build_model(config=self.config)
-
- def _forward_micro_batch(self, micro_batch):
- from flash_attn.bert_padding import (
- index_first_axis,
- pad_input,
- rearrange,
- unpad_input,
- )
- from verl.utils.ulysses import (
- gather_outpus_and_unpad,
- ulysses_pad_and_slice_inputs,
- )
-
- with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
- input_ids = micro_batch["input_ids"]
- batch_size, seqlen = input_ids.shape
- attention_mask = micro_batch["attention_mask"]
- position_ids = micro_batch["position_ids"]
-
- if self.use_remove_padding:
- input_ids_rmpad, indices, *_ = unpad_input(
- input_ids.unsqueeze(-1), attention_mask
- ) # input_ids_rmpad (total_nnz, ...)
- input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
-
- # unpad the position_ids to align the rotary
- position_ids_rmpad = index_first_axis(
- rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
- ).transpose(0, 1)
-
- # pad and slice the inputs if sp > 1
- if self.ulysses_sequence_parallel_size > 1:
- input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
- input_ids_rmpad,
- position_ids_rmpad,
- sp_size=self.ulysses_sequence_parallel_size,
- )
-
- # only pass input_ids and position_ids to enable flash_attn_varlen
- output = self.reward_module(
- input_ids=input_ids_rmpad,
- attention_mask=None,
- position_ids=position_ids_rmpad,
- use_cache=False,
- ) # prevent model thinks we are generating
- reward_rmpad = output.logits
- reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
-
- # gather output if sp > 1
- if self.ulysses_sequence_parallel_size > 1:
- reward_rmpad = gather_outpus_and_unpad(
- reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
- )
-
- # pad it back
- rm_score = pad_input(
- reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen
- ).squeeze(-1)
- else:
- output = self.reward_module(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=False,
- )
- rm_score = output.logits # (batch_size, seq_len, 1)
- rm_score = rm_score.squeeze(-1)
-
- # extract the result of the last valid token
- eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
- rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
- return rm_score
-
- def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
- batch_size = data.batch.batch_size[0]
- # expand as token_level_reward
- attention_mask = data.batch["attention_mask"]
- position_ids = data.batch["position_ids"]
- response_length = data.batch["responses"].shape[-1]
- eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
- token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
- token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
-
- # select the response part
- token_level_scores = token_level_scores[:, -response_length:]
-
- return token_level_scores
-
- def _switch_chat_template(self, data: DataProto):
- src_max_length = data.batch["attention_mask"].shape[-1]
-
- src_tokenizer = self.input_tokenizer
- target_tokenizer = self.tokenizer
-
- rm_input_ids = []
- rm_attention_mask = []
-
- for i in range(data.batch.batch_size[0]):
- # extract raw prompt
- chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
-
- # extract response
- response_ids = data.batch["responses"][i]
- response_length = response_ids.shape[-1]
- valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
- valid_response_ids = response_ids[:valid_response_length]
-
- # decode
- response = src_tokenizer.decode(valid_response_ids)
- # remove bos and eos
- response = response.replace(src_tokenizer.eos_token, "")
-
- chat.append({"role": "assistant", "content": response})
-
- prompt_with_chat_template = target_tokenizer.apply_chat_template(
- chat, add_generation_prompt=False, tokenize=False
- )
- if self.rank == 0 and i == 0:
- # for debugging purpose
- print(f"Switch template. chat: {prompt_with_chat_template}")
-
- # the maximum length is actually determined by the reward model itself
- max_length = self.config.get("max_length", src_max_length)
- if max_length is None:
- max_length = src_max_length
- input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
- prompt=prompt_with_chat_template,
- tokenizer=target_tokenizer,
- max_length=max_length,
- pad_token_id=target_tokenizer.pad_token_id,
- left_pad=False, # right padding
- truncation=self.config.get("truncation", "right"),
- ) # truncate from the right
-
- rm_input_ids.append(input_ids)
- rm_attention_mask.append(attention_mask)
-
- rm_input_ids = torch.cat(rm_input_ids, dim=0)
- rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
-
- rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
-
- rm_inputs = {
- "input_ids": rm_input_ids,
- "attention_mask": rm_attention_mask,
- "position_ids": rm_position_ids,
- }
-
- return DataProto.from_dict(rm_inputs)
-
- @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
- def compute_rm_score(self, data: DataProto):
- import itertools
-
- from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
-
- # Support all hardwares
- data = data.to(torch.cuda.current_device())
- if self._do_switch_chat_template:
- rm_data = self._switch_chat_template(data)
-
- # Support all hardwares
- rm_data.batch = rm_data.batch.to(torch.cuda.current_device())
-
- # perform forward computation
- with self.ulysses_sharding_manager:
- rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
- data = self.ulysses_sharding_manager.preprocess_data(data=data)
-
- use_dynamic_bsz = self.config.use_dynamic_bsz
- if use_dynamic_bsz:
- max_token_len = (
- self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
- )
- micro_batches, indices = rearrange_micro_batches(
- batch=rm_data.batch, max_token_len=max_token_len
- )
- else:
- micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
- output = []
- for micro_batch in micro_batches:
- rm_score = self._forward_micro_batch(micro_batch)
- output.append(rm_score)
- scores = torch.cat(output, dim=0) # (batch_size)
-
- if use_dynamic_bsz:
- indices = list(itertools.chain.from_iterable(indices))
- assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
- revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
- scores = scores[revert_indices]
-
- token_level_scores = self._expand_to_token_level(data, scores)
- # Note that this is only the scores, may not be the final rewards used to train RL
- output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
- output = self.ulysses_sharding_manager.postprocess_data(data=output)
-
- # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
- # unshard the root FSDP module
- self.reward_module._handle.reshard(True)
-
- output = output.to("cpu")
- return output
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 110a54a7db..bc15a25446 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -160,7 +160,16 @@ def _validate_config(self): # TODO
super()._validate_config()
def init_workers(self):
- """Init resource pool and worker group"""
+ """Initialize distributed training workers using Ray backend.
+
+
+ Creates:
+
+ 1. Ray resource pools from configuration
+
+ 2. Worker groups for each role (actor, critic, etc.)
+
+ """
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {
@@ -208,25 +217,31 @@ def init_workers(self):
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
- # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
+ # you should not use `create_colocated_worker_cls`.
+ # Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
- self.wg_dicts = []
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
+ wg_kwargs[
+ "ray_wait_register_center_timeout"
+ ] = self.config.trainer.ray_wait_register_center_timeout
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
- resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls
+ resource_pool=resource_pool,
+ ray_cls_with_init=worker_dict_cls,
+ device_name=self.device_name,
+ **wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
- # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
- self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
- if self.use_reference_policy:
+ if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model()
@@ -265,7 +280,7 @@ def prepare(self):
if self.config.trainer.get("val_only", False):
return
- def _create_dataloader(self):
+ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
self.train_dataloader = _InternalDataLoader(self.config)
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
From b8d1faa32aa1398545e7677f907da06330ab9045 Mon Sep 17 00:00:00 2001
From: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com>
Date: Tue, 17 Jun 2025 14:42:21 +0800
Subject: [PATCH 14/28] [Feature] Add MIX algorithm (#83)
---
.../source/tutorial/example_mix_algo.md | 303 ++++++++++++++++++
examples/mix_math/README.md | 7 +
examples/mix_math/mix_math.yaml | 88 +++++
examples/mix_math/train_mix_math.yaml | 70 ++++
trinity/algorithm/algorithm.py | 21 ++
trinity/algorithm/policy_loss_fn/__init__.py | 2 +
.../policy_loss_fn/mix_policy_loss.py | 133 ++++++++
trinity/algorithm/sample_strategy/__init__.py | 2 +
.../sample_strategy/mix_sample_strategy.py | 118 +++++++
trinity/trainer/verl/dp_actor.py | 7 +-
10 files changed, 750 insertions(+), 1 deletion(-)
create mode 100644 docs/sphinx_doc/source/tutorial/example_mix_algo.md
create mode 100644 examples/mix_math/README.md
create mode 100644 examples/mix_math/mix_math.yaml
create mode 100644 examples/mix_math/train_mix_math.yaml
create mode 100644 trinity/algorithm/policy_loss_fn/mix_policy_loss.py
create mode 100644 trinity/algorithm/sample_strategy/mix_sample_strategy.py
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
new file mode 100644
index 0000000000..9dadc76b40
--- /dev/null
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -0,0 +1,303 @@
+# Integrate An New Algorithm
+
+
+This guide introduces how to integrate a new algorithm to Trinity-RFT.
+As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
+
+$$
+\mathcal{J}_{\text{Mix}}(\theta) =
+\mathcal{J}_{\text{GRPO}}(\theta)
++
+\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
+\left[
+ \frac{1}{T'_b} \sum_{t=1}^{T'_b}
+ \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b, Dict:
+ return {
+ "repeat_times": 8,
+ "policy_loss_fn": "mix",
+ "advantage_fn": "grpo",
+ "sample_strategy": "mix",
+ }
+```
+
+
+## Step 2: Define the Sampling Strategy
+
+We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`.
+
+
+```python
+class MixSampleStrategy(SampleStrategy):
+ """The default sample strategy."""
+
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ super().__init__(buffer_config, trainer_type)
+ self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
+ tot_batch_size = buffer_config.read_batch_size
+ expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)
+
+ # experience buffer
+ usual_buffer_config = copy.deepcopy(buffer_config)
+ usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size
+ self.usual_exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
+ )
+
+ if buffer_config.trainer_input.sft_warmup_dataset is None:
+ raise ValueError(
+ "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm"
+ )
+
+ # expert experience buffer
+ expert_buffer_config = copy.deepcopy(buffer_config)
+ expert_buffer_config.read_batch_size = expert_batch_size
+ self.expert_exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
+ )
+
+ def sample(self, step: int) -> Tuple[Any, Dict, List]:
+ metrics = {}
+ with Timer(metrics, "read_time"):
+ usual_exp_list = self.usual_exp_buffer.read()
+ for exp in usual_exp_list:
+ if exp.info is None:
+ exp.info = {}
+ exp.info["is_expert"] = False
+
+ expert_exp_list = self.expert_exp_buffer.read()
+ for exp in expert_exp_list:
+ exp.reward = 0.0
+ exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
+ if exp.info is None:
+ exp.info = {}
+ exp.info["is_expert"] = True
+
+ exp_list = usual_exp_list + expert_exp_list
+ repr_samples = representative_sample(exp_list)
+
+ is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)
+
+ with Timer(metrics, "gather_time"):
+ exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
+
+ if self.trainer_type == "verl":
+ with Timer(metrics, "convert_time"):
+ data = to_data_proto_mix(exps, is_expert_mask)
+ return data, metrics, repr_samples
+ else:
+ raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+```
+
+We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type.
+
+```diff
++ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
+ attention_mask = experiences.attention_masks
+ cumsum = torch.cumsum(attention_mask, dim=-1)
+ position_ids = torch.clip(cumsum - 1, 0, None).long()
+ batch_dict = {
+ "uid": np.array(experiences.run_ids),
+ "position_ids": position_ids,
+ "input_ids": experiences.tokens.long(),
+ "responses": experiences.tokens[:, experiences.prompt_length :].long(),
+ "attention_mask": attention_mask.long(),
+ "response_mask": (
+ experiences.action_masks[:, experiences.prompt_length :].long()
+ if hasattr(experiences, "action_masks") and experiences.action_masks is not None
+ else attention_mask[:, experiences.prompt_length :].long()
+ ),
++ "is_expert_mask": is_expert_mask,
+ }
+ if experiences.rewards is not None:
+ token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
+ eos_mask_idx = cumsum.argmax(dim=-1)
+ token_level_rewards[
+ torch.arange(experiences.batch_size), eos_mask_idx
+ ] = experiences.rewards
+ token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
+ batch_dict.update(
+ {
+ "token_level_scores": token_level_rewards,
+ "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
+ }
+ )
+ return DataProto.from_single_dict(batch_dict)
+```
+
+
+## Step 3: Define the Policy Loss Function
+
+We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.
+
+```python
+@POLICY_LOSS_FN.register_module("mix")
+class MIXPolicyLossFn(PolicyLossFn):
+ def __init__(
+ self,
+ mu: float = 0.1,
+ clip_range: Optional[float] = None,
+ clip_range_low: Optional[float] = None,
+ clip_range_high: Optional[float] = None,
+ use_dynamic_bsz: Optional[bool] = None,
+ repeat_times: Optional[int] = None,
+ ppo_mini_batch_size: Optional[int] = None,
+ ppo_micro_batch_size_per_gpu: Optional[int] = None,
+ ngpus_trainer: Optional[int] = None,
+ read_batch_size_usual: Optional[int] = None,
+ read_batch_size_expert: Optional[int] = None,
+ use_token_level_loss_in_sft: bool = True,
+ ) -> None:
+ self.mu = mu
+ self.use_dynamic_bsz = use_dynamic_bsz
+ self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
+ self.gradient_accumulation = (
+ ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
+ )
+ self.read_batch_size_usual = read_batch_size_usual
+ self.read_batch_size_expert = read_batch_size_expert
+ self.grpo_loss_fn = PPOPolicyLossFn(
+ clip_range=clip_range,
+ clip_range_low=clip_range_low,
+ clip_range_high=clip_range_high,
+ )
+ self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)
+
+ def __call__( # type: ignore
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ is_expert_mask = kwargs.get("is_expert_mask", None)
+ if is_expert_mask is None:
+ raise ValueError("is_expert_mask is required in MIX")
+ assert (
+ len(is_expert_mask) == logprob.shape[0]
+ ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
+
+ n_usual_exp = torch.sum(~is_expert_mask).item()
+ n_expert_exp = torch.sum(is_expert_mask).item()
+
+ if self.use_dynamic_bsz:
+ per_micro_batch_weight_usual = self.experience_per_gpu / (
+ logprob.shape[0] * self.read_batch_size_usual
+ )
+ per_micro_batch_weight_expert = self.experience_per_gpu / (
+ logprob.shape[0] * self.read_batch_size_expert
+ )
+ else:
+ per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore
+ per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore
+
+ if n_usual_exp > 0:
+ grpo_loss, grpo_metrics = self.grpo_loss_fn(
+ logprob[~is_expert_mask],
+ old_logprob[~is_expert_mask],
+ action_mask[~is_expert_mask],
+ advantages[~is_expert_mask],
+ **kwargs,
+ )
+ grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
+ grpo_metrics = {
+ k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
+ }
+ else:
+ grpo_loss = torch.tensor(0.0, device=logprob.device)
+ grpo_metrics = {}
+
+ # SFT Loss (expert)
+ if n_expert_exp > 0:
+ sft_loss, sft_metrics = self.sft_loss_fn(
+ logprob[is_expert_mask],
+ action_mask[is_expert_mask],
+ )
+ sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
+ sft_metrics = {
+ k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
+ }
+ else:
+ sft_loss = torch.tensor(0.0, device=logprob.device)
+ sft_metrics = {}
+
+ loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss
+
+ metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()}
+ metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()})
+ metrics.update({"loss": loss.item()})
+
+ return loss, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "mu": 0.1,
+ "clip_range": 0.2,
+ }
+
+ @property
+ def select_keys(self) -> List[str]:
+ return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
+```
+
+## Step 4: Run the Experiment
+
+With the above newly-defined classes and functions, we can run the experiments without modifying other process.
+An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`.
+For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml).
+
+```yaml
+algorithm:
+ algorithm_type: mix
+ repeat_times: 8
+ sample_strategy_args:
+ expert_data_ratio: 0.25
+ policy_loss_fn_args:
+ mu: 0.1
+ clip_range: 0.2
+ use_token_level_loss_in_sft: False
+ use_dynamic_bsz: False
+ repeat_times: 8
+ ppo_mini_batch_size: 32
+ ppo_micro_batch_size_per_gpu: 4
+ ngpus_trainer: 4
+ read_batch_size_expert: 64
+ read_batch_size_usual: 192
+```
diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md
new file mode 100644
index 0000000000..8e84f233bc
--- /dev/null
+++ b/examples/mix_math/README.md
@@ -0,0 +1,7 @@
+# Example: MIX on MATH dataset
+
+This example shows the usage of a new algorithm MIX on the MATH dataset.
+
+For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).
+
+The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml
new file mode 100644
index 0000000000..339d8df394
--- /dev/null
+++ b/examples/mix_math/mix_math.yaml
@@ -0,0 +1,88 @@
+project: "mix_math"
+name: "expert0.25_mu0.1"
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/
+algorithm:
+ algorithm_type: mix
+ repeat_times: 8
+ sample_strategy_args:
+ expert_data_ratio: 0.25
+ policy_loss_fn_args:
+ mu: 0.1
+ clip_range: 0.2
+ use_token_level_loss_in_sft: False
+ use_dynamic_bsz: False
+ repeat_times: 8
+ ppo_mini_batch_size: 32
+ ppo_micro_batch_size_per_gpu: 4
+ ngpus_trainer: 4
+ read_batch_size_expert: 64
+ read_batch_size_usual: 192
+model:
+ model_path: /PATH/TO/MODEL/
+ max_prompt_tokens: 1024
+ max_response_tokens: 10240
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 40
+ explore_batch_size: 36
+ max_retry_times: 3
+ max_retry_interval: 1
+ explorer_input:
+ taskset:
+ name: math_train
+ storage_type: file
+ path: /PATH/TO/DATASET/
+ split: 'train'
+ format:
+ prompt_key: 'problem'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: math_eval
+ storage_type: file
+ path: /PATH/TO/DATASET/
+ split: 'test'
+ format:
+ prompt_key: 'problem'
+ response_key: 'answer'
+ default_workflow_type: 'math_workflow'
+ trainer_input:
+ experience_buffer:
+ name: math_buffer
+ storage_type: queue
+ path: /PATH/TO/BUFFER/
+ sft_warmup_dataset:
+ name: math_sft
+ storage_type: file
+ algorithm_type: sft
+ path: /PATH/TO/EXPERT_DATA/
+ split: 'train'
+ format:
+ prompt_type: messages
+ messages_key: 'messages'
+explorer:
+ eval_interval: 10
+ runner_num: 16
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: true
+ dtype: bfloat16
+ seed: 42
+synchronizer:
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 1200
+trainer:
+ trainer_type: 'verl'
+ trainer_config_path: 'examples/mix_math/train_math.yaml'
+ save_interval: 50
+monitor:
+ monitor_type: wandb
diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml
new file mode 100644
index 0000000000..7b14a87fad
--- /dev/null
+++ b/examples/mix_math/train_mix_math.yaml
@@ -0,0 +1,70 @@
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: True # False
+ actor:
+ strategy: fsdp # This is for backward-compatibility
+ ppo_mini_batch_size: 128
+ ppo_micro_batch_size_per_gpu: 4
+ use_dynamic_bsz: True # False
+ ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length}
+ grad_clip: 1.0
+ clip_ratio: 0.2
+ entropy_coeff: 0.001
+ use_kl_loss: True # True for GRPO
+ kl_loss_coef: 0.0001 # for grpo
+ kl_loss_type: low_var_kl # for grpo
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 1 # sp size
+ optim:
+ lr: 1e-6
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ fsdp_config:
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ # --- below: opmd ---
+ tau: 0.000 # strength of regularization w.r.t. old / ref policy
+ opmd_baseline: mean # mean / logavgexp, applicable to opmd
+ use_uid: False # True / False, applicable to pairwise_opmd
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ log_prob_micro_batch_size_per_gpu: 16
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
+
+custom_reward_function:
+ path: null
+ name: compute_score
+
+algorithm:
+ gamma: 1.0
+ lam: 1.0
+ kl_penalty: kl # how to estimate kl divergence
+ kl_ctrl:
+ type: fixed
+ kl_coef: 0.0001
+
+trainer:
+ balance_batch: True
+ # auto: find the last ckpt to resume. If can't find, start from scratch
+ resume_mode: auto # or auto or resume_path if
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: False
+ del_local_ckpt_after_load: False
+ val_before_train: False
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index 88b9b946b7..6f0a2d19a7 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -180,3 +180,24 @@ def check_config(cls, config: Config) -> None:
logger.warning(
"DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2."
) # no need to warn
+
+
+@ALGORITHM_TYPE.register_module("mix")
+class MIXAlgorithm(AlgorithmType):
+ """MIX algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = True
+ use_advantage: bool = True
+ use_rollout: bool = True
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 8,
+ "policy_loss_fn": "mix",
+ "advantage_fn": "grpo",
+ "sample_strategy": "mix",
+ }
diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py
index 66dce16cab..705fb2525a 100644
--- a/trinity/algorithm/policy_loss_fn/__init__.py
+++ b/trinity/algorithm/policy_loss_fn/__init__.py
@@ -1,4 +1,5 @@
from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn
+from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn
from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
@@ -11,4 +12,5 @@
"OPMDPolicyLossFn",
"DPOLossFn",
"SFTLossFn",
+ "MIXPolicyLossFn",
]
diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py
new file mode 100644
index 0000000000..84679b0ea8
--- /dev/null
+++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py
@@ -0,0 +1,133 @@
+"""Mix policy loss function."""
+
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
+from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
+from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
+
+
+@POLICY_LOSS_FN.register_module("mix")
+class MIXPolicyLossFn(PolicyLossFn):
+ """Implements a mixed policy loss combining GRPO and SFT losses.
+
+ This loss function applies different loss components to data based on whether
+ it comes from an expert or not, as indicated by `is_expert_mask`. It combines:
+ - GRPO loss (self.grpo_loss_fn) for non-expert data
+ - SFT loss (self.sft_loss_fn) for expert data
+ - Weighting parameter `mu`
+
+ The per-sample weights are normalized using either `experience_per_gpu` or
+ `gradient_accumulation`, depending on whether dynamic batch sizing is enabled,
+ to ensure consistent weighting across different batches of the same type experiences.
+ """
+
+ def __init__(
+ self,
+ mu: float = 0.1,
+ clip_range: Optional[float] = None,
+ clip_range_low: Optional[float] = None,
+ clip_range_high: Optional[float] = None,
+ use_dynamic_bsz: Optional[bool] = None,
+ repeat_times: Optional[int] = None,
+ ppo_mini_batch_size: Optional[int] = None,
+ ppo_micro_batch_size_per_gpu: Optional[int] = None,
+ ngpus_trainer: Optional[int] = None,
+ read_batch_size_usual: Optional[int] = None,
+ read_batch_size_expert: Optional[int] = None,
+ use_token_level_loss_in_sft: bool = True,
+ ) -> None:
+ self.mu = mu
+ self.use_dynamic_bsz = use_dynamic_bsz
+ self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
+ self.gradient_accumulation = (
+ ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
+ )
+ self.read_batch_size_usual = read_batch_size_usual
+ self.read_batch_size_expert = read_batch_size_expert
+ self.grpo_loss_fn = PPOPolicyLossFn(
+ clip_range=clip_range,
+ clip_range_low=clip_range_low,
+ clip_range_high=clip_range_high,
+ )
+ self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)
+
+ def __call__( # type: ignore
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ is_expert_mask = kwargs.get("is_expert_mask", None)
+ if is_expert_mask is None:
+ raise ValueError("is_expert_mask is required in MIX")
+ assert (
+ len(is_expert_mask) == logprob.shape[0]
+ ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
+
+ n_usual_exp = torch.sum(~is_expert_mask).item()
+ n_expert_exp = torch.sum(is_expert_mask).item()
+
+ if self.use_dynamic_bsz:
+ per_micro_batch_weight_usual = self.experience_per_gpu / (
+ logprob.shape[0] * self.read_batch_size_usual
+ )
+ per_micro_batch_weight_expert = self.experience_per_gpu / (
+ logprob.shape[0] * self.read_batch_size_expert
+ )
+ else:
+ per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore
+ per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore
+
+ if n_usual_exp > 0:
+ grpo_loss, grpo_metrics = self.grpo_loss_fn(
+ logprob[~is_expert_mask],
+ old_logprob[~is_expert_mask],
+ action_mask[~is_expert_mask],
+ advantages[~is_expert_mask],
+ **kwargs,
+ )
+ grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
+ grpo_metrics = {
+ k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
+ }
+ else:
+ grpo_loss = torch.tensor(0.0, device=logprob.device)
+ grpo_metrics = {}
+
+ # SFT Loss (expert)
+ if n_expert_exp > 0:
+ sft_loss, sft_metrics = self.sft_loss_fn(
+ logprob[is_expert_mask],
+ action_mask[is_expert_mask],
+ )
+ sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
+ sft_metrics = {
+ k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
+ }
+ else:
+ sft_loss = torch.tensor(0.0, device=logprob.device)
+ sft_metrics = {}
+
+ loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss
+
+ metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()}
+ metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()})
+ metrics.update({"loss": loss.item()})
+
+ return loss, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "mu": 0.1,
+ "clip_range": 0.2,
+ }
+
+ @property
+ def select_keys(self) -> List[str]:
+ return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py
index 60f2e268ae..cd4b9e0d66 100644
--- a/trinity/algorithm/sample_strategy/__init__.py
+++ b/trinity/algorithm/sample_strategy/__init__.py
@@ -1,3 +1,4 @@
+from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy
from trinity.algorithm.sample_strategy.sample_strategy import (
SAMPLE_STRATEGY,
DefaultSampleStrategy,
@@ -10,4 +11,5 @@
"SampleStrategy",
"DefaultSampleStrategy",
"WarmupSampleStrategy",
+ "MixSampleStrategy",
]
diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
new file mode 100644
index 0000000000..acdd340b24
--- /dev/null
+++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
@@ -0,0 +1,118 @@
+import copy
+from math import ceil
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+import torch
+from verl.trainer.ppo.ray_trainer import DataProto
+
+from trinity.algorithm.sample_strategy.sample_strategy import (
+ SAMPLE_STRATEGY,
+ SampleStrategy,
+)
+from trinity.algorithm.sample_strategy.utils import representative_sample
+from trinity.buffer import get_buffer_reader
+from trinity.common.config import BufferConfig
+from trinity.common.experience import Experiences
+from trinity.utils.timer import Timer
+
+
+@SAMPLE_STRATEGY.register_module("mix")
+class MixSampleStrategy(SampleStrategy):
+ """The default sample strategy."""
+
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ super().__init__(buffer_config, trainer_type)
+ self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
+ tot_batch_size = buffer_config.read_batch_size
+ expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)
+
+ # experience buffer
+ usual_buffer_config = copy.deepcopy(buffer_config)
+ usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size
+ self.usual_exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
+ )
+
+ if buffer_config.trainer_input.sft_warmup_dataset is None:
+ raise ValueError(
+ "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm"
+ )
+
+ # expert experience buffer
+ expert_buffer_config = copy.deepcopy(buffer_config)
+ expert_buffer_config.read_batch_size = expert_batch_size
+ self.expert_exp_buffer = get_buffer_reader(
+ buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
+ )
+
+ def sample(self, step: int) -> Tuple[Any, Dict, List]:
+ metrics = {}
+ with Timer(metrics, "read_time"):
+ usual_exp_list = self.usual_exp_buffer.read()
+ for exp in usual_exp_list:
+ if exp.info is None:
+ exp.info = {}
+ exp.info["is_expert"] = False
+
+ expert_exp_list = self.expert_exp_buffer.read()
+ for exp in expert_exp_list:
+ exp.reward = 0.0
+ exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
+ if exp.info is None:
+ exp.info = {}
+ exp.info["is_expert"] = True
+
+ exp_list = usual_exp_list + expert_exp_list
+ repr_samples = representative_sample(exp_list)
+
+ is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)
+
+ with Timer(metrics, "gather_time"):
+ exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
+
+ if self.trainer_type == "verl":
+ with Timer(metrics, "convert_time"):
+ data = to_data_proto_mix(exps, is_expert_mask)
+ return data, metrics, repr_samples
+ else:
+ raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "expert_data_ratio": 0.5,
+ }
+
+
+def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
+ attention_mask = experiences.attention_masks
+ cumsum = torch.cumsum(attention_mask, dim=-1)
+ position_ids = torch.clip(cumsum - 1, 0, None).long()
+ batch_dict = {
+ "uid": np.array(experiences.run_ids),
+ "position_ids": position_ids,
+ "input_ids": experiences.tokens.long(),
+ "responses": experiences.tokens[:, experiences.prompt_length :].long(),
+ "attention_mask": attention_mask.long(),
+ "response_mask": (
+ experiences.action_masks[:, experiences.prompt_length :].long()
+ if hasattr(experiences, "action_masks") and experiences.action_masks is not None
+ else attention_mask[:, experiences.prompt_length :].long()
+ ),
+ "is_expert_mask": is_expert_mask,
+ }
+ if experiences.rewards is not None:
+ token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
+ eos_mask_idx = cumsum.argmax(dim=-1)
+ token_level_rewards[
+ torch.arange(experiences.batch_size), eos_mask_idx
+ ] = experiences.rewards
+ token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
+ batch_dict.update(
+ {
+ "token_level_scores": token_level_rewards,
+ "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
+ }
+ )
+ return DataProto.from_single_dict(batch_dict)
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 0d750c8303..6a57e58144 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -160,7 +160,12 @@ def update_policy(self, data: DataProto):
}
select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()}
for trinity_key in self.policy_loss_fn.select_keys:
- verl_key = select_keys_trinity2verl[trinity_key]
+ if trinity_key in select_keys_trinity2verl:
+ verl_key = select_keys_trinity2verl[trinity_key]
+ else:
+ verl_key = trinity_key
+ select_keys_verl2trinity.update({verl_key: trinity_key})
+ select_keys_trinity2verl.update({trinity_key: verl_key})
select_keys.append(verl_key)
if not isinstance(self.kl_loss_fn, DummyKLFn):
select_keys.append("ref_log_prob")
From 69ddbd039998d3763528645b81ef871d4f2e7741 Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Wed, 18 Jun 2025 10:06:17 +0800
Subject: [PATCH 15/28] Refactor on `select_keys` (#84)
---
docs/sphinx_doc/source/conf.py | 3 +-
docs/sphinx_doc/source/index.rst | 1 +
.../source/tutorial/example_mix_algo.md | 14 +--
examples/mix_math/README.md | 2 +-
examples/mix_math/mix_math.yaml | 3 +-
tests/algorithm/policy_loss_test.py | 94 ++++++++++++++++++
tests/common/config_test.py | 1 +
trinity/algorithm/key_mapper.py | 29 ++++++
trinity/algorithm/policy_loss_fn/dpo_loss.py | 11 +--
.../policy_loss_fn/mix_policy_loss.py | 32 +++----
.../policy_loss_fn/opmd_policy_loss.py | 14 +--
.../policy_loss_fn/policy_loss_fn.py | 96 ++++++++++++++++---
.../policy_loss_fn/ppo_policy_loss.py | 12 +--
trinity/algorithm/policy_loss_fn/sft_loss.py | 9 +-
trinity/trainer/verl/dp_actor.py | 26 +----
15 files changed, 247 insertions(+), 100 deletions(-)
create mode 100644 tests/algorithm/policy_loss_test.py
create mode 100644 trinity/algorithm/key_mapper.py
diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py
index 4842a34557..ffaabf72c9 100644
--- a/docs/sphinx_doc/source/conf.py
+++ b/docs/sphinx_doc/source/conf.py
@@ -22,12 +22,13 @@
"sphinx.ext.napoleon",
"sphinx.ext.autosectionlabel",
"myst_parser",
+ "sphinx.ext.mathjax",
]
source_suffix = {
".rst": "restructuredtext",
".md": "markdown",
}
-myst_enable_extensions = ["colon_fence"]
+myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"]
# Prefix document path to section labels, otherwise autogenerated labels would
# look like 'heading' rather than 'path/to/file:heading'
diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst
index 5604faa15d..a1b6fde647 100644
--- a/docs/sphinx_doc/source/index.rst
+++ b/docs/sphinx_doc/source/index.rst
@@ -24,6 +24,7 @@ Welcome to Trinity-RFT's documentation!
tutorial/example_data_functionalities.md
tutorial/trinity_configs.md
tutorial/trinity_programming_guide.md
+ tutorial/example_mix_algo.md
.. toctree::
:maxdepth: 1
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index 9dadc76b40..ee0010ba24 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -1,4 +1,4 @@
-# Integrate An New Algorithm
+# Integrate A New Algorithm
This guide introduces how to integrate a new algorithm to Trinity-RFT.
@@ -6,7 +6,7 @@ As an example, we incorporate some "expert" data generated by a more advanced LL
$$
\mathcal{J}_{\text{Mix}}(\theta) =
-\mathcal{J}_{\text{GRPO}}(\theta)
+(1-\mu) \mathcal{J}_{\text{GRPO}}(\theta)
+
\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
\left[
@@ -170,6 +170,7 @@ We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_polic
class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
+ backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
@@ -183,6 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn):
read_batch_size_expert: Optional[int] = None,
use_token_level_loss_in_sft: bool = True,
) -> None:
+ super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
@@ -204,11 +206,9 @@ class MIXPolicyLossFn(PolicyLossFn):
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
+ is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
- is_expert_mask = kwargs.get("is_expert_mask", None)
- if is_expert_mask is None:
- raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
@@ -271,10 +271,6 @@ class MIXPolicyLossFn(PolicyLossFn):
"mu": 0.1,
"clip_range": 0.2,
}
-
- @property
- def select_keys(self) -> List[str]:
- return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
```
## Step 4: Run the Experiment
diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md
index 8e84f233bc..2ef160b0f2 100644
--- a/examples/mix_math/README.md
+++ b/examples/mix_math/README.md
@@ -4,4 +4,4 @@ This example shows the usage of a new algorithm MIX on the MATH dataset.
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).
-The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
+The config files are located in [`mix_math.yaml`](mix_math.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml
index 339d8df394..b92edd4b25 100644
--- a/examples/mix_math/mix_math.yaml
+++ b/examples/mix_math/mix_math.yaml
@@ -27,7 +27,6 @@ cluster:
buffer:
total_epochs: 1
batch_size: 40
- explore_batch_size: 36
max_retry_times: 3
max_retry_interval: 1
explorer_input:
@@ -82,7 +81,7 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
- trainer_config_path: 'examples/mix_math/train_math.yaml'
+ trainer_config_path: 'examples/mix_math/train_mix_math.yaml'
save_interval: 50
monitor:
monitor_type: wandb
diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py
new file mode 100644
index 0000000000..ba88feb2d7
--- /dev/null
+++ b/tests/algorithm/policy_loss_test.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+"""Test for policy loss functions"""
+
+import unittest
+
+import torch
+from verl import DataProto
+
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
+
+
+class VerlPolicyLossTest(unittest.TestCase):
+ def setUp(self):
+ seed = 42
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ shape = (5, 20)
+ self.logprob = 2 * torch.rand(shape) - 1
+ self.input_data = DataProto.from_dict(
+ {
+ "old_log_probs": 2 * torch.rand(shape) - 1,
+ "ref_log_prob": 2 * torch.rand(shape) - 1,
+ "response_mask": torch.rand(shape) > 0.5,
+ "advantages": 2 * torch.rand(shape) - 1,
+ "is_expert_mask": torch.rand(shape[0]) > 0.5,
+ }
+ )
+
+ def test_ppo_policy_loss(self):
+ policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
+ policy_loss_fn_args = policy_loss_fn_cls.default_args()
+ policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
+ loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
+ ppo_loss = torch.tensor(0.28560468554496765)
+ pg_clipfrac = torch.tensor(0.3541666567325592)
+ ppo_kl = torch.tensor(-0.21663446724414825)
+ self.assertTrue(torch.allclose(loss, ppo_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss))
+
+ def test_sft_policy_loss(self):
+ policy_loss_fn_cls = POLICY_LOSS_FN.get("sft")
+ policy_loss_fn_args = policy_loss_fn_cls.default_args()
+ policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
+ loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
+ sft_loss = torch.tensor(-0.07560186833143234)
+ self.assertTrue(torch.allclose(loss, sft_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["sft_loss"]), sft_loss))
+
+ def test_dpo_policy_loss(self):
+ policy_loss_fn_cls = POLICY_LOSS_FN.get("dpo")
+ policy_loss_fn_args = policy_loss_fn_cls.default_args()
+ policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
+ loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
+ dpo_loss = torch.tensor(0.5406752228736877)
+ chosen_reward = torch.tensor(0.7082431316375732)
+ rejected_reward = torch.tensor(0.3757950782775879)
+ accuracy_mean = torch.tensor(1.0)
+ self.assertTrue(torch.allclose(loss, dpo_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["chosen_reward"]), chosen_reward))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["rejected_reward"]), rejected_reward))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["accuracy_mean"]), accuracy_mean))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["dpo_loss"]), dpo_loss))
+
+ def test_opmd_policy_loss(self):
+ policy_loss_fn_cls = POLICY_LOSS_FN.get("opmd")
+ policy_loss_fn_args = policy_loss_fn_cls.default_args()
+ policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
+ loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
+ opmd_loss = torch.tensor(-0.009589947760105133)
+ self.assertTrue(torch.allclose(loss, opmd_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["opmd_loss"]), opmd_loss))
+
+ def test_mix_policy_loss(self):
+ policy_loss_fn_cls = POLICY_LOSS_FN.get("mix")
+ policy_loss_fn_args = policy_loss_fn_cls.default_args()
+ policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
+ loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)
+ mix_loss = torch.tensor(0.6581965088844299)
+ pg_clipfrac = torch.tensor(0.7777777910232544)
+ ppo_kl = torch.tensor(-1.0737695693969727)
+ pg_loss = torch.tensor(0.7236452102661133)
+ sft_loss = torch.tensor(0.06915830634534359)
+ self.assertTrue(torch.allclose(loss, mix_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_clipfrac"]), pg_clipfrac))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["usual/ppo_kl"]), ppo_kl))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
+ self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))
diff --git a/tests/common/config_test.py b/tests/common/config_test.py
index e1ac0aa7d4..da4fd914a0 100644
--- a/tests/common/config_test.py
+++ b/tests/common/config_test.py
@@ -47,6 +47,7 @@ def test_all_examples_are_valid(self):
config_path = os.path.join(example_dir, example_name, filename)
try:
config = load_config(config_path)
+ config.checkpoint_root_dir = "./.cache/"
config.check_and_update()
except Exception as e:
print(f"Error loading config {config_path}: {e}")
diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py
new file mode 100644
index 0000000000..09c1f988a6
--- /dev/null
+++ b/trinity/algorithm/key_mapper.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+"""Key Mapper"""
+
+from typing import Dict
+
+
+class KeyMapper:
+ def __init__(self, to_trinity_map: Dict[str, str]):
+ self.to_trinity_map = to_trinity_map
+ self.from_trinity_map = {v: k for k, v in self.to_trinity_map.items()}
+
+ def to_trinity(self, key: str) -> str:
+ return self.to_trinity_map.get(key, key)
+
+ def from_trinity(self, key: str) -> str:
+ return self.from_trinity_map.get(key, key)
+
+
+ALL_MAPPERS = {
+ "verl": KeyMapper(
+ {
+ "log_prob": "logprob",
+ "old_log_probs": "old_logprob",
+ "ref_log_prob": "ref_logprob",
+ "response_mask": "action_mask",
+ "advantages": "advantages",
+ }
+ ),
+}
diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py
index 7dfbb7141d..0858cb7002 100644
--- a/trinity/algorithm/policy_loss_fn/dpo_loss.py
+++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py
@@ -1,6 +1,6 @@
"""DPO loss function."""
-from typing import Dict, List, Tuple
+from typing import Dict, Tuple
import torch
import torch.nn.functional as F
@@ -13,9 +13,11 @@
class DPOLossFn(PolicyLossFn):
def __init__(
self,
+ backend: str = "verl",
beta: float = 0.1,
label_smoothing: float = 0.0,
) -> None:
+ super().__init__(backend=backend)
self.beta = beta
self.label_smoothing = label_smoothing
@@ -63,10 +65,3 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}
-
- @property
- def select_keys(self) -> List[str]:
- return [
- "ref_logprob",
- "action_mask",
- ]
diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py
index 84679b0ea8..76c89c42d9 100644
--- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py
@@ -1,6 +1,6 @@
"""Mix policy loss function."""
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, Optional, Tuple
import torch
@@ -26,27 +26,29 @@ class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
+ backend: str = "verl",
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
use_dynamic_bsz: Optional[bool] = None,
- repeat_times: Optional[int] = None,
- ppo_mini_batch_size: Optional[int] = None,
- ppo_micro_batch_size_per_gpu: Optional[int] = None,
- ngpus_trainer: Optional[int] = None,
- read_batch_size_usual: Optional[int] = None,
- read_batch_size_expert: Optional[int] = None,
+ repeat_times: int = 1,
+ ppo_mini_batch_size: int = 1,
+ ppo_micro_batch_size_per_gpu: int = 1,
+ ngpus_trainer: int = 1,
+ read_batch_size_usual: int = 1,
+ read_batch_size_expert: int = 1,
use_token_level_loss_in_sft: bool = True,
) -> None:
+ super().__init__(backend=backend)
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
- self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
+ self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer
self.gradient_accumulation = (
- ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
+ ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu
)
- self.read_batch_size_usual = read_batch_size_usual
- self.read_batch_size_expert = read_batch_size_expert
+ self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer
+ self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer
self.grpo_loss_fn = PPOPolicyLossFn(
clip_range=clip_range,
clip_range_low=clip_range_low,
@@ -60,11 +62,9 @@ def __call__( # type: ignore
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
+ is_expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
- is_expert_mask = kwargs.get("is_expert_mask", None)
- if is_expert_mask is None:
- raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
@@ -127,7 +127,3 @@ def default_args(cls) -> Dict:
"mu": 0.1,
"clip_range": 0.2,
}
-
- @property
- def select_keys(self) -> List[str]:
- return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
index 042d26b341..618301b319 100644
--- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
@@ -1,6 +1,6 @@
"""OPMD policy loss function."""
-from typing import Dict, List, Tuple
+from typing import Dict, Tuple
import torch
@@ -10,13 +10,13 @@
@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
- def __init__(self, tau: float = 1.0) -> None:
+ def __init__(self, backend: str = "verl", tau: float = 1.0) -> None:
+ super().__init__(backend=backend)
self.tau = tau
def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor, # NOT USED!
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
@@ -29,11 +29,3 @@ def __call__( # type: ignore
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
-
- @property
- def select_keys(self) -> List[str]:
- return [
- "old_logprob",
- "action_mask",
- "advantages",
- ]
diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
index 6c1a29b3e9..aa6025252e 100644
--- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
+++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
@@ -1,18 +1,92 @@
-from abc import ABC, abstractmethod
-from typing import Dict, List, Tuple
+import inspect
+from abc import ABC, ABCMeta, abstractmethod
+from typing import Dict, Tuple
import torch
+from trinity.algorithm.key_mapper import ALL_MAPPERS
from trinity.utils.registry import Registry
POLICY_LOSS_FN = Registry("policy_loss_fn")
-class PolicyLossFn(ABC):
+class PolicyLossFnMeta(ABCMeta):
+ """Metaclass for policy loss functions that handles parameter name mapping and filtering."""
+
+ ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection
+
+ def __new__(cls, name, bases, dct):
+ """
+ Metaclass constructor that automatically generates parameter handling logic.
+
+ For example with `PPOPolicyLossFn` class:
+ .. code-block:: python
+ class PPOPolicyLossFn(PolicyLossFn):
+ ...
+ def __call__(
+ self,
+ logprob: torch.Tensor,
+ old_logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ ...
+
+ This metaclass analyzes the __call__ method's parameters to:
+ 1. Generate _select_keys containing all non-ignored parameters
+ 2. Create select_keys property that maps parameters to trainer-specific names
+ 3. Apply decorator to automatically convert input parameter names using the mapper
+ """
+ signature = inspect.signature(dct["__call__"])
+ param_names = [
+ key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys
+ ]
+ dct["_select_keys"] = param_names
+
+ # Property to return trainer-specific parameter names
+ def select_keys(self):
+ """Returns parameter keys mapped to the specific training framework's naming convention."""
+ keys = [self.mapper.from_trinity(key) for key in self._select_keys]
+ return keys
+
+ # Decorator to handle parameter name conversion before calling __call__
+ def decorator(func):
+ def wrapper(self, *args, **kwargs):
+ """Filters and converts parameter names according to the training framework's convention."""
+ new_kwargs = {}
+ for key, value in kwargs.items():
+ key = self.mapper.to_trinity(key)
+ if key == "logprob" or key in self._select_keys: # remove unused keys
+ new_kwargs[key] = value
+ return func(self, *args, **new_kwargs)
+
+ return wrapper
+
+ # Add the property and decorated method to the class
+ dct["select_keys"] = property(select_keys)
+ dct["__call__"] = decorator(dct["__call__"])
+ return super().__new__(cls, name, bases, dct)
+
+
+class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta):
"""
- Policy Loss Function
+ Abstract base class for policy loss functions.
+
+ This class provides the interface for implementing different policy gradient loss functions
+ while handling parameter name mapping between different training frameworks.
"""
+ def __init__(self, backend: str = "verl"):
+ """
+ Initialize the policy loss function.
+
+ Args:
+ backend: The training framework/backend to use (e.g., "verl")
+ """
+ self.backend = backend
+ self.mapper = ALL_MAPPERS[self.backend]
+
@abstractmethod
def __call__(
self,
@@ -20,8 +94,12 @@ def __call__(
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
+ Calculate the policy loss.
+
Args:
logprob (`torch.Tensor`): The log probability generated by the policy model.
+
+ Kwargs (optional):
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
@@ -36,14 +114,8 @@ def __call__(
@abstractmethod
def default_args(cls) -> Dict:
"""
- Returns:
- `Dict`: The default init arguments for the policy loss function.
- """
+ Get default initialization arguments for this loss function.
- @property
- @abstractmethod
- def select_keys(self) -> List[str]:
- """
Returns:
- `List[str]`: The keys to select from input data.
+ `Dict`: The default init arguments for the policy loss function.
"""
diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
index 5c735d4d6a..a4cc0b2d03 100644
--- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
@@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, Optional, Tuple
import torch
@@ -15,10 +15,12 @@
class PPOPolicyLossFn(PolicyLossFn):
def __init__(
self,
+ backend: str = "verl",
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
) -> None:
+ super().__init__(backend=backend)
if clip_range_low is None:
self.clip_range_low = clip_range
else:
@@ -61,11 +63,3 @@ def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}
-
- @property
- def select_keys(self) -> List[str]:
- return [
- "old_logprob",
- "action_mask",
- "advantages",
- ]
diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py
index dd1c75a4a2..2c824f1c09 100644
--- a/trinity/algorithm/policy_loss_fn/sft_loss.py
+++ b/trinity/algorithm/policy_loss_fn/sft_loss.py
@@ -1,6 +1,6 @@
"""SFT loss function."""
-from typing import Dict, List, Tuple
+from typing import Dict, Tuple
import torch
@@ -10,7 +10,8 @@
@POLICY_LOSS_FN.register_module("sft")
class SFTLossFn(PolicyLossFn):
- def __init__(self, use_token_level_loss: bool = True) -> None:
+ def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None:
+ super().__init__(backend=backend)
self.use_token_level_loss = use_token_level_loss
def __call__( # type: ignore
@@ -30,7 +31,3 @@ def default_args(cls):
return {
"use_token_level_loss": True,
}
-
- @property
- def select_keys(self) -> List[str]:
- return ["action_mask"]
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index 6a57e58144..e7eb34ea17 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -56,7 +56,7 @@ def __init__(
def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
- **algorithm_config.policy_loss_fn_args
+ backend="verl", **algorithm_config.policy_loss_fn_args
)
self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args)
self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
@@ -152,21 +152,7 @@ def update_policy(self, data: DataProto):
"responses",
"response_mask",
]
- select_keys_verl2trinity = {
- "old_log_probs": "old_logprob",
- "ref_log_prob": "ref_logprob",
- "response_mask": "action_mask",
- "advantages": "advantages",
- }
- select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()}
- for trinity_key in self.policy_loss_fn.select_keys:
- if trinity_key in select_keys_trinity2verl:
- verl_key = select_keys_trinity2verl[trinity_key]
- else:
- verl_key = trinity_key
- select_keys_verl2trinity.update({verl_key: trinity_key})
- select_keys_trinity2verl.update({trinity_key: verl_key})
- select_keys.append(verl_key)
+ select_keys.extend(self.policy_loss_fn.select_keys)
if not isinstance(self.kl_loss_fn, DummyKLFn):
select_keys.append("ref_log_prob")
select_keys = list(set(select_keys))
@@ -240,14 +226,8 @@ def update_policy(self, data: DataProto):
calculate_entropy=calculate_entropy,
)
- kwargs = {
- select_keys_verl2trinity[verl_key]: value
- for verl_key, value in data.items()
- if verl_key in select_keys_verl2trinity
- }
pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore
- logprob=log_prob,
- **kwargs,
+ logprob=log_prob, **data
)
prefix_metrics(
src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
From a592af7d4accf5ebb639d5cb4af4690dac832dfd Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Wed, 18 Jun 2025 11:30:19 +0800
Subject: [PATCH 16/28] Add guideline for adding new algorithm (#85)
---
docs/sphinx_doc/source/index.rst | 12 +-
.../source/tutorial/example_mix_algo.md | 13 +-
.../tutorial/example_reasoning_advanced.md | 2 +-
.../tutorial/trinity_programming_guide.md | 239 ++++++++++++++++--
trinity/algorithm/__init__.py | 3 +
trinity/algorithm/algorithm.py | 6 -
trinity/buffer/writer/sql_writer.py | 4 -
trinity/utils/registry.py | 30 +--
8 files changed, 252 insertions(+), 57 deletions(-)
diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst
index a1b6fde647..4b4cab2aa9 100644
--- a/docs/sphinx_doc/source/index.rst
+++ b/docs/sphinx_doc/source/index.rst
@@ -14,7 +14,7 @@ Welcome to Trinity-RFT's documentation!
:maxdepth: 1
:glob:
:hidden:
- :caption: Tutorial
+ :caption: Examples
tutorial/example_reasoning_basic.md
tutorial/example_reasoning_advanced.md
@@ -22,8 +22,15 @@ Welcome to Trinity-RFT's documentation!
tutorial/example_multi_turn.md
tutorial/example_dpo.md
tutorial/example_data_functionalities.md
- tutorial/trinity_configs.md
+
+.. toctree::
+ :maxdepth: 2
+ :glob:
+ :hidden:
+ :caption: Guidelines
+
tutorial/trinity_programming_guide.md
+ tutorial/trinity_configs.md
tutorial/example_mix_algo.md
.. toctree::
@@ -34,6 +41,7 @@ Welcome to Trinity-RFT's documentation!
build_api/trinity.buffer
build_api/trinity.explorer
build_api/trinity.trainer
+ build_api/trinity.algorithm
build_api/trinity.manager
build_api/trinity.common
build_api/trinity.utils
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index ee0010ba24..61ecec33b1 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -1,5 +1,8 @@
-# Integrate A New Algorithm
+# Algorithm Development
+```{note}
+This guide is an advanced version of the {ref}`Algorithms ` section in the Developer Guide.
+```
This guide introduces how to integrate a new algorithm to Trinity-RFT.
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
@@ -19,13 +22,10 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz
## Step 0: Prepare the Expert Data
-We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format:
+We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format:
```json
-{
- "question": "What is the average of 4, 6, and 8?",
- "response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."
-}
+{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."}
...
```
@@ -42,7 +42,6 @@ class MIXAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
- use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel
diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
index a80032bc12..aa4439e866 100644
--- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
+++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
@@ -6,7 +6,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a
-
+(OPMD)=
## OPMD: a native off-policy RL algorithm
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index 4d158f86b9..931cb81506 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -1,6 +1,16 @@
# Developer Guide
-This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.
+This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.
+
+Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**.
+We decouple the RL pipeline into three modules to make it easier to customize and extend.
+Below is a table summarizing the modules and components that developers with different tragets need to focus on.
+
+| Development Target | Core Module | Key Component |
+|--------------------|-------------|---------------|
+| Apply existing RL algorithms to new environments. | *Explorer* | `Workflow` |
+| Design new RL algorithms. | *Trainer* | `Algorithm` |
+| Enhance the RL process from the data perspective. | *Buffer* | Data Processing Module (Coming soon) |
```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
@@ -8,9 +18,10 @@ Trinity-RFT is still under development, and the following interfaces may change.
---
-## Creating New Workflows
+## Workflows (For RL Environment Developers)
-Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:
+In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments.
+A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:
---
@@ -18,19 +29,16 @@ Trinity-RFT allows developers to register new workflows (e.g., for multi-turn in
Before starting development, it's important to understand several core concepts:
-
- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type:
- **Math problems**: A `Task` contains the problem description and the golden answer.
- **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information.
-
-- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
+- **Workflow** ({class}`trinity.common.workflows.Workflow`): Describes how a `Task` is executed. It defines the interaction flow between Agents and Environments, including logic similar to *Rollout* and *Reward* calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows:
- `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).
- `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment.
- `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.
- ...
-
- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
---
@@ -40,12 +48,12 @@ Before starting development, it's important to understand several core concepts:
The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file.
To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields.
- - **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
- - **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
- - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
- - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
+- **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file.
+- **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.
+- **`raw_task`** (`Dict`): A record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
+- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
+- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
+- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
```{tip}
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.
@@ -82,7 +90,6 @@ buffer:
In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response.
-
---
### Step 2: Implement a New Workflow
@@ -106,8 +113,7 @@ class Workflow(ABC):
"""Run the workflow and return a list of Experiences."""
```
-
-#### Initializing Your Workflow
+#### Initialize Your Workflow
During initialization, `Workflow` receives the following parameters:
@@ -115,7 +121,6 @@ During initialization, `Workflow` receives the following parameters:
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
-
```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
```
@@ -143,7 +148,6 @@ We first call the model to generate multiple response using the provided questio
Then we calculate the reward for each response using the `calculate_reward` function.
Finally, we construct a list of `Experience` with the responses and rewards and return it.
-
```python
class ExampleWorkflow(Workflow):
@@ -215,7 +219,6 @@ For workflows that are not intended to be contributed to Trinity-RFT project, yo
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory.
```
-
#### Avoid Re-initialization
For heavy workflows, re-initializing every time can incurs extra computational costs.
@@ -235,7 +238,6 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```
-
#### Full Code Example
```python
@@ -289,7 +291,6 @@ class ExampleWorkflow(Workflow):
self.answer = task.raw_task.get("answer")
```
-
---
### Step 3: Use Your Workflow
@@ -314,6 +315,198 @@ trinity run --config
---
+(Algorithms)=
+## Algorithms (For RL Algorithm Developers)
+
+Trinity-RFT provides a standardized process for implementing new algorithms.
+
+### Step 0: Basic Concepts of Algorithm Module
+
+In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data.
+To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms.
+
+- **Sample Strategy** ({class}`trinity.algorithm.SampleStrategy`): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources.
+- **Advantage Fn**({class}`trinity.algorithm.AdvantageFn`): Responsible for calculating the Advantage and Returns of experience data.
+- **Policy Loss Fn**({class}`trinity.algorithm.PolicyLossFn`): Responsible for calculating the core training loss of the policy network.
+- **KL Fn**({class}`trinity.algorithm.KLFn`): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss.
+- **Entropy Loss Fn**({class}`trinity.algorithm.EntropyLossFn`): Responsible for calculating the entropy loss of the policy network.
+
+We provide several implementations of above modules in `trinity/algorithm`.
+
+---
+
+### Step 1: Implement Algorithm Components
+
+
+Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the {ref}`OPMD ` algorithm as an example.
+
+The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented.
+
+---
+
+#### Step 1.1: Implement `AdvantageFn`
+
+Developers need to implement the {class}`trinity.algorithm.AdvantageFn` interface, which mainly includes two methods:
+
+- `__call__`: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is [verl](https://github.com/volcengine/verl)'s `DataProto`.
+- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.
+
+After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name.
+
+Here's an implementation example for the OPMD algorithm's Advantage Fn:
+
+```python
+# trinity/algorithm/advantage_fn/opmd.py
+# import some modules
+from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+
+
+@ADVANTAGE_FN.register_module("opmd")
+class OPMDAdvantageFn(AdvantageFn):
+ """OPMD advantage computation"""
+
+ def __init__(
+ self,
+ opmd_baseline: str = "mean",
+ tau: float = 1.0,
+ ) -> None:
+ self.opmd_baseline = opmd_baseline
+ self.tau = tau
+
+
+ def __call__(
+ self,
+ exps: DataProto,
+ **kwargs,
+ ) -> Tuple[DataProto, Dict]:
+ # calculate advantages and returns based on the exps
+
+ # record some metrics
+
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {
+ "opmd_baseline": "mean",
+ "tau": 1.0,
+ }
+```
+
+#### Step 1.2: Implement `PolicyLossFn`
+
+Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interface, which is similar to `AdvantageFn` and includes two methods:
+
+- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`.
+- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.
+
+Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`.
+
+Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method:
+
+
+```python
+@POLICY_LOSS_FN.register_module("opmd")
+class OPMDPolicyLossFn(PolicyLossFn):
+ def __init__(self, tau: float = 1.0) -> None:
+ self.tau = tau
+
+ def __call__( # type: ignore
+ self,
+ logprob: torch.Tensor,
+ action_mask: torch.Tensor,
+ advantages: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ pg_losses = -advantages * logprob
+ opmd_loss = masked_mean(pg_losses, action_mask)
+ opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
+ return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
+
+ @classmethod
+ def default_args(cls) -> Dict:
+ return {"tau": 1.0}
+```
+
+---
+
+### Step 2: Register Your Algorithm
+
+The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
+
+To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {object}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
+
+The `AlgorithmType` class includes the following attributes and methods:
+
+- `use_critic`: Whether to use the Critic model
+- `use_reference`: Whether to use the Reference model
+- `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped
+- `can_balance_batch`: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples)
+- `schema`: The format of experience data corresponding to the algorithm
+- `get_default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`
+
+Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`.
+
+Below is the implementation for the OPMD algorithm.
+Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
+The dictionary returned by the `get_default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.
+
+```python
+@ALGORITHM_TYPE.register_module("opmd")
+class OPMDAlgorithm(AlgorithmType):
+ """OPMD algorithm."""
+
+ use_critic: bool = False
+ use_reference: bool = True
+ use_advantage: bool = True
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def get_default_config(cls) -> Dict:
+ return {
+ "repeat_times": 2,
+ "sample_strategy": "warmup",
+ "policy_loss_fn": "opmd",
+ "advantage_fn": "opmd",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "default",
+ }
+```
+
+---
+
+### Step 3: Use Your Algorithm
+
+After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file.
+
+For default configurations, you just need to add the following content to your `config.yaml` file:
+
+```yaml
+# some other configs
+algorithm:
+ algorithm_type: "opmd"
+# some other configs
+```
+
+If you need to modify certain parameters, you can simply add the corresponding parameters within the `algorithm` section. For example, if you need to modify `repeat_times` and the initialization parameters of `AdvantageFn` and `PolicyLossFn`, the modified `config.yaml` file would be as follows:
+
+```yaml
+# some other configs
+algorithm:
+ algorithm_type: "opmd"
+ repeat_times: 8
+ advantage_fn_args:
+ opmd_baseline: "logavgexp"
+ tau: 0.99
+ policy_loss_fn_args:
+ tau: 0.99
+# some other configs
+```
+
+---
+
## Adding New Config Entries for the Config Generator (Advanced)
### Step 0: Understanding Streamlit
@@ -344,11 +537,11 @@ The `CONFIG_GENERATORS.register_config` decorator automatically passes `key=conf
```
For `train_batch_size`, we will use the following settings:
+
- Default value: 96
- Visibility condition: `lambda: st.session_state["trainer_gpu_num"] > 0`
- Additional config: `{"_train_batch_size_per_gpu": 16}`
-
Here's the complete code for the `train_batch_size` parameter:
```python
@@ -408,6 +601,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea
Incorporate the new parameter into the relevant section using the `self.get_configs` method within the `ConfigManager` class.
Example:
+
```python
class ConfigManager:
def _expert_buffer_part(self):
@@ -421,6 +615,7 @@ To successfully integrate new parameters into the `config_manager.py` file, plea
Utilize `st.session_state` to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.
Example:
+
```python
class ConfigManager:
def _gen_buffer_config(self):
diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py
index ff52f609e5..667aa10d74 100644
--- a/trinity/algorithm/__init__.py
+++ b/trinity/algorithm/__init__.py
@@ -1,10 +1,13 @@
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
+from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy
__all__ = [
+ "ALGORITHM_TYPE",
+ "AlgorithmType",
"AdvantageFn",
"ADVANTAGE_FN",
"PolicyLossFn",
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index 6f0a2d19a7..805dd8f213 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -26,7 +26,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
use_critic: bool
use_reference: bool
use_advantage: bool
- use_rollout: bool
can_balance_batch: bool
schema: type
@@ -50,7 +49,6 @@ class SFTAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = False
use_advantage: bool = False
- use_rollout: bool = False
can_balance_batch: bool = True
schema: type = SFTDataModel
@@ -71,7 +69,6 @@ class PPOAlgorithm(AlgorithmType):
use_critic: bool = True
use_reference: bool = True
use_advantage: bool = True
- use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel
@@ -95,7 +92,6 @@ class GRPOAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
- use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel
@@ -119,7 +115,6 @@ class OPMDAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
- use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel
@@ -143,7 +138,6 @@ class DPOAlgorithm(AlgorithmType):
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = False
- use_rollout: bool = False
can_balance_batch: bool = False
schema: type = DPODataModel
diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py
index 3e054d58c6..49ef0ec9b3 100644
--- a/trinity/buffer/writer/sql_writer.py
+++ b/trinity/buffer/writer/sql_writer.py
@@ -2,7 +2,6 @@
import ray
-from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
@@ -15,9 +14,6 @@ class SQLWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
# we only support write RFT algorithm buffer for now
- # TODO: support other algorithms
- algorithm = ALGORITHM_TYPE.get(meta.algorithm_type)
- assert algorithm.use_rollout, "Only RFT buffer is supported for writing."
self.wrap_in_ray = meta.wrap_in_ray
self.db_wrapper = DBWrapper.get_wrapper(meta, config)
diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py
index d5ee37f36e..83cd393519 100644
--- a/trinity/utils/registry.py
+++ b/trinity/utils/registry.py
@@ -83,21 +83,21 @@ def register_module(self, module_name: str, module_cls: Type = None, force=False
Default: False.
Example:
- ```python
- WORKFLOWS = Registry("workflows")
-
- # register a module using decorator
- @WORKFLOWS.register_module(name="workflow_name")
- class MyWorkflow(Workflow):
- pass
-
- # or register a module directly
- WORKFLOWS.register_module(
- name="workflow_name",
- module_cls=MyWorkflow,
- force=True,
- )
- ```
+
+ .. code-block:: python
+ WORKFLOWS = Registry("workflows")
+
+ # register a module using decorator
+ @WORKFLOWS.register_module(name="workflow_name")
+ class MyWorkflow(Workflow):
+ pass
+
+ # or register a module directly
+ WORKFLOWS.register_module(
+ name="workflow_name",
+ module_cls=MyWorkflow,
+ force=True,
+ )
"""
if not (module_name is None or isinstance(module_name, str)):
From b7ea08f5acdddd2aab243ba9c31afa3930d26283 Mon Sep 17 00:00:00 2001
From: pxc
Date: Wed, 18 Jun 2025 12:21:05 +0800
Subject: [PATCH 17/28] fix file_reader
---
tests/trainer/trainer_test.py | 2 +-
trinity/buffer/reader/file_reader.py | 2 --
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index bf064785cd..726e22290e 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -184,7 +184,7 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
def tearDown(self):
- # remove dir only when the test passed
+ # TODO: remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index d43f2667b7..8bb9dbcd28 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -69,7 +69,6 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split)
) # TODO: support resume
- self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def read(
@@ -146,7 +145,6 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split)
) # TODO: support resume
- self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
def _get_assistant_message(self, item) -> dict:
From eea4d85d9dc431d59d76c07b379cd739ca640948 Mon Sep 17 00:00:00 2001
From: pxc
Date: Wed, 18 Jun 2025 13:24:40 +0800
Subject: [PATCH 18/28] update pyproject
---
pyproject.toml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index bafa620470..cc6ccba23f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,8 +23,9 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.4.0",
"ray[default]>=2.45.0",
- "vllm==0.8.5.post1",
+ "vllm==0.9.1",
"tensordict==0.6.2",
+ "flash-attn==2.8.0.post2",
"wandb",
"omegaconf",
"sqlalchemy",
From aedfa53b9be89c33b443ffb0a51d8d517064b6d0 Mon Sep 17 00:00:00 2001
From: pxc
Date: Wed, 18 Jun 2025 13:50:34 +0800
Subject: [PATCH 19/28] update pyproject.toml
---
README.md | 7 +++++--
pyproject.toml | 5 ++++-
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index deadf54ae3..3ccc1fa19c 100644
--- a/README.md
+++ b/README.md
@@ -148,8 +148,11 @@ pip install -e .\[dev\]
# Install flash-attn after all dependencies are installed
# Note: flash-attn will take a long time to compile, please be patient.
-pip install flash-attn -v
-# Try the following command if you encounter errors during installation
+# for bash
+pip install -e .[flash_attn]
+# for zsh
+pip install -e .\[flash_attn\]
+# Try the following command if you encounter errors during flash-attn installation
# pip install flash-attn -v --no-build-isolation
```
diff --git a/pyproject.toml b/pyproject.toml
index cc6ccba23f..c6917217ad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,7 +25,6 @@ dependencies = [
"ray[default]>=2.45.0",
"vllm==0.9.1",
"tensordict==0.6.2",
- "flash-attn==2.8.0.post2",
"wandb",
"omegaconf",
"sqlalchemy",
@@ -70,6 +69,10 @@ doc = [
"myst-parser",
]
+flash_attn = [
+ "flash-attn==2.8.0.post2"
+]
+
[tool.setuptools.packages.find]
where = ["."]
include = ["trinity*"]
From 2a36e0e130d1c56cd9e700b9dd5a50d80c8cadba Mon Sep 17 00:00:00 2001
From: pxc
Date: Wed, 18 Jun 2025 19:08:05 +0800
Subject: [PATCH 20/28] clean code
---
trinity/common/models/vllm_worker.py | 6 +-
trinity/explorer/explorer.py | 1 +
trinity/trainer/verl/fsdp_workers.py | 196 +--------------------------
3 files changed, 11 insertions(+), 192 deletions(-)
diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py
index 4a32628a96..4293811ab7 100644
--- a/trinity/common/models/vllm_worker.py
+++ b/trinity/common/models/vllm_worker.py
@@ -57,7 +57,7 @@ def init_process_group(
)
self._explorer_actor = None
- def update_weight(self, name, dtype, shape, empty_cache=False):
+ def update_weight(self, name: str, dtype_str: str, shape: tuple, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
if self._weight_update_rank == 0:
if self._explorer_actor is None:
@@ -65,8 +65,8 @@ def update_weight(self, name, dtype, shape, empty_cache=False):
weight = ray.get(self._explorer_actor.get_weight.remote(name))
weight = weight.to(self.device)
else:
- weight = torch.empty(shape, dtype=dtype, device="cuda")
-
+ dtype = getattr(torch, dtype_str.split(".")[-1])
+ weight = torch.empty(shape, dtype=dtype, device=self.device)
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
weight = weight.type(self.model_config.dtype)
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index a542e2c10b..6b8e5b6288 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -85,6 +85,7 @@ def setup_weight_sync_group(
f"world_size={world_size}, rank_offset={base_offset}"
)
self.state_dict_meta = state_dict_meta
+ # TODO: save state_dict in models
refs = [
model.init_process_group.remote(
master_address=master_address,
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index 66d055feeb..2a8308ea62 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -26,7 +26,7 @@
import torch
import torch.distributed
import torch.distributed as dist
-import vllm # noqa: F401 ; import vllm to avoid "Cuda failure 1 'invalid argument'"
+import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
from codetiming import Timer
from omegaconf import DictConfig, open_dict
from peft import LoraConfig, TaskType, get_peft_model
@@ -126,7 +126,6 @@ def __init__(self, config: DictConfig, role: str):
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
- self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._is_offload_param = False
@@ -170,14 +169,6 @@ def __init__(self, config: DictConfig, role: str):
> 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
- # normalize rollout config
- if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
- self.config.rollout.log_prob_micro_batch_size //= (
- self.device_mesh.size() // self.ulysses_sequence_parallel_size
- )
- self.config.rollout.log_prob_micro_batch_size_per_gpu = (
- self.config.rollout.log_prob_micro_batch_size
- )
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (
@@ -339,10 +330,6 @@ def _build_model_optimizer( # noqa: C901
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
- if self._is_rollout and self.config.rollout.name == "hf":
- # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
- auto_wrap_policy = None
-
if self.rank == 0:
print(f"wrap_policy: {auto_wrap_policy}")
@@ -450,136 +437,6 @@ def _build_model_optimizer( # noqa: C901
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
- def _build_rollout(self, trust_remote_code=False):
- from torch.distributed.device_mesh import init_device_mesh
-
- # TODO(sgm): support FSDP hybrid shard for larger model
- infer_tp = self.config.rollout.tensor_model_parallel_size
- dp = self.world_size // infer_tp
- assert (
- self.world_size % infer_tp == 0
- ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
- rollout_device_mesh = init_device_mesh(
- device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
- )
- rollout_name = self.config.rollout.name
- if rollout_name == "hf":
- from verl.workers.rollout import HFRollout
- from verl.workers.sharding_manager.base import BaseShardingManager
-
- rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
- rollout_sharding_manager = BaseShardingManager()
- # TODO: a sharding manager that do nothing?
-
- elif rollout_name == "vllm":
- from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
- from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
-
- log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
- local_path = copy_to_local(
- self.config.model.path, use_shm=self.config.model.get("use_shm", False)
- )
- lora_kwargs = (
- {
- "lora_kwargs": {
- "enable_lora": True,
- "max_loras": 1,
- "max_lora_rank": self._lora_rank,
- }
- }
- if self._is_lora
- else {}
- )
- # lora_kwargs = {}
- if vllm_mode == "customized":
- rollout = vLLMRollout(
- actor_module=self.actor_module_fsdp,
- config=self.config.rollout,
- tokenizer=self.tokenizer,
- model_hf_config=self.actor_model_config,
- trust_remote_code=trust_remote_code,
- **lora_kwargs,
- )
- elif vllm_mode == "spmd":
- from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
-
- vllm_rollout_cls = (
- vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
- )
- rollout = vllm_rollout_cls(
- model_path=local_path,
- config=self.config.rollout,
- tokenizer=self.tokenizer,
- model_hf_config=self.actor_model_config,
- device_mesh=rollout_device_mesh,
- trust_remote_code=trust_remote_code,
- **lora_kwargs,
- )
- else:
- raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
-
- log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
- full_params = torch.distributed.get_world_size() == 1
- rollout_sharding_manager = FSDPVLLMShardingManager(
- module=self.actor_module_fsdp,
- inference_engine=rollout.inference_engine,
- model_config=self.actor_model_config,
- full_params=full_params,
- device_mesh=rollout_device_mesh,
- offload_param=self._is_offload_param,
- load_format=self.config.rollout.load_format,
- layered_summon=self.config.rollout.get("layered_summon", False),
- )
- log_gpu_memory_usage("After building sharding manager", logger=logger)
-
- elif rollout_name in ["sglang", "sglang_async"]:
- if rollout_name == "sglang_async":
- warnings.warn(
- "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.",
- DeprecationWarning,
- stacklevel=2,
- )
- from verl.workers.rollout.sglang_rollout import SGLangRollout
-
- # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
- # SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
- # the main process of ray can not find any CUDA device, which would potentially lead to:
- # "RuntimeError: No CUDA GPUs are available".
- # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
- # we import it here use the abs path.
- # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
- from verl.workers.sharding_manager.fsdp_sglang import (
- FSDPSGLangShardingManager,
- )
-
- local_path = copy_to_local(self.config.model.path)
- log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
- rollout = SGLangRollout(
- actor_module=local_path,
- config=self.config.rollout,
- tokenizer=self.tokenizer,
- model_hf_config=self.actor_model_config,
- trust_remote_code=trust_remote_code,
- )
- log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
-
- if torch.distributed.get_world_size() == 1:
- self.config.rollout.load_format = "dummy_hf"
- rollout_sharding_manager = FSDPSGLangShardingManager(
- module=self.actor_module_fsdp,
- inference_engine=rollout._engine,
- model_config=self.actor_model_config,
- full_params="hf" in self.config.rollout.load_format,
- device_mesh=rollout_device_mesh,
- offload_param=self._is_offload_param,
- )
- log_gpu_memory_usage("After building sharding manager", logger=logger)
-
- else:
- raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
-
- return rollout, rollout_sharding_manager
-
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from trinity.trainer.verl.dp_actor import DataParallelPPOActor
@@ -597,14 +454,10 @@ def init_model(self):
use_shm = self.config.model.get("use_shm", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
- if self._is_actor or self._is_rollout:
+ if self._is_actor:
# we need the model for actor and rollout
- if self._is_actor:
- optim_config = self.config.actor.optim
- fsdp_config = self.config.actor.fsdp_config
- else:
- optim_config = None
- fsdp_config = OmegaConf.create()
+ optim_config = self.config.actor.optim
+ fsdp_config = self.config.actor.fsdp_config
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
(
@@ -651,11 +504,6 @@ def init_model(self):
actor_optimizer=self.actor_optimizer,
)
- if self._is_rollout:
- self.rollout, self.rollout_sharding_manager = self._build_rollout(
- trust_remote_code=self.config.model.get("trust_remote_code", False)
- )
-
if self._is_ref:
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
self.ref_module_fsdp = self._build_model_optimizer(
@@ -713,7 +561,9 @@ def setup_weight_sync_group(self):
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name
)
- self.state_dict_meta.append((realname, param.dtype, param.shape))
+ self.state_dict_meta.append(
+ (realname, str(param.dtype), tuple(param.shape))
+ )
param = None
torch.cuda.empty_cache()
@@ -815,38 +665,6 @@ def update_actor(self, data: DataProto):
return output
- @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
- def generate_sequences(self, prompts: DataProto):
- # Support all hardwares
- prompts = prompts.to(get_torch_device().current_device())
-
- assert self._is_rollout
-
- meta_info = {
- "eos_token_id": self.generation_config.eos_token_id
- if self.generation_config is not None
- else self.tokenizer.eos_token_id,
- "pad_token_id": self.generation_config.pad_token_id
- if self.generation_config is not None
- else self.tokenizer.pad_token_id,
- }
- prompts.meta_info.update(meta_info)
- with self.rollout_sharding_manager:
- log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
-
- prompts = self.rollout_sharding_manager.preprocess_data(prompts)
- output = self.rollout.generate_sequences(prompts=prompts)
-
- log_gpu_memory_usage("After rollout generation", logger=logger)
-
- output = self.rollout_sharding_manager.postprocess_data(output)
-
- output = output.to("cpu")
-
- # clear kv cache
- get_torch_device().empty_cache()
- return output
-
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
# when is_lora is True, we use the actor without lora applied to calculate the log_prob
From 5cb9ebe0279f2c9effa83eda459ad1989499b35e Mon Sep 17 00:00:00 2001
From: pxc
Date: Thu, 19 Jun 2025 10:27:29 +0800
Subject: [PATCH 21/28] fix checkpoint sync mode
---
trinity/explorer/explorer.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 6b8e5b6288..87657c9b42 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -121,7 +121,7 @@ def _update_model_weight(self, state_dict: dict) -> None:
self.state_dict = state_dict
update_weight_args_list = []
for name, param in state_dict.items():
- update_weight_args_list.append((name, param.dtype, param.shape))
+ update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models])
self.state_dict.clear()
From f24db44b5d4f4cc48b72470a53c406fd5b5363bf Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Thu, 19 Jun 2025 14:45:12 +0800
Subject: [PATCH 22/28] Update config manager (#86)
---
.../source/tutorial/example_mix_algo.md | 2 +-
.../source/tutorial/trinity_configs.md | 31 --
.../tutorial/trinity_programming_guide.md | 6 +-
examples/async_gsm8k/verl_config.yaml | 21 -
examples/dpo_humanlike/train_dpo.yaml | 17 -
examples/grpo_alfworld/train_alfworld.yaml | 17 -
examples/grpo_gsm8k/train_gsm8k.yaml | 21 -
examples/grpo_math/train_math.yaml | 21 -
examples/grpo_sciworld/train_sciworld.yaml | 17 -
examples/grpo_webshop/train_webshop.yaml | 17 -
examples/mix_math/train_mix_math.yaml | 21 -
examples/opmd_gsm8k/train_opmd_gsm8k.yaml | 21 -
examples/ppo_countdown/train_countdown.yaml | 21 -
tests/template/verl_config.yaml | 17 -
trinity/algorithm/algorithm.py | 26 +-
trinity/algorithm/algorithm_manager.py | 2 +-
.../sample_strategy/mix_sample_strategy.py | 2 +-
trinity/buffer/reader/file_reader.py | 6 +-
trinity/common/config.py | 12 +-
trinity/common/verl_config.py | 44 +--
trinity/manager/config_manager.py | 210 +++++-----
trinity/manager/config_registry/__init__.py | 2 +
.../algorithm_config_manager.py | 371 ++++++++++++++++++
.../config_registry/buffer_config_manager.py | 12 +-
.../explorer_config_manager.py | 6 +-
.../config_registry/model_config_manager.py | 88 +----
.../config_registry/trainer_config_manager.py | 94 +----
27 files changed, 544 insertions(+), 581 deletions(-)
create mode 100644 trinity/manager/config_registry/algorithm_config_manager.py
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index 61ecec33b1..de664cae4a 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -46,7 +46,7 @@ class MIXAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"policy_loss_fn": "mix",
diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index dbb8402ceb..8c0cab9a0a 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -376,11 +376,6 @@ actor_rollout_ref:
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: False # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -399,10 +394,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -447,22 +438,6 @@ critic:
grad_clip: 1.0
cliprange_value: 0.5
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- norm_adv_by_std_in_grpo: True
- use_kl_in_reward: False
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
- horizon: 10000
- target_kl: 0.1
-
trainer:
balance_batch: True
# total_training_steps: null
@@ -483,11 +458,7 @@ trainer:
- `actor_rollout_ref.model.use_remove_padding`: Whether to remove pad tokens, which will reduce training time.
- `actor_rollout_ref.actor.use_dynamic_bsz`: Whether to reorganize the batch data, specifically to splice the shorter data to reduce the batch size in the actual training process.
- `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: Batch size for one GPU in one forward pass.
-- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
-- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy.
-- `actor_rollout_ref.actor.opmd_baseline`: mean / logavgexp, applicable to opmd.
-- `actor_rollout_ref.actor.use_uid`: True / False, applicable to pairwise_opmd.
- `actor_rollout_ref.actor.optim.lr`: Learning rate for actor model.
- `actor_rollout_ref.actor.optim.lr_warmup_steps_ratio`: Ratio of warmup steps for learning rate.
- `actor_rollout_ref.actor.optim.warmup_style`: Warmup style for learning rate.
@@ -505,8 +476,6 @@ trainer:
- `critic.grad_clip`: Gradient clip for critic model training.
- `critic.cliprange_value`: Used for compute value loss.
-- `algorithm`: Training algorithm settings.
-
- `trainer.balance_batch`: Whether to balance batch size between GPUs during training.
- `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`.
- `trainer.resume_from_path`: Path to resume from.
diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
index 931cb81506..e07e6bb3dc 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
@@ -443,13 +443,13 @@ The `AlgorithmType` class includes the following attributes and methods:
- `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped
- `can_balance_batch`: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples)
- `schema`: The format of experience data corresponding to the algorithm
-- `get_default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`
+- `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`
Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`.
Below is the implementation for the OPMD algorithm.
Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
-The dictionary returned by the `get_default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.
+The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.
```python
@ALGORITHM_TYPE.register_module("opmd")
@@ -463,7 +463,7 @@ class OPMDAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
diff --git a/examples/async_gsm8k/verl_config.yaml b/examples/async_gsm8k/verl_config.yaml
index de1b08f590..fc44fdad94 100644
--- a/examples/async_gsm8k/verl_config.yaml
+++ b/examples/async_gsm8k/verl_config.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -33,10 +28,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -48,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml
index 028c997e06..d5074848b0 100644
--- a/examples/dpo_humanlike/train_dpo.yaml
+++ b/examples/dpo_humanlike/train_dpo.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True
- kl_loss_coef: 0.1 # NOTE: beta for DPO
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -46,18 +41,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: False
total_training_steps: 783 #
diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml
index 215b1817ab..5b73ec7403 100644
--- a/examples/grpo_alfworld/train_alfworld.yaml
+++ b/examples/grpo_alfworld/train_alfworld.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -44,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml
index de1b08f590..fc44fdad94 100644
--- a/examples/grpo_gsm8k/train_gsm8k.yaml
+++ b/examples/grpo_gsm8k/train_gsm8k.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -33,10 +28,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -48,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml
index 78bcb862c6..0a46bd1788 100644
--- a/examples/grpo_math/train_math.yaml
+++ b/examples/grpo_math/train_math.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.0001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -33,10 +28,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -48,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.0001
-
trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml
index 215b1817ab..5b73ec7403 100644
--- a/examples/grpo_sciworld/train_sciworld.yaml
+++ b/examples/grpo_sciworld/train_sciworld.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -44,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml
index 215b1817ab..5b73ec7403 100644
--- a/examples/grpo_webshop/train_webshop.yaml
+++ b/examples/grpo_webshop/train_webshop.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -44,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml
index 7b14a87fad..ca072b78f6 100644
--- a/examples/mix_math/train_mix_math.yaml
+++ b/examples/mix_math/train_mix_math.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: True # True for GRPO
- kl_loss_coef: 0.0001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -33,10 +28,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -48,18 +39,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.0001
-
trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
index 44a0111d64..5ddd5124ee 100644
--- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml
@@ -36,11 +36,6 @@ actor_rollout_ref:
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.000
- use_kl_loss: True
- kl_loss_coef: 0.001
- kl_loss_type: mse
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -58,10 +53,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 4.0 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -73,18 +64,6 @@ actor_rollout_ref:
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.000
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml
index ae16122ef7..191c345b90 100644
--- a/examples/ppo_countdown/train_countdown.yaml
+++ b/examples/ppo_countdown/train_countdown.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: False # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -35,10 +30,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -82,18 +73,6 @@ critic:
grad_clip: 1.0
cliprange_value: 0.5
-custom_reward_function:
- path: null
- name: compute_score
-
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml
index b17fc87958..d6dcf4a997 100644
--- a/tests/template/verl_config.yaml
+++ b/tests/template/verl_config.yaml
@@ -12,11 +12,6 @@ actor_rollout_ref:
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
- clip_ratio: 0.2
- entropy_coeff: 0.001
- use_kl_loss: False # True for GRPO
- kl_loss_coef: 0.001 # for grpo
- kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
@@ -35,10 +30,6 @@ actor_rollout_ref:
param_offload: False
optimizer_offload: False
fsdp_size: -1
- # --- below: opmd ---
- tau: 0.000 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: mean # mean / logavgexp, applicable to opmd
- use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
@@ -82,14 +73,6 @@ critic:
grad_clip: 1.0
cliprange_value: 0.5
-algorithm:
- gamma: 1.0
- lam: 1.0
- kl_penalty: kl # how to estimate kl divergence
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
-
trainer:
balance_batch: True
# total_training_steps: null
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index 805dd8f213..54f5c3d296 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""Algorithm classes."""
-from abc import ABC, ABCMeta
+from abc import ABC, ABCMeta, abstractmethod
from typing import Dict
from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
@@ -30,7 +30,8 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
schema: type
@classmethod
- def get_default_config(cls) -> Dict:
+ @abstractmethod
+ def default_config(cls) -> Dict:
raise NotImplementedError
@classmethod
@@ -53,7 +54,7 @@ class SFTAlgorithm(AlgorithmType):
schema: type = SFTDataModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"sample_strategy": "default",
"policy_loss_fn": "sft",
@@ -73,7 +74,7 @@ class PPOAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 1,
"sample_strategy": "warmup",
@@ -96,7 +97,7 @@ class GRPOAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
@@ -119,7 +120,7 @@ class OPMDAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
@@ -142,9 +143,8 @@ class DPOAlgorithm(AlgorithmType):
schema: type = DPODataModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
- "repeat_times": 2, # fake repeat times
"sample_strategy": "dpo",
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
@@ -170,10 +170,10 @@ def check_config(cls, config: Config) -> None:
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if config.algorithm.repeat_times != 2:
- config.algorithm.repeat_times = 2
- logger.warning(
- "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2."
- ) # no need to warn
+ config.algorithm.repeat_times = 2 # Fake repeat times
+ if config.algorithm.kl_loss_fn in {"none", None}:
+ config.algorithm.kl_loss_fn = "k2"
+ logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`")
@ALGORITHM_TYPE.register_module("mix")
@@ -188,7 +188,7 @@ class MIXAlgorithm(AlgorithmType):
schema: type = ExperienceModel
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"policy_loss_fn": "mix",
diff --git a/trinity/algorithm/algorithm_manager.py b/trinity/algorithm/algorithm_manager.py
index 3c2983c80b..82cef5ebbd 100644
--- a/trinity/algorithm/algorithm_manager.py
+++ b/trinity/algorithm/algorithm_manager.py
@@ -12,7 +12,7 @@ class AlgorithmManager:
def __init__(self, config: Config):
self.config = config
sft_type = ALGORITHM_TYPE.get("sft")
- sft_default_config = sft_type.get_default_config()
+ sft_default_config = sft_type.default_config()
self.sft_algorithm_config = AlgorithmConfig(
algorithm_type="sft",
**sft_default_config,
diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
index acdd340b24..25811e9190 100644
--- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py
+++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py
@@ -79,7 +79,7 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
@classmethod
- def get_default_config(cls) -> Dict:
+ def default_args(cls) -> Dict:
return {
"expert_data_ratio": 0.5,
}
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 8bb9dbcd28..0dd9aef75e 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -67,7 +67,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.response_key = meta.format.response_key
self.read_batch_size = config.read_batch_size
self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=subset_name, split=self.split)
+ load_dataset(meta.path, name=subset_name, split=self.split),
+ max_epoch=meta.total_epochs,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -143,7 +144,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.rejected_key = meta.format.rejected_key
self.read_batch_size = config.read_batch_size
self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=subset_name, split=self.split)
+ load_dataset(meta.path, name=subset_name, split=self.split),
+ max_epoch=meta.total_epochs,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 52f8c433fc..9c45627d32 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -206,15 +206,6 @@ class AlgorithmConfig:
# TODO: move this to SFT warmup
use_token_level_loss: bool = True
- # do not set
- algorithm_manager: Optional[Any] = None
-
- def get_current_algorithm_config(self, global_steps: int):
- return self.algorithm_manager.get_current_algorithm_config(global_steps)
-
- def need_save(self, global_steps: int):
- return self.algorithm_manager.need_save(global_steps)
-
@dataclass
class ClusterConfig:
@@ -303,7 +294,6 @@ class TrainerConfig:
# trainer configs
actor_grad_clip: Optional[float] = None
- actor_clip_ratio: Optional[float] = None
# TODO: extract more train-related params from underlying trainer engine
# Only one needs to be set for `trainer_config` and `trainer_config_path`
@@ -525,7 +515,7 @@ def _check_algorithm(self) -> None:
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
- default_config.update(algorithm.get_default_config())
+ default_config.update(algorithm.default_config())
for key, value in default_config.items():
if getattr(self.algorithm, key, None) is None:
setattr(self.algorithm, key, value)
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index e6b1b9e4e1..1ec0653503 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -4,7 +4,6 @@
from omegaconf import OmegaConf
-from trinity.algorithm.algorithm import DPOAlgorithm
from trinity.common.config import BufferConfig, Config, SynchronizerConfig
from trinity.utils.log import get_logger
@@ -66,22 +65,19 @@ class Actor:
16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
)
grad_clip: float = 1.0
- clip_ratio: float = 0.2
- entropy_coeff: float = 0.001
- use_kl_loss: bool = False
- kl_loss_coef: float = 0.001
- kl_loss_type: str = "low_var_kl"
ppo_epochs: int = 1
shuffle: bool = False
ulysses_sequence_parallel_size: int = 1
checkpoint: Checkpoint = field(default_factory=Checkpoint)
optim: Optim = field(default_factory=Optim)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
- algorithm_type: str = "ppo" # TODO
- tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
- opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
- use_uid: bool = False # True / False, applicable to pairwise_opmd
- loss_agg_mode: str = "token-mean" # do not set
+ # do not set
+ loss_agg_mode: str = "token-mean"
+ clip_ratio: float = 0.2
+ entropy_coeff: float = 0.001
+ use_kl_loss: bool = False
+ kl_loss_coef: float = 0.001
+ kl_loss_type: str = "low_var_kl"
@dataclass
@@ -208,10 +204,6 @@ class Algorithm:
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
- # ! DO NOT SET THE FOLLOWING PARAMETERS
- policy_loss_fn: str = "ppo"
- policy_loss_fn_args: Optional[dict] = None
-
@dataclass
class Trainer:
@@ -323,33 +315,19 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
if config.trainer.actor_grad_clip is not None:
self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip
- if config.trainer.actor_clip_ratio is not None:
- self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio
# Algorithm related config
- adv_fn_args = config.algorithm.advantage_fn_args
- if adv_fn_args is not None and "gamma" in adv_fn_args:
- self.algorithm.gamma = adv_fn_args["gamma"]
- if adv_fn_args is not None and "lam" in adv_fn_args:
- self.algorithm.lam = adv_fn_args["lam"]
- self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
- self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore
- self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore
- "entropy_coef"
- ]
+ self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
# True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper).
# Need to double check whether this is indeed the case,
# and see if adv_estimator can be removed completely.
- if isinstance(self.actor_rollout_ref.actor.algorithm_type, DPOAlgorithm): # for DPO
- if not self.actor_rollout_ref.actor.use_kl_loss:
- self.actor_rollout_ref.actor.use_kl_loss = True
- logger.warning("DPO must use KL loss.")
+ if config.algorithm.algorithm_type == "dpo": # for DPO
logger.warning("DPO micro batch size is doubled for computing loss.")
- self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore
- self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore
+ self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2
+ self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2
if self.actor_rollout_ref.rollout.n != 2:
self.actor_rollout_ref.rollout.n = 2
# TODO: check other fields
diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py
index 80b8992b3b..de4305a9cc 100644
--- a/trinity/manager/config_manager.py
+++ b/trinity/manager/config_manager.py
@@ -7,10 +7,25 @@
import streamlit as st
import yaml
-from trinity.common.constants import AlgorithmType, StorageType
+from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN
+from trinity.algorithm.algorithm import ALGORITHM_TYPE
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN
+from trinity.algorithm.kl_fn.kl_fn import KL_FN
+from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
+from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY
+from trinity.common.constants import StorageType
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.trainer_config_manager import use_critic
+register_map = {
+ "sample_strategy": SAMPLE_STRATEGY,
+ "policy_loss_fn": POLICY_LOSS_FN,
+ "advantage_fn": ADVANTAGE_FN,
+ "kl_loss_fn": KL_FN,
+ "kl_penalty_fn": KL_FN,
+ "entropy_loss_fn": ENTROPY_LOSS_FN,
+}
+
class ConfigManager:
def __init__(self):
@@ -47,55 +62,48 @@ def maintain_session_state(self):
for key in CONFIG_GENERATORS.default_config:
st.session_state[key] = st.session_state[key]
- eval_dataset_keys = [
+ def maintain_list_state(prefix, key_list):
+ last_idx, del_num = 0, 0
+ for idx in range(st.session_state[f"_{prefix}_num"]):
+ if st.session_state.get(f"{prefix}_{idx}_del_flag", False):
+ del_num += 1
+ continue
+ for key in key_list:
+ full_key = f"{prefix}_{idx}_{key}"
+ last_full_key = f"{prefix}_{last_idx}_{key}"
+ st.session_state[last_full_key] = st.session_state[full_key]
+ last_idx += 1
+ st.session_state[f"_{prefix}_num"] -= del_num
+
+ self.eval_dataset_keys = [
"name",
"path",
- "subset_name",
"split",
+ "subset_name",
"prompt_key",
"response_key",
"temperature",
"logprobs",
"n",
]
- last_idx, del_num = 0, 0
- for idx in range(st.session_state["_eval_tasksets_num"]):
- if st.session_state.get(f"eval_taskset_{idx}_del_flag", False):
- del_num += 1
- continue
- for key in eval_dataset_keys:
- full_key = f"eval_taskset_{idx}_{key}"
- last_full_key = f"eval_taskset_{last_idx}_{key}"
- st.session_state[last_full_key] = st.session_state[full_key]
- last_idx += 1
- st.session_state["_eval_tasksets_num"] -= del_num
-
- auxiliary_model_keys = [
+ maintain_list_state("eval_tasksets", self.eval_dataset_keys)
+
+ self.inference_model_keys = [
"model_path",
"engine_type",
"engine_num",
"tensor_parallel_size",
- "gpu_memory_utilization",
- "dtype",
- "seed",
"use_v1",
"enforce_eager",
"enable_prefix_caching",
"enable_chunked_prefill",
+ "gpu_memory_utilization",
+ "dtype",
+ "seed",
"enable_thinking",
"enable_openai_api",
]
- last_idx, del_num = 0, 0
- for idx in range(st.session_state["_auxiliary_models_num"]):
- if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False):
- del_num += 1
- continue
- for key in auxiliary_model_keys:
- full_key = f"auxiliary_model_{idx}_{key}"
- last_full_key = f"auxiliary_model_{last_idx}_{key}"
- st.session_state[last_full_key] = st.session_state[full_key]
- last_idx += 1
- st.session_state["_auxiliary_models_num"] -= del_num
+ maintain_list_state("auxiliary_models", self.inference_model_keys)
def get_configs(self, *config_names: str, columns_spec: List[int] = None):
CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec)
@@ -108,7 +116,7 @@ def beginner_mode(self):
self.get_configs("checkpoint_root_dir")
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
self.get_configs("taskset_path")
else:
self.get_configs("experience_buffer_path")
@@ -126,7 +134,7 @@ def beginner_mode(self):
self.get_configs("sync_interval", "eval_interval", "save_interval")
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
self.get_configs("taskset_args")
else:
self.get_configs("dpo_dataset_kwargs")
@@ -136,9 +144,6 @@ def beginner_mode(self):
self.get_configs("default_workflow_type", "default_reward_fn_type")
- self.get_configs("actor_use_kl_loss")
- self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type")
-
self.get_configs(
"actor_ppo_micro_batch_size_per_gpu",
"actor_lr",
@@ -165,7 +170,7 @@ def _expert_buffer_part(self):
self.get_configs("system_prompt")
self.get_configs("reply_prefix")
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
with st.expander("Taskset Configs", expanded=True):
self.get_configs("taskset_path")
self.get_configs("taskset_args")
@@ -182,7 +187,7 @@ def _expert_buffer_part(self):
self.get_configs("sft_warmup_dataset_path")
self.get_configs("sft_warmup_dataset_args")
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
with st.expander("Experiences Buffer Configs", expanded=True):
self.get_configs("storage_type")
self.get_configs("experience_buffer_path")
@@ -213,8 +218,30 @@ def _expert_explorer_part(self):
self.get_configs("auxiliary_models")
def _expert_trainer_part(self):
- self.get_configs("algorithm_type", "gamma", "lam")
- self.get_configs("repeat_times", "save_interval")
+ self.get_configs("algorithm_type", "repeat_times", "save_interval")
+ self.get_configs("sample_strategy", "advantage_fn", "entropy_loss_fn")
+ self.get_configs("policy_loss_fn", "kl_penalty_fn", "kl_loss_fn")
+
+ with st.expander("Advanced Algorithm Config"):
+ algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"])
+ default_config = algorithm.default_config()
+ config_key_list = []
+ for key in default_config.keys():
+ value = st.session_state[key]
+ if key == "repeat_times":
+ continue
+ default_args = register_map[key].get(value).default_args()
+ for sub_key in default_args.keys():
+ full_key = sub_key + "_in_" + key
+ config_key_list.append(full_key)
+
+ idx = 0
+ while idx < len(config_key_list):
+ delta = 3 if len(config_key_list) - idx != 4 else 2
+ key_list = config_key_list[idx : idx + delta]
+ idx += delta
+ self.get_configs(*key_list)
+
self.get_configs("enable_preview")
if st.session_state["trainer_type"] == "verl":
@@ -238,12 +265,6 @@ def _expert_verl_training_part(self):
self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep")
- def _expert_verl_algorithm_part(self):
- st.subheader("RL Algorithm Config")
- self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward")
- self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef")
- self.get_configs("horizon", "target_kl")
-
def _expert_verl_actor_part(self):
st.subheader("Actor Model Config")
self.get_configs(
@@ -254,12 +275,7 @@ def _expert_verl_actor_part(self):
self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio")
- self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef")
-
- self.get_configs("actor_use_kl_loss", "actor_use_uid")
- self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type")
-
- self.get_configs("actor_tau", "actor_opmd_baseline")
+ self.get_configs("actor_grad_clip")
self.get_configs("actor_checkpoint")
@@ -277,7 +293,6 @@ def _expert_verl_critic_part(self):
def _expert_verl_trainer_part(self):
name2func = {
"RL Training Config": self._expert_verl_training_part,
- "RL Algorithm Config": self._expert_verl_algorithm_part,
"Actor and Ref Config": self._expert_verl_actor_part,
}
if use_critic():
@@ -359,9 +374,6 @@ def _generate_verl_config(self):
),
},
"fsdp_config": copy.deepcopy(fsdp_config),
- "tau": st.session_state["actor_tau"],
- "opmd_baseline": st.session_state["actor_opmd_baseline"],
- "use_uid": st.session_state["actor_use_uid"],
},
"ref": {
"fsdp_config": copy.deepcopy(fsdp_config),
@@ -375,14 +387,7 @@ def _generate_verl_config(self):
],
},
},
- "custom_reward_function": {"path": None, "name": "compute_score"},
- "algorithm": {
- "kl_penalty": st.session_state["kl_penalty"],
- "kl_ctrl": {
- "type": st.session_state["kl_ctrl_type"],
- "kl_coef": st.session_state["kl_ctrl_coef"],
- },
- },
+ "critic": {},
"trainer": {
"balance_batch": balance_batch,
"resume_mode": st.session_state["resume_mode"],
@@ -436,11 +441,35 @@ def _generate_verl_config(self):
"cliprange_value": st.session_state["critic_cliprange_value"],
"checkpoint": {"contents": st.session_state["critic_checkpoint"]},
}
+ else:
+ del trainer_config["critic"]
return trainer_config
+ def _gen_algorithm_config(self):
+ algorithm_config = {
+ "algorithm_type": st.session_state["algorithm_type"],
+ }
+ algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"])
+ default_config = algorithm.default_config()
+ current_config = {}
+ for key in default_config.keys():
+ current_config[key] = value = st.session_state[key]
+ if key == "repeat_times":
+ continue
+ default_args = register_map[key].get(value).default_args()
+ args = {}
+ for sub_key in default_args.keys():
+ full_key = sub_key + "_in_" + key
+ args[sub_key] = st.session_state.get(full_key, default_args[sub_key])
+ if default_args != args:
+ current_config[key + "_args"] = args
+ if default_config != current_config:
+ algorithm_config.update(current_config)
+ return algorithm_config
+
def _gen_buffer_config(self):
experience_buffer_path = st.session_state["experience_buffer_path"].strip()
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
if (
not experience_buffer_path
and st.session_state["storage_type"] == StorageType.SQL.value
@@ -456,6 +485,7 @@ def _gen_buffer_config(self):
buffer_config = {
"batch_size": st.session_state["train_batch_size"],
"total_epochs": st.session_state["total_epochs"],
+ "explorer_input": {},
"trainer_input": {
"experience_buffer": {
"name": "experience_buffer",
@@ -497,13 +527,25 @@ def _gen_buffer_config(self):
{
"name": st.session_state[f"eval_taskset_{idx}_name"],
"path": st.session_state[f"eval_taskset_{idx}_path"],
- "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"],
"split": st.session_state[f"eval_taskset_{idx}_split"],
- "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"],
- "response_key": st.session_state[f"eval_taskset_{idx}_response_key"],
+ "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"],
+ "format": {
+ "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"],
+ "response_key": st.session_state[
+ f"eval_taskset_{idx}_response_key"
+ ],
+ },
+ "rollout_args": {
+ "temperature": st.session_state[f"eval_taskset_{idx}_temperature"],
+ "logprobs": st.session_state[f"eval_taskset_{idx}_logprobs"],
+ "n": st.session_state[f"eval_taskset_{idx}_n"],
+ },
}
)
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ else:
+ del buffer_config["explorer_input"]
+
+ if st.session_state["algorithm_type"] == "dpo":
experience_buffer = buffer_config["trainer_input"]["experience_buffer"]
experience_buffer["split"] = st.session_state["dpo_dataset_train_split"]
experience_buffer["format"] = {
@@ -534,26 +576,23 @@ def _gen_explorer_config(self):
"max_timeout": st.session_state["max_timeout"],
"max_retry_times": st.session_state["explorer_max_retry_times"],
"rollout_model": {
- "engine_type": st.session_state["engine_type"],
- "engine_num": st.session_state["engine_num"],
- "tensor_parallel_size": st.session_state["tensor_parallel_size"],
- "use_v1": st.session_state["use_v1"],
- "enforce_eager": st.session_state["enforce_eager"],
- "enable_prefix_caching": st.session_state["enable_prefix_caching"],
- "enable_chunked_prefill": st.session_state["enable_chunked_prefill"],
- "gpu_memory_utilization": st.session_state["gpu_memory_utilization"],
- "dtype": st.session_state["dtype"],
- "seed": st.session_state["seed"],
+ key: st.session_state[key]
+ for key in self.inference_model_keys
+ if key != "model_path"
# "max_prompt_tokens": None, # TODO
# "max_response_tokens": None, # TODO
# "chat_template": None, # TODO: add chat template
- "enable_thinking": st.session_state["enable_thinking"],
- "enable_openai_api": st.session_state["enable_openai_api"],
},
"auxiliary_models": [],
"eval_interval": st.session_state["eval_interval"],
"eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"],
}
+ for i in range(st.session_state["_auxiliary_models_num"]):
+ auxiliary_model_config = {
+ key: st.session_state[f"auxiliary_model_{i}_{key}"]
+ for key in self.inference_model_keys
+ }
+ explorer_config["auxiliary_models"].append(auxiliary_model_config)
return explorer_config
def generate_config(self):
@@ -585,12 +624,7 @@ def generate_config(self):
"project": st.session_state["project"],
"name": st.session_state["exp_name"],
"checkpoint_root_dir": st.session_state["checkpoint_root_dir"],
- "algorithm": {
- "algorithm_type": st.session_state["algorithm_type"],
- "repeat_times": st.session_state["repeat_times"],
- "gamma": st.session_state["gamma"],
- "lam": st.session_state["lam"],
- },
+ "algorithm": self._gen_algorithm_config(),
"data_processor": {}, # TODO: Add data processor config
"model": {
"model_path": st.session_state["model_path"],
@@ -607,11 +641,7 @@ def generate_config(self):
"trainer_type": st.session_state["trainer_type"],
"save_interval": st.session_state["save_interval"],
"enable_preview": st.session_state["enable_preview"],
- "actor_use_kl_loss": st.session_state["actor_use_kl_loss"],
- "actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"],
- "actor_entropy_coef": st.session_state["actor_entropy_coef"],
"actor_grad_clip": st.session_state["actor_grad_clip"],
- "actor_clip_ratio": st.session_state["actor_clip_ratio"],
"trainer_config": trainer_config,
},
"monitor": {
diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py
index e62c565fb4..3896582755 100644
--- a/trinity/manager/config_registry/__init__.py
+++ b/trinity/manager/config_registry/__init__.py
@@ -1,3 +1,4 @@
+import trinity.manager.config_registry.algorithm_config_manager as algorithm_config_manager
import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager
import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager
import trinity.manager.config_registry.model_config_manager as model_config_manager
@@ -6,6 +7,7 @@
__all__ = [
"CONFIG_GENERATORS",
+ "algorithm_config_manager",
"buffer_config_manager",
"explorer_config_manager",
"model_config_manager",
diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py
new file mode 100644
index 0000000000..c9694dec25
--- /dev/null
+++ b/trinity/manager/config_registry/algorithm_config_manager.py
@@ -0,0 +1,371 @@
+import streamlit as st
+
+from trinity.algorithm.advantage_fn import (
+ ADVANTAGE_FN,
+ GRPOAdvantageFn,
+ OPMDAdvantageFn,
+ PPOAdvantageFn,
+)
+from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm
+from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import (
+ ENTROPY_LOSS_FN,
+ EntropyLossFn,
+)
+from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn
+from trinity.algorithm.policy_loss_fn import (
+ POLICY_LOSS_FN,
+ DPOLossFn,
+ MIXPolicyLossFn,
+ OPMDPolicyLossFn,
+ PPOPolicyLossFn,
+ SFTLossFn,
+)
+from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy
+from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
+from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value="ppo",
+ other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()},
+)
+def set_algorithm_type(**kwargs):
+ def on_change():
+ if st.session_state["algorithm_type"] == "dpo":
+ st.session_state["mode"] = "train"
+ else:
+ st.session_state["mode"] = "both"
+ algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"])
+ default_config = algorithm.default_config()
+ st.session_state["_current_default_config"] = default_config
+ for key, value in default_config.items():
+ st.session_state[key] = value
+ set_trainer_gpu_num()
+
+ candidates = list(ALGORITHM_TYPE.modules.keys())
+ st.selectbox(
+ "Algorithm Type",
+ candidates,
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["repeat_times"],
+ visible=lambda: "repeat_times" in st.session_state["_current_default_config"],
+ other_configs={
+ "_grouped_adv_repeat_times": 2,
+ "_not_grouped_adv_repeat_times": 1,
+ },
+)
+def set_repeat_times(**kwargs): # TODO
+ key = kwargs.get("key")
+ grouped_adv_algorithms = [
+ "grpo",
+ "opmd", # TODO: may add rloo
+ ]
+ if st.session_state["algorithm_type"] in grouped_adv_algorithms:
+ min_repeat_times = 2
+ st.session_state[key] = st.session_state["_grouped_adv_repeat_times"]
+ else:
+ min_repeat_times = 1
+ st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"]
+
+ def on_change():
+ if st.session_state["algorithm_type"] in grouped_adv_algorithms:
+ st.session_state["_grouped_adv_repeat_times"] = st.session_state[key]
+ else:
+ st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key]
+
+ st.number_input(
+ "Repeat Times",
+ min_value=min_repeat_times,
+ help="`repeat_times` is used to set how many experiences each task can generate, "
+ "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
+ on_change=on_change,
+ **kwargs,
+ )
+
+
+# Sample_strategy Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["sample_strategy"],
+ visible=lambda: "sample_strategy" in st.session_state["_current_default_config"],
+)
+def set_sample_strategy(**kwargs):
+ candidates = list(SAMPLE_STRATEGY.modules.keys())
+ st.selectbox(
+ "Sample Strategy",
+ candidates,
+ help="The sample strategy used to obtain experiences.",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=MixSampleStrategy.default_args()["expert_data_ratio"],
+ visible=lambda: st.session_state["sample_strategy"] == "mix",
+)
+def set_expert_data_ratio_in_sample_strategy(**kwargs):
+ st.number_input(
+ "Expert Data Ratio",
+ min_value=0.0,
+ max_value=1.0,
+ value=0.5,
+ help="The ratio of expert data to be used in the training.",
+ **kwargs,
+ )
+
+
+# Advantage Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["advantage_fn"],
+ visible=lambda: "advantage_fn" in st.session_state["_current_default_config"],
+)
+def set_advantage_fn(**kwargs):
+ candidates = list(ADVANTAGE_FN.modules.keys())
+ st.selectbox(
+ "Advantage Function",
+ candidates,
+ help="The advantage function used to compute advantages.",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAdvantageFn.default_args()["gamma"],
+ visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"},
+)
+def set_gamma_in_advantage_fn(**kwargs):
+ st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAdvantageFn.default_args()["lam"],
+ visible=lambda: st.session_state["advantage_fn"] == "ppo",
+)
+def set_lam_in_advantage_fn(**kwargs):
+ st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=GRPOAdvantageFn.default_args()["epsilon"],
+ visible=lambda: st.session_state["advantage_fn"] == "grpo",
+)
+def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message
+ st.number_input(
+ r"GRPO Epsilon",
+ help=r"""
+```python
+scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
+```
+""",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=OPMDAdvantageFn.default_args()["opmd_baseline"],
+ visible=lambda: st.session_state["advantage_fn"] == "opmd",
+)
+def set_opmd_baseline_in_advantage_fn(**kwargs):
+ st.selectbox(
+ "OPMD Baseline",
+ ["mean", "logavgexp"],
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=OPMDAdvantageFn.default_args()["tau"],
+ visible=lambda: st.session_state["advantage_fn"] == "opmd"
+ and st.session_state["opmd_baseline_in_advantage_fn"] == "logavgexp",
+)
+def set_tau_in_advantage_fn(**kwargs):
+ st.number_input("Tau for OPMD Adv.", min_value=0.0, format="%.1e", **kwargs)
+
+
+# KL Loss Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["kl_loss_fn"],
+ visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"],
+)
+def set_kl_loss_fn(**kwargs):
+ candidates = list(KL_FN.modules.keys())
+ st.selectbox(
+ "KL Loss Type",
+ candidates,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=KLFn.default_args()["kl_coef"],
+ visible=lambda: st.session_state["kl_loss_fn"] != "none",
+)
+def set_kl_coef_in_kl_loss_fn(**kwargs):
+ st.number_input(
+ r"KL Loss Coef :blue-badge[$\beta$]",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+# KL Penalty Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["kl_penalty_fn"],
+ visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"],
+)
+def set_kl_penalty_fn(**kwargs):
+ candidates = list(KL_FN.modules.keys())
+ st.selectbox(
+ "KL Penalty Type",
+ candidates,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=KLFn.default_args()["adaptive"],
+ visible=lambda: st.session_state["kl_penalty_fn"] != "none",
+)
+def set_adaptive_in_kl_penalty_fn(**kwargs):
+ st.checkbox(
+ "Adaptive KL Penalty",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=KLFn.default_args()["kl_coef"],
+ visible=lambda: st.session_state["kl_penalty_fn"] != "none",
+)
+def set_kl_coef_in_kl_penalty_fn(**kwargs):
+ st.number_input(
+ r"KL Penalty Coef",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ **kwargs,
+ )
+
+
+# TODO: target_kl and horizon
+
+# Policy Loss Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["policy_loss_fn"],
+ visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"],
+)
+def set_policy_loss_fn(**kwargs):
+ candidates = list(POLICY_LOSS_FN.modules.keys())
+ st.selectbox(
+ "Policy Loss Fn",
+ candidates,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOPolicyLossFn.default_args()["clip_range"],
+ visible=lambda: st.session_state["policy_loss_fn"] in {"ppo", "mix"},
+)
+def set_clip_range_in_policy_loss_fn(**kwargs):
+ st.number_input(
+ "Clip Range",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=SFTLossFn.default_args()["use_token_level_loss"],
+ visible=lambda: st.session_state["policy_loss_fn"] == "sft",
+)
+def set_use_token_level_loss_in_policy_loss_fn(**kwargs):
+ st.checkbox(
+ "Use Token Level Loss",
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=DPOLossFn.default_args()["beta"],
+ visible=lambda: st.session_state["policy_loss_fn"] == "dpo",
+)
+def set_beta_in_policy_loss_fn(**kwargs):
+ st.number_input(
+ "Beta for DPO",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=DPOLossFn.default_args()["label_smoothing"],
+ visible=lambda: st.session_state["policy_loss_fn"] == "dpo",
+)
+def set_label_smoothing_in_policy_loss_fn(**kwargs):
+ st.number_input(
+ "Label Smoothing",
+ min_value=0.0,
+ max_value=1.0,
+ **kwargs,
+ )
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=OPMDPolicyLossFn.default_args()["tau"],
+ visible=lambda: st.session_state["policy_loss_fn"] == "opmd",
+)
+def set_tau_in_policy_loss_fn(**kwargs):
+ st.number_input("Tau for OPMD Loss", min_value=0.0, format="%.1e", **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=MIXPolicyLossFn.default_args()["mu"],
+ visible=lambda: st.session_state["policy_loss_fn"] == "mix",
+)
+def set_mu_in_policy_loss_fn(**kwargs):
+ st.number_input("Mu for Mix Policy Loss", min_value=0.0, **kwargs)
+
+
+# Entropy Loss Configs
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=PPOAlgorithm.default_config()["entropy_loss_fn"],
+ visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"],
+)
+def set_entropy_loss_fn(**kwargs):
+ candidates = list(ENTROPY_LOSS_FN.modules.keys())
+ st.selectbox("Entropy Loss Function", candidates, **kwargs)
+
+
+@CONFIG_GENERATORS.register_config(
+ default_value=EntropyLossFn.default_args()["entropy_coef"],
+ visible=lambda: st.session_state["entropy_loss_fn"] != "none",
+)
+def set_entropy_coef_in_entropy_loss_fn(**kwargs):
+ st.number_input(
+ "Entropy Coeff",
+ min_value=0.0,
+ max_value=1.0,
+ format="%.1e",
+ **kwargs,
+ )
diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py
index 044f982e94..f704d0ecd2 100644
--- a/trinity/manager/config_registry/buffer_config_manager.py
+++ b/trinity/manager/config_registry/buffer_config_manager.py
@@ -1,6 +1,6 @@
import streamlit as st
-from trinity.common.constants import AlgorithmType, PromptType, StorageType
+from trinity.common.constants import PromptType, StorageType
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS
from trinity.common.workflows.workflow import WORKFLOWS
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
@@ -264,7 +264,7 @@ def set_reply_prefix(**kwargs):
)
def set_storage_type(**kwargs):
key = kwargs.get("key")
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
st.session_state[key] = st.session_state["_dpo_storage_type"]
storage_candidates = [StorageType.FILE.value, StorageType.SQL.value]
else:
@@ -272,7 +272,7 @@ def set_storage_type(**kwargs):
storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value]
def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
st.session_state["_dpo_storage_type"] = st.session_state[key]
else:
st.session_state["_not_dpo_storage_type"] = st.session_state[key]
@@ -294,7 +294,7 @@ def on_change():
)
def set_experience_buffer_path(**kwargs): # TODO
key = kwargs.get("key")
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]:
st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"]
st.session_state[key] = st.session_state["_dpo_experience_buffer_path"]
@@ -314,7 +314,7 @@ def set_experience_buffer_path(**kwargs): # TODO
if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`."""
def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
st.session_state["_dpo_experience_buffer_path"] = st.session_state[key]
else:
st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key]
@@ -324,7 +324,7 @@ def on_change():
@CONFIG_GENERATORS.register_check()
def check_experience_buffer_path(unfinished_fields: set, key: str):
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
if not st.session_state[key].strip():
unfinished_fields.add(key)
st.warning("Please input DPO dataset path.")
diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py
index 9393187f60..12e8034a30 100644
--- a/trinity/manager/config_registry/explorer_config_manager.py
+++ b/trinity/manager/config_registry/explorer_config_manager.py
@@ -1,6 +1,6 @@
import streamlit as st
-from trinity.common.constants import AlgorithmType, SyncMethod
+from trinity.common.constants import SyncMethod
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num
@@ -255,7 +255,7 @@ def check_auxiliary_models(unfinished_fields: set, key: str):
)
def set_sync_method(**kwargs):
key = kwargs.get("key")
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] == "dpo":
st.session_state[key] = SyncMethod.CHECKPOINT.value
disabled = True
else:
@@ -263,7 +263,7 @@ def set_sync_method(**kwargs):
disabled = False
def on_change():
- if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
+ if st.session_state["algorithm_type"] != "dpo":
st.session_state["_not_dpo_sync_method"] = st.session_state[key]
st.selectbox(
diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py
index 837bf27679..f9014e58a1 100644
--- a/trinity/manager/config_registry/model_config_manager.py
+++ b/trinity/manager/config_registry/model_config_manager.py
@@ -2,10 +2,9 @@
import streamlit as st
-from trinity.common.constants import AlgorithmType, MonitorType
+from trinity.common.constants import MonitorType
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
from trinity.manager.config_registry.trainer_config_manager import use_critic
-from trinity.trainer.verl.ray_trainer import AdvantageEstimator
def set_total_gpu_num():
@@ -64,91 +63,6 @@ def set_monitor_type(**kwargs):
)
-# Algorithm Configs
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=AlgorithmType.PPO.value,
- other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value},
-)
-def set_algorithm_type(**kwargs):
- def on_change():
- if st.session_state["algorithm_type"] == AlgorithmType.PPO.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value
- elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["mode"] = "train"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value:
- st.session_state["mode"] = "both"
- st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value
- else: # TODO: add more algorithms
- pass
- set_trainer_gpu_num()
-
- st.selectbox(
- "Algorithm Type",
- [
- AlgorithmType.PPO.value,
- AlgorithmType.GRPO.value,
- AlgorithmType.DPO.value,
- AlgorithmType.OPMD.value,
- ],
- key="algorithm_type",
- on_change=on_change,
- )
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=1,
- visible=lambda: st.session_state["mode"] == "both",
- other_configs={
- "_grouped_adv_repeat_times": 2,
- "_not_grouped_adv_repeat_times": 1,
- },
-)
-def set_repeat_times(**kwargs): # TODO
- key = kwargs.get("key")
- grouped_adv_algorithms = [
- AlgorithmType.GRPO.value,
- AlgorithmType.OPMD.value, # TODO: may add rloo
- ]
- if st.session_state["algorithm_type"] in grouped_adv_algorithms:
- min_repeat_times = 2
- st.session_state[key] = st.session_state["_grouped_adv_repeat_times"]
- else:
- min_repeat_times = 1
- st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"]
-
- def on_change():
- if st.session_state["algorithm_type"] in grouped_adv_algorithms:
- st.session_state["_grouped_adv_repeat_times"] = st.session_state[key]
- else:
- st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key]
-
- st.number_input(
- "Repeat Times",
- min_value=min_repeat_times,
- help="`repeat_times` is used to set how many experiences each task can generate, "
- "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
- on_change=on_change,
- **kwargs,
- )
-
-
-@CONFIG_GENERATORS.register_config(default_value=1.0)
-def set_gamma(**kwargs):
- st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs)
-
-
-@CONFIG_GENERATORS.register_config(default_value=1.0)
-def set_lam(**kwargs):
- st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs)
-
-
# Model Configs
diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py
index d0f5d26897..9b3e5f3ea9 100644
--- a/trinity/manager/config_registry/trainer_config_manager.py
+++ b/trinity/manager/config_registry/trainer_config_manager.py
@@ -1,12 +1,13 @@
import streamlit as st
-from trinity.common.constants import AlgorithmType, SyncMethod
+from trinity.algorithm.algorithm import ALGORITHM_TYPE
+from trinity.common.constants import SyncMethod
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
-from trinity.trainer.verl.ray_trainer import AdvantageEstimator
def use_critic():
- return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value
+ algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"])
+ return algorithm.use_critic
@CONFIG_GENERATORS.register_config(default_value="verl")
@@ -18,7 +19,7 @@ def set_trainer_type(**kwargs):
def set_save_interval(**kwargs):
key = kwargs.get("key")
if (
- st.session_state["algorithm_type"] == AlgorithmType.DPO.value
+ st.session_state["algorithm_type"] == "dpo"
or st.session_state["sync_method"] == SyncMethod.NCCL.value
):
st.session_state[key] = st.session_state["_nccl_save_interval"]
@@ -29,7 +30,7 @@ def set_save_interval(**kwargs):
def on_change():
if (
- st.session_state["algorithm_type"] == AlgorithmType.DPO.value
+ st.session_state["algorithm_type"] == "dpo"
or st.session_state["sync_method"] == SyncMethod.NCCL.value
):
st.session_state["_nccl_save_interval"] = st.session_state[key]
@@ -49,54 +50,6 @@ def set_enable_preview(**kwargs):
st.checkbox("Enable Preview", **kwargs)
-def _actor_use_kl_loss_visible():
- if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
- st.session_state["actor_use_kl_loss"] = True
- return False
- return True
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=True,
- visible=_actor_use_kl_loss_visible,
- other_configs={"_not_dpo_actor_use_kl_loss": True},
-)
-def set_actor_use_kl_loss(**kwargs):
- key = kwargs.get("key")
- st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"]
-
- def on_change():
- st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key]
-
- st.checkbox("Use KL Loss", on_change=on_change, **kwargs)
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
-)
-def set_actor_kl_loss_coef(**kwargs):
- st.number_input(
- r"KL Loss Coef :blue-badge[$\beta$]",
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- **kwargs,
- )
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"]
-)
-def set_actor_entropy_coef(**kwargs):
- st.number_input(
- "Entropy Coeff",
- min_value=0.0,
- max_value=1.0,
- format="%.1e",
- **kwargs,
- )
-
-
@CONFIG_GENERATORS.register_config(default_value=1.0)
def set_actor_grad_clip(**kwargs):
st.number_input(
@@ -108,16 +61,6 @@ def set_actor_grad_clip(**kwargs):
)
-@CONFIG_GENERATORS.register_config(default_value=0.2)
-def set_actor_clip_ratio(**kwargs):
- st.number_input(
- r"Clip Ratio :blue-badge[$\epsilon$]",
- min_value=0.0,
- max_value=1.0,
- **kwargs,
- )
-
-
# veRL Trainer Configs
@@ -322,31 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs):
)
-@CONFIG_GENERATORS.register_config(
- default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd"
-)
-def set_actor_tau(**kwargs):
- st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs)
-
-
-@CONFIG_GENERATORS.register_config(
- default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd"
-)
-def set_actor_opmd_baseline(**kwargs):
- st.selectbox(
- "OPMD Baseline",
- ["mean", "logavgexp"],
- **kwargs,
- )
-
-
-@CONFIG_GENERATORS.register_config(
- default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd"
-)
-def set_actor_use_uid(**kwargs):
- st.checkbox("Use UID for OPMD", **kwargs)
-
-
@CONFIG_GENERATORS.register_config(default_value="low_var_kl")
def set_actor_kl_loss_type(**kwargs):
st.selectbox(
From c85d853b92df5a90e9a9213de33b2ac9333418c8 Mon Sep 17 00:00:00 2001
From: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com>
Date: Thu, 19 Jun 2025 15:56:47 +0800
Subject: [PATCH 23/28] Update docs (#89)
---
README.md | 4 +-
docs/sphinx_doc/source/main.md | 22 +++++------
.../source/tutorial/example_async_mode.md | 2 +-
.../tutorial/example_data_functionalities.md | 12 +++---
.../sphinx_doc/source/tutorial/example_dpo.md | 4 +-
.../source/tutorial/example_mix_algo.md | 14 ++++++-
.../source/tutorial/example_multi_turn.md | 10 ++---
.../tutorial/example_reasoning_basic.md | 8 +++-
.../source/tutorial/trinity_configs.md | 37 ++++++++++---------
9 files changed, 65 insertions(+), 48 deletions(-)
diff --git a/README.md b/README.md
index 3ccc1fa19c..c69ceba6cf 100644
--- a/README.md
+++ b/README.md
@@ -266,7 +266,7 @@ Then, for command-line users, run the RFT process with the following command:
trinity run --config
```
-> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
+> For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
> ```shell
> trinity run --config examples/grpo_gsm8k/gsm8k.yaml
> ```
@@ -279,7 +279,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
+ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md)
+ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md)
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
-+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md)
++ [Offline learning by DPO or SFT](./docs/sphinx_doc/source/tutorial/example_dpo.md)
+ [Advanced data processing / human-in-the-loop](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md)
diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md
index 4992854eef..5e886a7ab4 100644
--- a/docs/sphinx_doc/source/main.md
+++ b/docs/sphinx_doc/source/main.md
@@ -84,15 +84,18 @@ e.g., utilizing NCCL (when feasible) for model weight synchronization, sequence
## Getting started
-
-*Note: this project is currently under active development; comments and suggestions are welcome!*
-
+```{note}
+Note: This project is currently under active development; comments and suggestions are welcome!
+```
### Step 1: preparations
-
+Trinity-RFT requires
+Python version >= 3.10,
+CUDA version >= 12.4,
+and at least 2 GPUs.
Installation from source (recommended):
@@ -146,11 +149,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest
```
-Trinity-RFT requires
-Python version >= 3.10,
-CUDA version >= 12.4,
-and at least 2 GPUs.
-
### Step 2: prepare dataset and model
@@ -243,7 +241,7 @@ trinity run --config
-For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
+For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
```shell
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
@@ -251,7 +249,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
-More example config files can be found in `examples`.
+More example config files can be found in [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/).
@@ -260,7 +258,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
+ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md)
+ [Asynchronous mode of RFT](tutorial/example_async_mode.md)
+ [Multi-turn tasks](tutorial/example_multi_turn.md)
-+ [Offline learning by DPO](tutorial/example_dpo.md)
++ [Offline learning by DPO or SFT](tutorial/example_dpo.md)
+ [Advanced data processing / human-in-the-loop](tutorial/example_data_functionalities.md)
diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md
index a565145d83..1f9a9c8665 100644
--- a/docs/sphinx_doc/source/tutorial/example_async_mode.md
+++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md
@@ -1,6 +1,6 @@
# Asynchronous RFT
-This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
+This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset.
Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
index d62f56de3f..6558efcdd9 100644
--- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
+++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
@@ -26,14 +26,14 @@ python scripts/start_servers.py
### Configure the Data Module
-Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file.
+Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.
In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
```yaml
data_processor:
# basic info
- source_data_path: '/path/to/gsm8k'
+ source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
@@ -58,7 +58,7 @@ If you are not familiar with Data-Juicer, the data module provides a natural-lan
```yaml
data_processor:
# basic info
- source_data_path: '/path/to/gsm8k'
+ source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
@@ -100,7 +100,7 @@ After preparing the Data-Juicer data processing recipe, you can set the `dj_conf
```yaml
data_processor:
# basic info
- source_data_path: '/path/to/gsm8k'
+ source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
@@ -165,7 +165,7 @@ python scripts/start_servers.py
### Configure the Data Module
-Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file.
+Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.
In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
@@ -187,7 +187,7 @@ data_processor:
Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above.
-For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in `tests/test_configs/human_annotator_test_dj_cfg.yaml` that includes an OP of `human_preference_annotation_mapper`. For example:
+For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in [`tests/test_configs/human_annotator_test_dj_cfg.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/tests/test_configs/human_annotator_test_dj_cfg.yaml) that includes an OP of `human_preference_annotation_mapper`. For example:
```yaml
project_name: 'demo-human-annotator'
diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md
index b5846bc24b..cd0c214725 100644
--- a/docs/sphinx_doc/source/tutorial/example_dpo.md
+++ b/docs/sphinx_doc/source/tutorial/example_dpo.md
@@ -1,12 +1,12 @@
# Offline DPO and SFT
-This example describes DPO and SFT based on the Qwen-2.5-1.5B-Instruct model.
+This example describes DPO and SFT based on the Qwen2.5-1.5B-Instruct model.
## Step 1: Model and Data Preparation
### Model Preparation
-Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
+Download the Qwen2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
```shell
# Using Modelscope
diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
index de664cae4a..b106293eed 100644
--- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md
+++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md
@@ -25,9 +25,15 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format:
```json
-{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."}
+{
+ "messages": [
+ { "role": "system", "content": },
+ { "role": "user", "content": "What is the sum of 4 and 12?" },
+ { "role": "assistant", "content": "thinking process...\n16" } ]
+},
...
```
+The path to expert data is passed to `buffer.trainer_input.sft_warmup_dataset` for later use.
## Step 1: Define the Algorithm
@@ -296,3 +302,9 @@ algorithm:
read_batch_size_expert: 64
read_batch_size_usual: 192
```
+
+With the above configurations, the experiment can be run with the following command:
+
+```bash
+trinity run --config examples/mix_math/mix_math.yaml
+```
diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
index 46cc4ab32e..3cf5b89145 100644
--- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md
+++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
@@ -15,8 +15,8 @@ To run the ALFworld and WebShop env, you need to setup the corresponding environ
- WebShop is a simulated online shopping environment where AI agents learn to shop based on user requirements. The platform allows agents to browse products, compare options, and make purchase decisions, mimicking real-world e-commerce interactions.
You may refer to their original environment to complete the setup.
-- For ALFworld, refer to: https://github.com/alfworld/alfworld
-- For WebShop, refer to: https://github.com/princeton-nlp/WebShop
+- For ALFWorld, refer to the [ALFWorld](https://github.com/alfworld/alfworld) repository.
+- For WebShop, refer to the [WebShop](https://github.com/princeton-nlp/WebShop) repository.
### Data Preparation
Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset.
@@ -36,7 +36,7 @@ The task is described as an environment instead of a single prompt.
## Step 2: Config preparation and run the experiment
-You can refer to `example_reasoning_basic` to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively.
+You can refer to [Quick Start](./example_reasoning_basic.md) to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively.
You may revise the configurations properly and run the experiment!
```bash
@@ -104,7 +104,7 @@ class AlfworldWorkflow(MultiTurnWorkflow):
...
```
-and include them in the init files in `trinity/common/workflows/__init__.py`
+and include it in the init file `trinity/common/workflows/__init__.py`
```diff
# -*- coding: utf-8 -*-
@@ -120,7 +120,7 @@ and include them in the init files in `trinity/common/workflows/__init__.py`
]
```
-Then you are all set! It should be pretty simple😄, and both environments converge.
+Then you are all set! It should be pretty simple😄, and the training processes in both environments converge.


diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
index dc45994e98..8d8309a913 100644
--- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
+++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
@@ -37,6 +37,12 @@ pip install flash-attn -v
# pip install flash-attn -v --no-build-isolation
```
+Installation using pip:
+
+```shell
+pip install trinity-rft
+```
+
Installation from docker:
We provided a dockerfile for Trinity-RFT.
@@ -60,7 +66,7 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v 1.
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
- `advantage_fn`: The advantage function used for computing advantages.
-- `kl_penalty_fn`: The KL penalty function used for computing KL penalty.
+- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
- `kl_loss_fn`: The KL loss function used for computing KL loss.
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.
@@ -111,8 +110,8 @@ monitor:
```
- `monitor_type`: Type of monitoring system. Options:
- - `wandb`: Logs to Weights & Biases. Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
- - `tensorboard`: Logs to TensorBoard. Files are saved under `///monitor/tensorboard`.
+ - `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
+ - `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `///monitor/tensorboard`.
---
@@ -122,13 +121,13 @@ Defines the model paths and token limits.
```yaml
model:
- model_path: '/PATH/TO/MODEL/CHECKPOINT/'
+ model_path: /PATH/TO/MODEL/
critic_model_path: ''
max_prompt_tokens: 4096
max_response_tokens: 16384
```
-- `model_path`: Path to the model checkpoint being trained.
+- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `max_prompt_tokens`: Maximum number of tokens allowed in input prompts.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses.
@@ -175,8 +174,8 @@ buffer:
default_reward_fn_type: 'countdown_reward'
```
-- `batch_size`: Number of samples used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
-- `total_epochs`: Total number of training epochs. Not applicable for streaming datasets (e.g., queue-based buffers).
+- `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
+- `total_epochs`: Total number of training epochs.
### Explorer Input
@@ -227,6 +226,8 @@ The configuration for each task dataset is defined as follows:
- For `file` storage type, the path is the path to the directory that contains the task dataset files.
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
- For `sql` storage type, the path is the path to the sqlite database file.
+- `subset_name`: The subset name of the task dataset. Default is `None`.
+- `split`: The split of the task dataset. Default is `train`.
- `format`: Defines keys for prompts and responses in the dataset.
- `prompt_key`: Specifies which column in the dataset contains the prompt data.
- `response_key`: Specifies which column in the dataset contains the response data.
@@ -302,9 +303,9 @@ synchronizer:
```
- `sync_method`: Method of synchronization. Options:
- - `nccl`: Uses NCCL for fast synchronization.
- - `checkpoint`: Loads latest model from disk.
-- `sync_interval`: Interval (in steps) between synchronizations.
+ - `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode.
+ - `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode.
+- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer.
- `sync_timeout`: Timeout duration for synchronization.
---
@@ -324,7 +325,7 @@ trainer:
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `trainer_config_path`: The path to the trainer configuration file.
-- `train_config`: The configuration of the trainer. Only one needs to be set for `trainer.trainer_config` and `trainer.trainer_config_path`
+- `trainer_config`: The trainer configuration provided inline. Only one of `trainer_config_path` and `trainer_config` should be specified.
---
@@ -334,7 +335,7 @@ Configures preprocessing and data cleaning pipelines.
```yaml
data_processor:
- source_data_path: '/PATH/TO/DATASET'
+ source_data_path: /PATH/TO/DATASET
load_kwargs:
split: 'train'
format:
@@ -345,7 +346,7 @@ data_processor:
db_url: 'postgresql://{username}@localhost:5432/{db_name}'
```
-- `source_data_path`: Path to the raw dataset.
+- `source_data_path`: Path to the task dataset.
- `load_kwargs`: Arguments passed to HuggingFace’s `load_dataset()`.
- `dj_config_path`: Path to Data-Juicer configuration for cleaning.
- `clean_strategy`: Strategy for iterative data cleaning.
From 7a1c526d833d1bc869b8f0fc86249c60a9cc1247 Mon Sep 17 00:00:00 2001
From: chenyushuo <297086016@qq.com>
Date: Thu, 19 Jun 2025 17:38:49 +0800
Subject: [PATCH 24/28] Refactor `state_dict_meta` init (#90)
---
trinity/common/models/vllm_async_model.py | 9 +++++++--
trinity/common/models/vllm_model.py | 11 ++++++++---
trinity/explorer/explorer.py | 14 ++++++++++----
3 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py
index 27faa4c44a..177e0e1a81 100644
--- a/trinity/common/models/vllm_async_model.py
+++ b/trinity/common/models/vllm_async_model.py
@@ -100,6 +100,7 @@ def __init__(
self.action_mask_method = tokenize_and_mask_messages_default
else:
self.action_mask_method = tokenize_and_mask_messages_hf
+ self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.api_server_host = None
self.api_server_port = None
@@ -264,9 +265,11 @@ async def _collective_rpc(
method, timeout, args, kwargs
)
- async def sync_model(self, update_weight_args_list) -> bool:
+ async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
- for args in update_weight_args_list:
+ if self.state_dict_meta is None:
+ self.state_dict_meta = update_weight_args_list
+ for args in self.state_dict_meta:
await self._collective_rpc("update_weight", args=args)
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
@@ -282,7 +285,9 @@ async def init_process_group(
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
+ state_dict_meta: dict = None,
):
+ self.state_dict_meta = state_dict_meta
return await self._collective_rpc(
"init_process_group",
args=(
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index 9459cd7511..c999a61bfa 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -8,7 +8,7 @@
import os
import re
import threading
-from typing import List
+from typing import List, Optional, Tuple
import torch
import vllm
@@ -85,6 +85,7 @@ def __init__(self, config: InferenceModelConfig):
else:
self.action_mask_method = tokenize_and_mask_messages_hf
self.lock = threading.Lock()
+ self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
def init_process_group(
@@ -97,7 +98,9 @@ def init_process_group(
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
+ state_dict_meta: dict = None,
):
+ self.state_dict_meta = state_dict_meta
return self.llm.collective_rpc(
"init_process_group",
args=(
@@ -274,10 +277,12 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def has_api_server(self) -> bool:
return False
- def sync_model(self, update_weight_args_list) -> bool:
+ def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
+ if self.state_dict_meta is None:
+ self.state_dict_meta = update_weight_args_list
with self.lock:
- for args in update_weight_args_list:
+ for args in self.state_dict_meta:
self.llm.collective_rpc("update_weight", args=args)
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 87657c9b42..26ee0b53c2 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -96,6 +96,7 @@ def setup_weight_sync_group(
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
timeout=self.config.synchronizer.sync_timeout,
update_with_checkpoint=self.use_checkpoint_weights_update,
+ state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
@@ -119,9 +120,13 @@ def _init_runner_pool(self) -> RunnerPool:
def _update_model_weight(self, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
- update_weight_args_list = []
- for name, param in state_dict.items():
- update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
+ if self.state_dict_meta is None:
+ update_weight_args_list = []
+ for name, param in state_dict.items():
+ update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
+ self.state_dict_meta = update_weight_args_list
+ else:
+ update_weight_args_list = None
ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models])
self.state_dict.clear()
@@ -142,7 +147,8 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
self.logger.error(f"Error when loading state_dict: {e}")
def _nccl_weights_update(self):
- ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models])
+ assert self.state_dict_meta is not None
+ ray.get([model.sync_model.remote() for model in self.models])
def prepare(self) -> None:
"""Preparation before running."""
From 99a772a7ce541a2599dba6b0cd59af135fadd435 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Fri, 20 Jun 2025 11:01:12 +0800
Subject: [PATCH 25/28] Unify async/sync RL (#91)
---
tests/template/verl_config.yaml | 4 +-
.../sample_strategy/sample_strategy.py | 38 +++-
trinity/cli/launcher.py | 90 +++------
trinity/common/config.py | 4 +-
trinity/common/models/utils.py | 1 -
trinity/common/models/vllm_async_model.py | 12 +-
trinity/common/models/vllm_model.py | 13 +-
trinity/common/models/vllm_worker.py | 61 +++---
trinity/explorer/explorer.py | 188 +++++++++---------
trinity/manager/manager.py | 16 +-
trinity/trainer/trainer.py | 44 ++--
trinity/trainer/verl/fsdp_workers.py | 2 +
trinity/trainer/verl_trainer.py | 8 +-
13 files changed, 238 insertions(+), 243 deletions(-)
diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml
index d6dcf4a997..bb5c21612a 100644
--- a/tests/template/verl_config.yaml
+++ b/tests/template/verl_config.yaml
@@ -16,7 +16,7 @@ actor_rollout_ref:
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
- contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
+ contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
@@ -72,6 +72,8 @@ critic:
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
+ checkpoint:
+ contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
trainer:
balance_batch: True
diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py
index 8686a0d497..6e530d32ce 100644
--- a/trinity/algorithm/sample_strategy/sample_strategy.py
+++ b/trinity/algorithm/sample_strategy/sample_strategy.py
@@ -12,26 +12,40 @@
class SampleStrategy(ABC):
- def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
+ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None:
self.pad_token_id = buffer_config.pad_token_id
self.trainer_type = trainer_type
@abstractmethod
def sample(self, step: int) -> Tuple[Any, Dict, List]:
- """Sample experiences from buffer.
+ """Sample data from buffer.
Args:
step (`int`): The step number of current step.
Returns:
- `Any`: The sampled experiences.
+ `Any`: The sampled data.
`Dict`: Metrics for logging.
- `List`: Representative experiences for logging.
+ `List`: Representative data for logging.
+ """
+
+ # Experimental API
+ @abstractmethod
+ def warmup_state(self, step: int) -> Tuple[bool, bool]:
+ """Check the warmup state of the current step.
+
+ Args:
+ step (`int`): The step number of current step.
+
+ Returns:
+ `bool`: Current step is in warmup or not.
+ `bool`: Warmup is finished on this step or not.
"""
@classmethod
+ @abstractmethod
def default_args(cls) -> dict:
- return {}
+ """Get the default arguments of the sample strategy."""
@SAMPLE_STRATEGY.register_module("warmup")
@@ -70,6 +84,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+ def warmup_state(self, step: int) -> Tuple[bool, bool]:
+ return step <= self.sft_warmup_steps, step == self.sft_warmup_steps
+
+ @classmethod
+ def default_args(cls) -> dict:
+ return {}
+
@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
@@ -93,6 +114,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
+ def warmup_state(self, step: int) -> Tuple[bool, bool]:
+ return False, False
+
+ @classmethod
+ def default_args(cls) -> dict:
+ return {}
+
@SAMPLE_STRATEGY.register_module("dpo")
class DPOSampleStrategy(WarmupSampleStrategy):
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index cf4a7882aa..348c9262a0 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -2,6 +2,7 @@
import argparse
import os
import sys
+import traceback
from pathlib import Path
from pprint import pprint
@@ -18,44 +19,41 @@
def bench(config: Config) -> None:
"""Evaluate model."""
- explorer = Explorer.remote(config)
+ explorer = ray.remote(Explorer).options(name="explorer").remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.benchmark.remote())
logger.info("Benchmark finished.")
ray.get(explorer.shutdown.remote())
- except Exception as e:
- logger.error(f"Benchmark failed: {e}")
- raise e
+ except Exception:
+ error_msg = traceback.format_exc()
+ logger.error(f"Benchmark failed:\n{error_msg}")
def explore(config: Config) -> None:
"""Run explorer."""
- explorer = Explorer.remote(config)
try:
+ explorer = ray.remote(Explorer).options(name="explorer").remote(config)
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
- logger.info("Explore finished.")
ray.get(explorer.shutdown.remote())
- except Exception as e:
- logger.error(f"Explore failed: {e}")
- raise e
+ except Exception:
+ error_msg = traceback.format_exc()
+ logger.error(f"Explorer failed:\n{error_msg}")
def train(config: Config) -> None:
"""Run trainer."""
-
- trainer = Trainer.remote(config)
- ray.get(trainer.prepare.remote())
-
try:
+ trainer = ray.remote(Trainer).options(name="trainer").remote(config)
+ ray.get(trainer.prepare.remote())
+ ray.get(trainer.sync_weight.remote())
ray.get(trainer.train.remote())
- logger.info("Train finished.")
ray.get(trainer.shutdown.remote())
- except Exception as e:
- logger.error(f"Train failed {e}.")
- raise e
+ except Exception:
+ error_msg = traceback.format_exc()
+ logger.error(f"Trainer failed:\n{error_msg}")
def both(config: Config) -> None:
@@ -68,54 +66,30 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
- explorer = Explorer.remote(config)
- trainer = Trainer.remote(config)
+ explorer = ray.remote(Explorer).options(name="explorer").remote(config)
+ trainer = ray.remote(Trainer).options(name="trainer").remote(config)
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
- logger.info("Setup explorer and trainer finished.")
ray.get(
[
explorer.prepare.remote(),
trainer.prepare.remote(),
]
)
- # sync weight before training start
- ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
-
- while True:
- try:
- ref_explore = explorer.explore_one_period.remote()
- ref_train = trainer.train_one_period.remote()
- explore_continue, explore_step_num = ray.get(ref_explore)
- train_continue, train_step_num = ray.get(ref_train)
- if not explore_continue:
- # If explore finished, the trainer may not have enough experiences to continue,
- # which will cause the trainer be blocked. So we stop the training process
- # immediately.
- # TODO: use a more elegant way to stop the training process.
- logger.info("Explorer finished, stopping...")
- break
- if not train_continue:
- logger.info("Trainer finished, stopping...")
- break
- ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
- logger.info("Model weight synchronized.")
- except Exception as e:
- logger.error(e)
- logger.error("Training stopped due to exception.")
- raise e
- if explore_step_num % config.explorer.eval_interval == 0:
- try:
- ray.get(explorer.eval.remote())
- logger.info("Evaluation finished.")
- except Exception as e:
- logger.error(e)
- logger.error("Evaluation failed.")
- raise e
- ray.get(explorer.flush_log.remote(step=explore_step_num))
- ray.get(trainer.flush_log.remote(step=train_step_num))
-
- ray.get(explorer.shutdown.remote())
- ray.get(trainer.shutdown.remote())
+ ray.get(
+ [
+ explorer.sync_weight.remote(),
+ trainer.sync_weight.remote(),
+ ]
+ )
+ _, _ = ray.wait(
+ [
+ explorer.explore.remote(),
+ trainer.train.remote(),
+ ],
+ num_returns=1,
+ )
+ explorer.shutdown.remote(),
+ trainer.shutdown.remote(),
def activate_data_module(data_workflow_url: str, config_path: str):
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 9c45627d32..1409fa33f3 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -319,8 +319,10 @@ class SynchronizerConfig:
sync_method: SyncMethod = SyncMethod.NCCL
# sync weights every `sync_interval` steps
sync_interval: int = 1
+ # allow explorer to run `sync_offset` steps before sync
+ sync_offset: int = 0
# waiting for `sync_timeout` seconds before timeout in `nccl` method
- sync_timeout: int = 1200
+ sync_timeout: int = 1800
# wait for the lastest checkpoint to be ready # TODO: to be used
wait_for_checkpoint: bool = False
diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py
index a8751e7240..5cc770e64f 100644
--- a/trinity/common/models/utils.py
+++ b/trinity/common/models/utils.py
@@ -156,7 +156,6 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
iteration = f.read().strip()
return os.path.join(checkpoint_path, f"global_step_{iteration}")
else:
- logger.error(f"No iteration file found in {checkpoint_path}")
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
else:
# load specific iteration checkpoint
diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py
index 177e0e1a81..8a8a089afa 100644
--- a/trinity/common/models/vllm_async_model.py
+++ b/trinity/common/models/vllm_async_model.py
@@ -267,10 +267,9 @@ async def _collective_rpc(
async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
- if self.state_dict_meta is None:
- self.state_dict_meta = update_weight_args_list
- for args in self.state_dict_meta:
- await self._collective_rpc("update_weight", args=args)
+ if update_weight_args_list is not None:
+ await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
+ await self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
return True
@@ -287,7 +286,6 @@ async def init_process_group(
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
- self.state_dict_meta = state_dict_meta
return await self._collective_rpc(
"init_process_group",
args=(
@@ -299,12 +297,10 @@ async def init_process_group(
backend,
timeout,
update_with_checkpoint,
+ state_dict_meta,
),
)
- async def update_weight(self, name, dtype, shape, empty_cache=False):
- return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
-
async def run_api_server(self):
"""Run the OpenAI API server in a Ray actor.
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index c999a61bfa..878fe0bd9c 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -100,7 +100,6 @@ def init_process_group(
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
- self.state_dict_meta = state_dict_meta
return self.llm.collective_rpc(
"init_process_group",
args=(
@@ -112,12 +111,10 @@ def init_process_group(
backend,
timeout,
update_with_checkpoint,
+ state_dict_meta,
),
)
- def update_weight(self, name, dtype, shape, empty_cache=False):
- return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
-
def reset_prefix_cache(self):
self.llm.llm_engine.reset_prefix_cache()
@@ -279,11 +276,9 @@ def has_api_server(self) -> bool:
def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
- if self.state_dict_meta is None:
- self.state_dict_meta = update_weight_args_list
- with self.lock:
- for args in self.state_dict_meta:
- self.llm.collective_rpc("update_weight", args=args)
+ if update_weight_args_list is not None:
+ self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
+ self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
return True
diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py
index 4293811ab7..4d5d3cf376 100644
--- a/trinity/common/models/vllm_worker.py
+++ b/trinity/common/models/vllm_worker.py
@@ -21,22 +21,21 @@ def init_process_group(
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
+ state_dict_meta: list = None,
):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
assert group_name != "", "group name must not be empty"
+ self.set_state_dict_meta(state_dict_meta)
self._update_with_checkpoint = update_with_checkpoint
- if self._update_with_checkpoint:
- logger.info(
- f"init_process_group (checkpoint): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}"
- )
- self._weight_update_rank = torch.distributed.get_rank() + rank_offset
- else:
- logger.info(
- f"init_process_group (nccl): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}"
- )
- self._weight_update_rank = torch.distributed.get_rank() + rank_offset
-
+ self._weight_update_rank = torch.distributed.get_rank() + rank_offset
+ logger.info(
+ f"vLLM starting init_process_group ({'checkpoint' if self._update_with_checkpoint else 'nccl'}):\n"
+ f" > address={master_address}:{master_port}\n"
+ f" > rank={torch.distributed.get_rank()}\n"
+ f" > rank_offset={rank_offset}\n"
+ f" > world_size={world_size}"
+ )
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
@@ -51,24 +50,28 @@ def init_process_group(
rank=self._weight_update_rank,
group_name=group_name,
)
- logger.info(
- f"init_process_group: master_address={master_address}, master_port={master_port}, "
- f"rank={self._weight_update_rank}, world_size={world_size}, group_name={group_name}"
- )
+ logger.info("vLLM init_process_group finished.")
self._explorer_actor = None
- def update_weight(self, name: str, dtype_str: str, shape: tuple, empty_cache=False):
- """Broadcast weight to all vllm workers from source rank 0 (actor model)"""
- if self._weight_update_rank == 0:
- if self._explorer_actor is None:
- self._explorer_actor = ray.get_actor(name="explorer")
- weight = ray.get(self._explorer_actor.get_weight.remote(name))
- weight = weight.to(self.device)
- else:
- dtype = getattr(torch, dtype_str.split(".")[-1])
- weight = torch.empty(shape, dtype=dtype, device=self.device)
- torch.distributed.broadcast(weight, 0, group=self._model_update_group)
- weight = weight.type(self.model_config.dtype)
+ def set_state_dict_meta(self, state_dict_meta):
+ self._state_dict_meta = state_dict_meta
- self.model_runner.model.load_weights(weights=[(name, weight)])
- del weight
+ def update_weight(self):
+ """Broadcast weight to all vllm workers from source rank 0 (actor model)"""
+ assert self._state_dict_meta is not None
+ if self._explorer_actor is None:
+ self._explorer_actor = ray.get_actor(name="explorer")
+ for name, dtype_str, shape in self._state_dict_meta:
+ if self._weight_update_rank == 0:
+ weight = ray.get(self._explorer_actor.get_weight.remote(name))
+ weight = weight.to(self.device)
+ else:
+ dtype = getattr(torch, dtype_str.split(".")[-1])
+ weight = torch.empty(shape, dtype=dtype, device=self.device)
+ torch.distributed.broadcast(weight, 0, group=self._model_update_group)
+ weight = weight.type(self.model_config.dtype)
+ self.model_runner.model.load_weights(weights=[(name, weight)])
+ del weight
+ torch.distributed.barrier()
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 26ee0b53c2..36527f1dcd 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -1,11 +1,13 @@
# -*- coding: utf-8 -*-
"""The explorer module"""
+from __future__ import annotations
+
+import asyncio
import os
import time
from collections import defaultdict
from typing import List, Optional, Tuple
-import ray
import torch
from trinity.algorithm.algorithm_manager import AlgorithmManager
@@ -24,7 +26,6 @@
from trinity.utils.monitor import MONITOR
-@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1})
class Explorer:
"""Responsible for exploring the taskset."""
@@ -32,7 +33,7 @@ def __init__(self, config: Config):
self.logger = get_logger(__name__)
self.cache = CacheManager(config)
explorer_meta = self.cache.load_explorer()
- self.step_num = explorer_meta.get("latest_iteration", 0)
+ self.explore_step_num = explorer_meta.get("latest_iteration", 0)
self.config = config
self.algorithm_manager = AlgorithmManager(config)
self.models, self.auxiliary_models = create_inference_models(config)
@@ -70,8 +71,7 @@ def __init__(self, config: Config):
self.state_dict_meta = []
self.logger.info("Finished initializing Explorer.")
- @ray.method(concurrency_group="setup_weight_sync_group")
- def setup_weight_sync_group(
+ async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
# In checkpoint mode, we use explorer to store the model weights which has no rank
@@ -100,7 +100,7 @@ def setup_weight_sync_group(
)
for i, model in enumerate(self.models)
]
- ray.get(refs)
+ await asyncio.gather(*refs)
def _init_runner_pool(self) -> RunnerPool:
if self.config.explorer.rollout_model.engine_type != "vllm_async":
@@ -117,7 +117,7 @@ def _init_runner_pool(self) -> RunnerPool:
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
return RunnerPool(self.config, self.models, self.auxiliary_models)
- def _update_model_weight(self, state_dict: dict) -> None:
+ async def _update_model_weight(self, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
if self.state_dict_meta is None:
@@ -127,10 +127,12 @@ def _update_model_weight(self, state_dict: dict) -> None:
self.state_dict_meta = update_weight_args_list
else:
update_weight_args_list = None
- ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models])
+ await asyncio.gather(
+ *[model.sync_model.remote(update_weight_args_list) for model in self.models]
+ )
self.state_dict.clear()
- def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
+ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
# TODO: support more checkpoint types
try:
checkpoint_dir = get_checkpoint_dir_with_step_num(
@@ -141,104 +143,62 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
if checkpoint_dir == self.old_checkpoint:
return
model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor"))
- self._update_model_weight(model_weights)
+ await self._update_model_weight(model_weights)
self.old_checkpoint = checkpoint_dir
except Exception as e:
- self.logger.error(f"Error when loading state_dict: {e}")
+ self.logger.warning(f"Fail to load checkpoint: {e}")
- def _nccl_weights_update(self):
+ async def _nccl_weights_update(self):
assert self.state_dict_meta is not None
- ray.get([model.sync_model.remote() for model in self.models])
+ await asyncio.gather(*[model.sync_model.remote() for model in self.models])
- def prepare(self) -> None:
+ async def prepare(self) -> None:
"""Preparation before running."""
if self.use_checkpoint_weights_update:
- master_address, master_port = ray.get(self.models[0].get_available_address.remote())
- self.setup_weight_sync_group(master_address, master_port)
+ master_address, master_port = await self.models[0].get_available_address.remote()
+ await self.setup_weight_sync_group(master_address, master_port)
- @ray.method(concurrency_group="get_weight")
- def get_weight(self, name: str) -> torch.Tensor:
+ async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
- def explore(self) -> None:
- """Explore the entire dataset."""
+ async def explore(self) -> None:
while True:
- explore_status, explore_iter = self.explore_one_period()
- if not explore_status:
+ try:
+ explore_contionue = self.explore_step()
+ if self.need_sync():
+ self.wait_for_workflow_done()
+ await self.sync_weight()
+ if self.explore_step_num % self.config.explorer.eval_interval == 0:
+ self.wait_for_workflow_done()
+ self.eval()
+ if not explore_contionue:
+ break
+ except Exception as e:
+ self.logger.error(f"Error in Explorer: {e}")
break
- self.sync_weight()
- if explore_iter % self.config.explorer.eval_interval == 0:
- self.eval()
- self.logger.info("Evaluation finished.")
- self.logger.info("Explorer finished.")
+ self.logger.info("--------------------\n> Explorer finished.\n--------------------\n")
- def explore_one_period(self) -> Tuple[bool, int]:
- """Explore for one period.
-
- Different from `explore()` which consumes all tasks in the task set,
- `explore_one_period()` only consume `sync_interval * batch_size`
- number of tasks.
- Returns:
- explore_status: whether there are more tasks to explore.
- explore_step_num: the number of explore steps
- """
- # skip for sft
- algo_config = self.algorithm_manager.get_current_algorithm_config(self.step_num + 1)
+ def explore_step(self) -> bool:
+ self.explore_step_num += 1
+ algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num)
+ # skip warmup
if algo_config.algorithm_type == "sft":
- for _ in range(self.config.synchronizer.sync_interval):
- self.step_num += 1
- if self.algorithm_manager.need_save(self.step_num):
- break
- return True, self.step_num
-
- st = time.time()
- all_metrics = defaultdict(list)
-
- # submit tasks of this step
+ return True
try:
- tasks = []
- for _ in range(self.config.synchronizer.sync_interval):
- tasks.extend(self.taskset.read())
- self.runner_pool.run_tasks(tasks) # type: ignore
+ tasks = self.taskset.read()
except StopIteration:
- self.experience_buffer.finish()
- self.logger.warning("No more tasks in the task set. Stop exploring.")
- return False, self.step_num
-
- # wait for all tasks of this step to finish
- while self.runner_pool.has_next():
- status_list = self.runner_pool.get_next_unorder()
- if not isinstance(status_list, list):
- status_list = [status_list]
- for status in status_list:
- if not status.ok:
- self.logger.error(f"Error when running task: {status.message}")
- try:
- # submit another task to replace the failed task
- self.runner_pool.run_tasks(self.taskset.read())
- except StopIteration:
- self.logger.warning("No more tasks in the task set. Stop exploring.")
- return False, self.step_num
- else:
- for metric_name, metric_value in status.metric.items():
- all_metrics[metric_name].append(metric_value)
-
- # calculate metrics
- log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
- log_metrics["rollout/step_time"] = time.time() - st
- self.step_num += self.config.synchronizer.sync_interval
- self.monitor.log(log_metrics, step=self.step_num)
-
- # save explore checkpoint
- self.cache.save_explorer(
- current_step=self.step_num,
- current_task_index=self.step_num * self.config.buffer.batch_size,
- # TODO: remove current_task_index
- )
+ self.logger.warning("No more tasks to explore. Stop exploring.")
+ return False
+ self.runner_pool.run_tasks(tasks)
+ return True
- self.logger.info(f"Explore step {self.step_num} finished.")
- return True, self.step_num
+ def need_sync(self) -> bool:
+ if self.explore_step_num <= self.config.synchronizer.sync_offset:
+ return False
+ return (
+ self.explore_step_num - self.config.synchronizer.sync_offset
+ ) % self.config.synchronizer.sync_interval == 0
def eval(self) -> Tuple[bool, int]:
"""Evaluation on all evaluation data samples."""
@@ -247,7 +207,7 @@ def eval(self) -> Tuple[bool, int]:
eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer))
if len(eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
- return True, self.step_num
+ return True, self.explore_step_num
self.logger.info("Evaluation started.")
all_st = time.time()
log_metrics = {}
@@ -279,14 +239,15 @@ def wait():
log_metrics.update(metrics)
log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st
log_metrics["eval/total_time"] = time.time() - all_st
- self.monitor.log(log_metrics, step=self.step_num) # type: ignore
- return True, self.step_num
+ self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore
+ self.logger.info("Evaluation finished.")
+ return True, self.explore_step_num
- def benchmark(self) -> bool:
+ async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.eval_on_latest_checkpoint:
- self._checkpoint_weights_update()
+ await self._checkpoint_weights_update()
self.eval()
return True
@@ -300,18 +261,47 @@ def benchmark(self) -> bool:
]
)
for step_num in all_ckp_steps:
- self.step_num = step_num
- self._checkpoint_weights_update(step_num=step_num)
+ self.explore_step_num = step_num
+ await self._checkpoint_weights_update(step_num=step_num)
self.eval()
return True
- def sync_weight(self) -> None:
+ def wait_for_workflow_done(self) -> None:
+ """Wait for workflow to finish."""
+ all_metrics = defaultdict(list)
+ # wait for all tasks of this step to finish
+ while self.runner_pool.has_next():
+ status_list = self.runner_pool.get_next_unorder()
+ if not isinstance(status_list, list):
+ status_list = [status_list]
+ for status in status_list:
+ if not status.ok:
+ self.logger.error(f"Error when running task: {status.message}")
+ # submit another task to replace the failed task
+ self.runner_pool.run_tasks(self.taskset.read(batch_size=1))
+ else:
+ for metric_name, metric_value in status.metric.items():
+ all_metrics[metric_name].append(metric_value)
+ # calculate metrics
+ log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
+ self.monitor.log(log_metrics, step=self.explore_step_num)
+
+ self.logger.info(f"Explore step {self.explore_step_num} finished.")
+
+ async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
+ self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.")
if self.use_checkpoint_weights_update:
- self._checkpoint_weights_update()
+ await self._checkpoint_weights_update()
else: # nccl weights update
- self._nccl_weights_update()
+ await self._nccl_weights_update()
+ # save explore checkpoint
+ self.cache.save_explorer(
+ current_step=self.explore_step_num,
+ current_task_index=self.explore_step_num * self.config.buffer.batch_size,
+ )
+ self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished")
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py
index 3c148cbe12..baaf1242c3 100644
--- a/trinity/manager/manager.py
+++ b/trinity/manager/manager.py
@@ -47,7 +47,13 @@ def load_explorer(self) -> dict:
try:
with open(self.explorer_meta_path, "r", encoding="utf-8") as f:
explorer_meta = json.load(f)
- logger.info(f"Find existing explorer meta: {explorer_meta}")
+ logger.info(
+ "----------------------------------\n"
+ "Found existing explorer checkpoint:\n"
+ f" > {explorer_meta}\n"
+ "Continue exploring from this point.\n"
+ "----------------------------------"
+ )
return explorer_meta
except Exception as e:
logger.error(f"Failed to load explore meta file: {e}")
@@ -62,7 +68,13 @@ def load_trainer(self) -> dict:
try:
with open(self.trainer_meta_path, "r", encoding="utf-8") as f:
trainer_meta = json.load(f)
- logger.info(f"Find existing trainer meta: {trainer_meta}")
+ logger.info(
+ "----------------------------------\n"
+ "Found existing trainer checkpoint:\n"
+ f" > {trainer_meta}\n"
+ "Continue training from this point.\n"
+ "----------------------------------"
+ )
return trainer_meta
except Exception as e:
logger.warning(f"Failed to load trainer meta file: {e}")
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 2920604fbb..ff43dfbb79 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -1,31 +1,23 @@
# -*- coding: utf-8 -*-
"""
Trainer Class
-This file is modified from verl.trainer.main_ppo.py
-And is a reproduction code of Jiayi-Pan/TinyZero.
-
-Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
+from __future__ import annotations
+
import os
from abc import ABC, abstractmethod
-from typing import Tuple
-
-import ray
-from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.utils.log import get_logger
-@ray.remote(name="trainer")
class Trainer:
"""Consume the experience and train the model."""
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
- self.algorithm_manager = AlgorithmManager(config)
self.engine = get_trainer_wrapper(config)
def prepare(self) -> None:
@@ -35,23 +27,18 @@ def prepare(self) -> None:
def train(self):
"""Train the model."""
while True:
- train_status, _ = self.train_step()
- if not train_status:
+ try:
+ train_continue = self.train_step()
+ if self.need_sync():
+ self.sync_weight()
+ if not train_continue:
+ break
+ except Exception as e:
+ self.logger.error(f"Error in Trainer: {e}")
break
+ self.logger.info("--------------------\n> Trainer finished.\n--------------------\n")
- def train_one_period(self) -> Tuple[bool, int]:
- """Train for one period. Each period contains `sync_interval` steps.
- Returns:
- train_status: Whether to continue training.
- train_step_num: The number of training steps"""
- for _ in range(self.config.synchronizer.sync_interval):
- train_status, train_step_num = self.train_step()
- if not train_status:
- return False, train_step_num
- self.logger.info(f"Train step {train_step_num} finished.")
- return True, train_step_num
-
- def train_step(self) -> Tuple[bool, int]:
+ def train_step(self) -> bool:
"""Train one step.
Returns:
@@ -59,9 +46,14 @@ def train_step(self) -> Tuple[bool, int]:
"""
return self.engine.train_step()
+ def need_sync(self) -> bool:
+ """Whether to sync the model weight."""
+ return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0
+
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
+ self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
self.engine.sync_weight()
def flush_log(self, step: int) -> None:
@@ -90,7 +82,7 @@ def train_step_num(self) -> int:
"""Get the current training step number."""
@abstractmethod
- def train_step(self) -> Tuple[bool, int]:
+ def train_step(self) -> bool:
"""Training."""
@abstractmethod
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index 2a8308ea62..69e99a153d 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -606,6 +606,8 @@ def sync_weight(self):
continue
torch.distributed.broadcast(param, 0, group=self._model_update_group)
param = None
+ torch.distributed.barrier()
+ torch.cuda.synchronize()
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index bc15a25446..4243e61d17 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -6,7 +6,7 @@
import os
import sys
from pprint import pprint
-from typing import Dict, List, Tuple
+from typing import Dict, List
import pandas as pd
import ray
@@ -285,14 +285,14 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
- def train_step(self) -> Tuple[bool, int]: # noqa C901
+ def train_step(self) -> bool: # noqa C901
metrics = {}
try:
batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
prefix_metrics(sample_metrics, "sample", metrics)
except StopIteration:
print("No more data to train. Stop training.")
- return False, self.global_steps
+ return False
self.global_steps += 1
timing_raw = {}
algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
@@ -382,7 +382,7 @@ def train_step(self) -> Tuple[bool, int]: # noqa C901
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- return train_status, self.global_steps
+ return train_status
def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
From 6f2d7c7d9867c25ada6e701a9a1f444738228ea0 Mon Sep 17 00:00:00 2001
From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com>
Date: Fri, 20 Jun 2025 16:30:12 +0800
Subject: [PATCH 26/28] Support one-step ahead async RL (#93)
---
tests/trainer/trainer_test.py | 52 +++++++++++++++++++++++++++-
trinity/cli/launcher.py | 36 ++++++++++++++-----
trinity/common/constants.py | 11 ++++++
trinity/common/models/vllm_worker.py | 3 +-
trinity/explorer/explorer.py | 48 ++++++++++++++++++-------
trinity/trainer/trainer.py | 25 ++++++++++---
trinity/trainer/verl/fsdp_workers.py | 8 +++--
trinity/trainer/verl_trainer.py | 3 +-
8 files changed, 156 insertions(+), 30 deletions(-)
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 726e22290e..0ec438c2db 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -115,6 +115,56 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
+class TestStepAheadAsyncRL(BaseTrainerCase):
+ def test_trainer(self):
+ """Test the explore step ahead trainer"""
+ # train 4 step, sync_offset=1, sync_interval=2
+ # Explorer:
+ # | 1 | 2 | 3 |sync| 4 |
+ # |---|---|---|sync|---|
+ # Trainer:
+ # | 1 | 2 |sync| 3 | 4 |
+ # |---|---|sync|---|---|
+ self.config.buffer.total_epochs = 1
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
+ self.config.trainer.save_interval = 4
+ self.config.synchronizer.sync_interval = 2
+ self.config.synchronizer.sync_offset = 1
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 1
+ self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 1
+
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ actor_kl_metrics = parser.metric_list("actor/kl")
+ self.assertTrue(len(actor_kl_metrics) > 0)
+ critic_kl_metrics = parser.metric_list("critic/kl")
+ self.assertTrue(len(critic_kl_metrics) > 0)
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+ from trinity.common.models.utils import get_checkpoint_dir_with_step_num
+
+ checkpoint_step_4 = get_checkpoint_dir_with_step_num(
+ checkpoint_root_path=self.config.checkpoint_job_dir,
+ trainer_type=self.config.trainer.trainer_type,
+ step_num=4,
+ )
+ self.assertTrue(os.path.exists(checkpoint_step_4))
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
+
+
class TestTrainerGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K."""
@@ -153,7 +203,7 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
-class TestTrainerGSM8KWithSFT(BaseTrainerCase):
+class TestTrainerSFTWarmupGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index 348c9262a0..a63b06a36d 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -9,6 +9,7 @@
import ray
from trinity.common.config import Config, load_config
+from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
@@ -19,7 +20,7 @@
def bench(config: Config) -> None:
"""Evaluate model."""
- explorer = ray.remote(Explorer).options(name="explorer").remote(config)
+ explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.benchmark.remote())
@@ -33,7 +34,7 @@ def bench(config: Config) -> None:
def explore(config: Config) -> None:
"""Run explorer."""
try:
- explorer = ray.remote(Explorer).options(name="explorer").remote(config)
+ explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
@@ -46,7 +47,7 @@ def explore(config: Config) -> None:
def train(config: Config) -> None:
"""Run trainer."""
try:
- trainer = ray.remote(Trainer).options(name="trainer").remote(config)
+ trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
ray.get(trainer.prepare.remote())
ray.get(trainer.sync_weight.remote())
ray.get(trainer.train.remote())
@@ -66,8 +67,8 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
- explorer = ray.remote(Explorer).options(name="explorer").remote(config)
- trainer = ray.remote(Trainer).options(name="trainer").remote(config)
+ explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
+ trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
ray.get(
[
@@ -81,15 +82,34 @@ def both(config: Config) -> None:
trainer.sync_weight.remote(),
]
)
- _, _ = ray.wait(
+ ready_ref, wait_ref = ray.wait(
[
explorer.explore.remote(),
trainer.train.remote(),
],
num_returns=1,
)
- explorer.shutdown.remote(),
- trainer.shutdown.remote(),
+
+ ready = ray.get(ready_ref[0])
+ if ready == TRAINER_NAME:
+ logger.info(
+ "===========================================================\n"
+ "> Launcher detected that the `Trainer` process has finished.\n"
+ "> Stopping the explorer process immediately.\n"
+ "==========================================================="
+ )
+ ray.wait(wait_ref, timeout=5)
+ elif ready == EXPLORER_NAME:
+ logger.info(
+ "============================================================\n"
+ "> Launcher detected that the `Explorer` process has finished.\n"
+ f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n"
+ "> You can force stop the Trainer process by pressing Ctrl+C.\n"
+ "============================================================"
+ )
+ ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout)
+ explorer.shutdown.remote()
+ trainer.shutdown.remote()
def activate_data_module(data_workflow_url: str, config_path: str):
diff --git a/trinity/common/constants.py b/trinity/common/constants.py
index 3c49d65c21..9a428131fe 100644
--- a/trinity/common/constants.py
+++ b/trinity/common/constants.py
@@ -8,6 +8,9 @@
# names
+EXPLORER_NAME = "explorer"
+TRAINER_NAME = "trainer"
+
ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"
@@ -92,3 +95,11 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):
NCCL = "nccl"
CHECKPOINT = "checkpoint"
+
+
+class RunningStatus(Enum):
+ """Running status of explorer and trainer."""
+
+ RUNNING = "running"
+ WAITING_SYNC = "waiting_sync"
+ STOPPED = "stopped"
diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py
index 4d5d3cf376..883e470381 100644
--- a/trinity/common/models/vllm_worker.py
+++ b/trinity/common/models/vllm_worker.py
@@ -4,6 +4,7 @@
import torch
import torch.distributed
+from trinity.common.constants import EXPLORER_NAME
from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.log import get_logger
@@ -60,7 +61,7 @@ def update_weight(self):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
assert self._state_dict_meta is not None
if self._explorer_actor is None:
- self._explorer_actor = ray.get_actor(name="explorer")
+ self._explorer_actor = ray.get_actor(name=EXPLORER_NAME)
for name, dtype_str, shape in self._state_dict_meta:
if self._weight_update_rank == 0:
weight = ray.get(self._explorer_actor.get_weight.remote(name))
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 36527f1dcd..31ade5f84b 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -14,7 +14,12 @@
from trinity.buffer import get_buffer_writer
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config
-from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
+from trinity.common.constants import (
+ EXPLORER_NAME,
+ ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
+ RunningStatus,
+ SyncMethod,
+)
from trinity.common.models import create_inference_models
from trinity.common.models.utils import (
get_checkpoint_dir_with_step_num,
@@ -50,7 +55,7 @@ def __init__(self, config: Config):
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
name=self.config.name,
- role="explorer",
+ role=EXPLORER_NAME,
config=config,
)
self.batch_size = config.buffer.batch_size
@@ -69,6 +74,7 @@ def __init__(self, config: Config):
self.state_dict = {}
else: # nccl mode
self.state_dict_meta = []
+ self.status = RunningStatus.RUNNING
self.logger.info("Finished initializing Explorer.")
async def setup_weight_sync_group(
@@ -162,35 +168,44 @@ async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
- async def explore(self) -> None:
+ async def explore(self) -> str:
while True:
try:
explore_contionue = self.explore_step()
+ if not explore_contionue:
+ break
if self.need_sync():
self.wait_for_workflow_done()
await self.sync_weight()
if self.explore_step_num % self.config.explorer.eval_interval == 0:
self.wait_for_workflow_done()
self.eval()
- if not explore_contionue:
- break
except Exception as e:
self.logger.error(f"Error in Explorer: {e}")
break
- self.logger.info("--------------------\n> Explorer finished.\n--------------------\n")
+ self.logger.info("--------------------\n> Explorer finished.\n--------------------")
+ return EXPLORER_NAME
def explore_step(self) -> bool:
- self.explore_step_num += 1
- algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num)
+ algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1)
# skip warmup
if algo_config.algorithm_type == "sft":
+ self.explore_step_num += 1
return True
try:
tasks = self.taskset.read()
except StopIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
+ self.cache.save_explorer(
+ current_step=self.explore_step_num,
+ current_task_index=self.explore_step_num * self.config.buffer.batch_size,
+ )
+ self.status = RunningStatus.STOPPED
+ self.wait_for_workflow_done()
+ self.experience_buffer.finish()
return False
self.runner_pool.run_tasks(tasks)
+ self.explore_step_num += 1
return True
def need_sync(self) -> bool:
@@ -278,20 +293,25 @@ def wait_for_workflow_done(self) -> None:
if not status.ok:
self.logger.error(f"Error when running task: {status.message}")
# submit another task to replace the failed task
- self.runner_pool.run_tasks(self.taskset.read(batch_size=1))
+ try:
+ tasks = self.taskset.read(batch_size=1)
+ except StopIteration:
+ self.logger.warning("No more tasks in taskset. Stop retrying.")
+ return
+ self.runner_pool.run_tasks(tasks)
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
# calculate metrics
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
self.monitor.log(log_metrics, step=self.explore_step_num)
-
self.logger.info(f"Explore step {self.explore_step_num} finished.")
async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
- self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.")
+ self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
+ self.status = RunningStatus.WAITING_SYNC
if self.use_checkpoint_weights_update:
await self._checkpoint_weights_update()
else: # nccl weights update
@@ -301,7 +321,11 @@ async def sync_weight(self) -> None:
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
)
- self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished")
+ self.status = RunningStatus.RUNNING
+ self.logger.info(f"Explorer sync at step {self.explore_step_num} finished")
+
+ async def running_status(self) -> RunningStatus:
+ return self.status
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index ff43dfbb79..216c916c69 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -7,8 +7,15 @@
import os
from abc import ABC, abstractmethod
+import ray
+
from trinity.common.config import Config
-from trinity.common.constants import SyncMethod
+from trinity.common.constants import (
+ EXPLORER_NAME,
+ TRAINER_NAME,
+ RunningStatus,
+ SyncMethod,
+)
from trinity.utils.log import get_logger
@@ -19,24 +26,26 @@ def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.engine = get_trainer_wrapper(config)
+ self.explorer_ref = None
def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
- def train(self):
+ def train(self) -> str:
"""Train the model."""
while True:
try:
train_continue = self.train_step()
- if self.need_sync():
- self.sync_weight()
if not train_continue:
break
+ if self.need_sync():
+ self.sync_weight()
except Exception as e:
self.logger.error(f"Error in Trainer: {e}")
break
- self.logger.info("--------------------\n> Trainer finished.\n--------------------\n")
+ self.logger.info("--------------------\n> Trainer finished.\n--------------------")
+ return TRAINER_NAME
def train_step(self) -> bool:
"""Train one step.
@@ -53,6 +62,12 @@ def need_sync(self) -> bool:
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
+ if self.explorer_ref is None:
+ self.explorer_ref = ray.get_actor(EXPLORER_NAME)
+ explorer_status = ray.get(self.explorer_ref.running_status.remote())
+ if explorer_status == RunningStatus.STOPPED:
+ self.logger.warning("Explorer has already stopped. Skipping sync weight.")
+ return
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
self.engine.sync_weight()
diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py
index 69e99a153d..cbc88902a0 100644
--- a/trinity/trainer/verl/fsdp_workers.py
+++ b/trinity/trainer/verl/fsdp_workers.py
@@ -71,7 +71,11 @@
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from trinity.common.config import AlgorithmConfig
-from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
+from trinity.common.constants import (
+ EXPLORER_NAME,
+ ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
+ SyncMethod,
+)
from trinity.utils.distributed import init_process_group, is_ipv6_address
logger = logging.getLogger(__file__)
@@ -573,7 +577,7 @@ def setup_weight_sync_group(self):
master_address, master_port = self.get_availale_master_addr_port()
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
- explorer = ray.get_actor("explorer")
+ explorer = ray.get_actor(EXPLORER_NAME)
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index 4243e61d17..d041bea128 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -36,6 +36,7 @@
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import Config
+from trinity.common.constants import TRAINER_NAME
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
from trinity.utils.monitor import MONITOR
@@ -149,7 +150,7 @@ def __init__(
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
- role="trainer",
+ role=TRAINER_NAME,
config=global_config,
)
self.reset_experiences_example_table()
From eddf4e47b60ccf76a368724941a17c0be6e93169 Mon Sep 17 00:00:00 2001
From: Yilun Huang
Date: Fri, 20 Jun 2025 17:43:35 +0800
Subject: [PATCH 27/28] Refactor data module and support task pipeline in data
processor (#92)
---
.gitignore | 1 +
.../tutorial/example_data_functionalities.md | 196 ++++++++++--------
environments/data.yaml | 5 -
examples/grpo_gsm8k/gsm8k.yaml | 15 --
examples/grpo_gsm8k_task_pipeline/README.md | 7 +
examples/grpo_gsm8k_task_pipeline/gsm8k.yaml | 95 +++++++++
.../grpo_gsm8k_task_pipeline/train_gsm8k.yaml | 50 +++++
pyproject.toml | 1 +
tests/buffer/file_test.py | 42 ++++
tests/data/controllers/task_parser_test.py | 22 +-
tests/data/core/dataset_test.py | 105 ++++------
tests/data/core/formatter_test.py | 114 ++++++----
tests/data/processor/cleaner_test.py | 78 +++----
.../active_iterator_test_cfg.yaml | 27 ++-
.../active_iterator_test_dj_cfg.yaml | 2 -
tests/test_configs/cleaner_test_dj_cfg.yaml | 4 +-
tests/test_configs/cleaner_test_rft_cfg.yaml | 10 +-
.../human_annotator_test_rft_cfg.yaml | 18 +-
trinity/buffer/buffer.py | 4 +-
trinity/buffer/ray_wrapper.py | 2 +
trinity/buffer/reader/file_reader.py | 26 ++-
trinity/cli/client.py | 4 +-
trinity/cli/launcher.py | 70 ++++++-
trinity/common/config.py | 30 ++-
trinity/data/controllers/active_iterator.py | 149 +++++++------
trinity/data/controllers/task_parser.py | 49 ++---
trinity/data/core/dataset.py | 62 +++---
trinity/data/core/dataset_db.py | 84 --------
trinity/data/processors/cleaner.py | 6 +-
trinity/data/readme.md | 4 +-
trinity/data/server.py | 27 ++-
31 files changed, 756 insertions(+), 553 deletions(-)
create mode 100644 examples/grpo_gsm8k_task_pipeline/README.md
create mode 100644 examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
create mode 100644 examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml
delete mode 100644 trinity/data/core/dataset_db.py
diff --git a/.gitignore b/.gitignore
index 646848ade7..7ab517ff20 100644
--- a/.gitignore
+++ b/.gitignore
@@ -84,6 +84,7 @@ ENV/
logs/
# data-juicer
+tmp/
outputs/
# agentscope
runs/
diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
index 6558efcdd9..27b5fb26bf 100644
--- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
+++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md
@@ -1,80 +1,97 @@
# Data Processing
-## Example: reasoning task
+## Example: Data Processor for Task Pipeline
-In this example, you will learn how to apply the data module of Trinity-RFT to prepare the dataset before exploring and training. This example takes GSM-8K dataset as the example dataset to figure out:
+In this example, you will learn how to apply the data processor of Trinity-RFT to prepare and prioritize the dataset before task exploring and training. This example takes GSM-8K dataset as the example dataset to figure out:
-1. how to prepare the data module
-2. how to configure the data module
-3. what the data module can do
+1. how to prepare the data processor
+2. how to configure the data processor
+3. what the data processor can do
-Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md), and you need to install [postgresql](https://www.postgresql.org/docs/current/tutorial-install.html) as well.
+Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md).
### Data Preparation
-#### Prepare the Data Module
+#### Prepare the Data Processor
-As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server.
+As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server.
```shell
-# prepare split environments, including the one of data module
+# prepare split environments, including the one of data processor
python scripts/install.py
# start all split servers
python scripts/start_servers.py
```
-### Configure the Data Module
+### Configure the Data Processor
-Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.
+Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file.
In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
```yaml
data_processor:
- # basic info
- source_data_path: /PATH/TO/GSM8K/
- load_kwargs:
- split: 'train' # only need the train split
- format: # set the field mappings
- prompt_key: 'question'
- response_key: 'answer'
- # database related. The result dataset will be stored in the database.
- db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
+ data_processor_url: 'http://127.0.0.1:5005/data_processor'
+ # task pipeline related
+ task_pipeline:
+ # I/O buffers
+ input_buffers:
+ - name: 'raw_input'
+ path: /PATH/TO/GSM8K/
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: /PATH/TO/OUTPUT/JSONL/FILE
+ storage_type: 'file'
+ # format mapping
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
```
-Here you can set the basic information for the GSM-8K dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training:
+Here you can set the basic buffers for the GSM-8K dataset input and output and some other items about downstream dataset loading for exploring and training:
-+ `source_data_path`: the path to the raw dataset.
-+ `load_kwargs`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
-+ `format`: some dataset format config items, which are used to map original data field names to unified ones.
-+ `db_url`: the URL of the postgresql database to store the result dataset.
++ `data_processor_url`: the URL of the data processor service, which is started in the previous step.
++ `task_pipeline`: the configs for the task pipeline. Task pipeline is used to process the raw dataset. It consists of several inner configs:
+ + `input_buffers`: the input buffers for the task pipeline. We usually load from raw dataset files in this pipeline, thus we need to the dataset `path` and set the `storage_type` to "file" and set `raw` to True. It allows multiple input buffers. We can name each buffer with the `name` field.
+ + `output_buffer`: the output buffer for the task pipeline. We usually store the processed dataset in files as well, thus we need to set the `storage_type` to "file".
+ + `format`: some dataset format config items, which are used to map original data field names to unified ones.
-In addition, there are several config items related to the data active iterator, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer.
+In addition, there are several config items related to the data active iterator in `task_pipeline` part, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer.
#### Not familiar with Data-Juicer
-If you are not familiar with Data-Juicer, the data module provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example:
+If you are not familiar with Data-Juicer, the data processor provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example:
```yaml
data_processor:
- # basic info
- source_data_path: /PATH/TO/GSM8K/
- load_kwargs:
- split: 'train' # only need the train split
- format: # set the field mappings
- prompt_key: 'question'
- response_key: 'answer'
- # database related. The result dataset will be stored in the database.
- db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
-
- #### new part about data active iterator
- dj_process_desc: 'Please compute difficulty scores for these math questions.'
- agent_model_name: 'qwen-max'
- agent_model_config:
- config_name: 'my-qwen-instruction'
- model_type: 'dashscope_chat'
- model_name: 'qwen2.5-72b-instruct'
- clean_strategy: 'iterative'
+ data_processor_url: 'http://127.0.0.1:5005/data_processor'
+ # task pipeline related
+ task_pipeline:
+ # I/O buffers
+ input_buffers:
+ - name: 'raw_input'
+ path: /PATH/TO/GSM8K/
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: /PATH/TO/OUTPUT/JSONL/FILE
+ storage_type: 'file'
+ # format mapping
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+
+ #### new part about data active iterator
+ dj_process_desc: 'Please compute difficulty scores for these math questions.'
+ agent_model_name: 'qwen-max'
+ agent_model_config:
+ config_name: 'my-qwen-instruction'
+ model_type: 'dashscope_chat'
+ model_name: 'qwen2.5-72b-instruct'
+ clean_strategy: 'iterative'
```
You can write your demand description in config item `dj_process_desc`, and set the model name and configs used for the agent in config items `agent_model_name` and `agent_model_config`. Here we use Qwen2.5-72b-Instruct as our recipe managing agent. And you can set the `clean_strategy` to 'iterative' to get a better dataset.
@@ -99,19 +116,27 @@ After preparing the Data-Juicer data processing recipe, you can set the `dj_conf
```yaml
data_processor:
- # basic info
- source_data_path: /PATH/TO/GSM8K/
- load_kwargs:
- split: 'train' # only need the train split
- format: # set the field mappings
- prompt_key: 'question'
- response_key: 'answer'
- # database related. The result dataset will be stored in the database.
- db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
-
- #### new part about data active iterator
- dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml'
- clean_strategy: 'iterative'
+ data_processor_url: 'http://127.0.0.1:5005/data_processor'
+ # task pipeline related
+ task_pipeline:
+ # I/O buffers
+ input_buffers:
+ - name: 'raw_input'
+ path: /PATH/TO/GSM8K/
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: /PATH/TO/OUTPUT/JSONL/FILE
+ storage_type: 'file'
+ # format mapping
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+
+ #### new part about data active iterator
+ dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml'
+ clean_strategy: 'iterative'
```
And you can set the `clean_strategy` to 'iterative' to get a better dataset.
@@ -123,7 +148,7 @@ All config items in the `data` section can be found [here](trinity_configs.md).
```{note}
-Only when one of `dj_process_desc` and `dj_config_path` is provided, the data module and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly.
+Only when one of `xxx_pipeline` is provided, and one of `dj_process_desc` and `dj_config_path` in the pipeline config is provided, the data processor and the data active iterator will be activated. Otherwise, this part will be skipped and it will enter into the exploring stage directly.
```
### Exploring & Training
@@ -140,49 +165,54 @@ ray start --address=
trinity run --config
```
-If you follow the steps above, Trinity-RFT will send a request to the data module server, the data active iterator will be activated and compute difficulty scores for each sample in the raw dataset. After that, the data module server stores the result dataset into the database, when exploring begins, it will load the prepared dataset and continue the downstream steps.
+If you follow the steps above, Trinity-RFT will send a request to the data processor server, the data active iterator will be activated, compute difficulty scores for each sample in the raw dataset, and rank the dataset according to difficulty scores. After that, the data processor server stores the result dataset into the output buffer, when exploring begins, it will load the prepared dataset and continue the downstream steps.
-
-
-## Example: human in the loop
+## Example: Human in the Loop
Sometimes, you might need to involve human feedbacks for some raw data. In this example, you will learn how to annotate raw data to get a better dataset before training. This example takes an example Q&A dataset and tries to select the chosen and rejected ones for DPO method.
-Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, install postgresql, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source.
+Before getting started, you need to prepare the main environment of Trinity-RFT according to the installation section of the README file, and [start a label-studio server](https://github.com/modelscope/data-juicer/tree/main/tools/humanops) from Data-Juicer from source.
### Data Preparation
-#### Prepare the Data Module
+#### Prepare the Data Processor
-As the overall framework of Trinity-RFT shows, the data module is one of the high-level functions. Trinity-RFT encapsulates the data module as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server.
+As the overall framework of Trinity-RFT shows, the data processor is one of the high-level functions. Trinity-RFT encapsulates the data processor as an independent service to avoid dependency conflict issues. Thus you need to prepare a split environment for this module and start the server.
```shell
-# prepare split environments, including the one of data module
+# prepare split environments, including the one of data processor
python scripts/install.py
# start all split servers
python scripts/start_servers.py
```
-### Configure the Data Module
+### Configure the Data Processor
-Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.
+Trinity-RFT uses a unified config file to manage all config items. For the data processor, you need to focus on the `data_processor` section in the config file.
-In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
+In this example, assume that you need to select the chosen and rejected responses for DPO method. So you can set these config items like the following example:
```yaml
data_processor:
- # basic info
- source_data_path: 'tests/test_data/test_human_annotator'
- load_kwargs:
- split: 'train' # only need the train split
- format: # set the field mappings
- prompt_key: 'prompt'
- chosen_key: 'chosen'
- rejected_key: 'rejected'
- #### new part about data active iterator
- dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml'
- # database related. The result dataset will be stored in the database.
- db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
+ data_processor_url: 'http://127.0.0.1:5005/data_processor'
+ # task pipeline related
+ task_pipeline:
+ # I/O buffers
+ input_buffers:
+ - name: 'raw_input'
+ path: 'tests/test_data/test_human_annotator'
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
+ storage_type: 'file'
+ format: # set the field mappings
+ prompt_key: 'prompt'
+ chosen_key: 'chosen'
+ rejected_key: 'rejected'
+ #### new part about data active iterator
+ dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml'
```
Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above.
@@ -223,7 +253,7 @@ You can set more config items for this OP (e.g. notification when annotation is
### Start Running
-When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server.
+When you start running with the RFT config, the data processor will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server.

diff --git a/environments/data.yaml b/environments/data.yaml
index 6acdf04dc9..d43ece076b 100644
--- a/environments/data.yaml
+++ b/environments/data.yaml
@@ -6,10 +6,5 @@ dependencies:
- pip:
- py-data-juicer
- agentscope
- - flask
- - omegaconf
- - sqlalchemy
- - psycopg2
- - networkx
- transformers
- "-e ..[dev]"
diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml
index 2a87ef288b..0763586457 100644
--- a/examples/grpo_gsm8k/gsm8k.yaml
+++ b/examples/grpo_gsm8k/gsm8k.yaml
@@ -4,19 +4,6 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
-data_processor:
- # basic info
- source_data_path: 'openai/gsm8k'
- # data active iterator related
- dj_process_desc: 'Please compute difficulty scores for these math questions.'
- agent_model_name: 'qwen-max'
- agent_model_config:
- config_name: 'my-qwen-instruction'
- model_type: 'dashscope_chat'
- model_name: 'qwen2.5-72b-instruct'
- clean_strategy: 'iterative'
- # db related
- db_url: ''
model:
model_path: /PATH/TO/MODEL/
@@ -41,9 +28,7 @@ buffer:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
- n: 8
temperature: 1.0
- logprobs: 0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
diff --git a/examples/grpo_gsm8k_task_pipeline/README.md b/examples/grpo_gsm8k_task_pipeline/README.md
new file mode 100644
index 0000000000..ead6a56185
--- /dev/null
+++ b/examples/grpo_gsm8k_task_pipeline/README.md
@@ -0,0 +1,7 @@
+# GRPO on GSM8K dataset with Task Pipeline
+
+This example shows the usage of GRPO on the GSM8K dataset, with a task pipeline to prioritize the raw dataset before training.
+
+For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md).
+
+The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
new file mode 100644
index 0000000000..36514e0e01
--- /dev/null
+++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
@@ -0,0 +1,95 @@
+project: "Trinity-RFT-gsm8k-task-pipeline"
+name: "qwen2.5-1.5B-gsm8k-task-pipeline"
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 8
+data_processor:
+ data_processor_url: 'http://127.0.0.1:5005/data_processor'
+ # task pipeline related
+ task_pipeline:
+ # I/O buffers
+ input_buffers:
+ - name: 'raw_input'
+ path: 'openai/gsm8k'
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
+ storage_type: 'file'
+ # format mapping
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ # data active iterator related
+ dj_process_desc: 'Please compute difficulty scores for these math questions.'
+ agent_model_name: 'qwen-max'
+ agent_model_config:
+ config_name: 'my-qwen-instruction'
+ model_type: 'dashscope_chat'
+ model_name: 'qwen2.5-72b-instruct'
+ clean_strategy: 'iterative'
+
+model:
+ model_path: /PATH/TO/MODEL/
+ max_prompt_tokens: 256
+ max_response_tokens: 1024
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ max_retry_times: 3
+ max_retry_interval: 1
+ explorer_input:
+ taskset:
+ name: gsm8k
+ storage_type: file
+ path: './outputs/task_pipeline_output/'
+ split: 'train'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 1.0
+ eval_tasksets:
+ - name: gsm8k-eval
+ storage_type: file
+ path: 'openai/gsm8k'
+ subset_name: 'main'
+ split: 'test'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ default_workflow_type: 'math_workflow'
+ trainer_input:
+ experience_buffer:
+ name: gsm8k_buffer
+ storage_type: queue
+ path: 'sqlite:///gsm8k.db'
+ # sft_warmup_steps: 0
+ # sft_warmup_dataset: # Uncomment these to enable sft warmup
+ # name: warmup_data
+ # storage_type: file
+ # path: '/PATH/TO/WARMUP_DATA/'
+explorer:
+ eval_interval: 50
+ runner_num: 32
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: true
+ dtype: bfloat16
+ seed: 42
+synchronizer:
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 1200
+trainer:
+ trainer_type: 'verl'
+ trainer_config_path: 'examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml'
+ save_interval: 100
diff --git a/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml
new file mode 100644
index 0000000000..fc44fdad94
--- /dev/null
+++ b/examples/grpo_gsm8k_task_pipeline/train_gsm8k.yaml
@@ -0,0 +1,50 @@
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: True # False
+ actor:
+ strategy: fsdp # This is for backward-compatibility
+ ppo_mini_batch_size: 128
+ ppo_micro_batch_size_per_gpu: 4
+ use_dynamic_bsz: True # False
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
+ grad_clip: 1.0
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 1 # sp size
+ optim:
+ lr: 1e-5
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ fsdp_config:
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ log_prob_micro_batch_size_per_gpu: 16
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
+
+trainer:
+ balance_batch: True
+ # total_training_steps: null
+ # auto: find the last ckpt to resume. If can't find, start from scratch
+ resume_mode: auto # or auto or resume_path if
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: False
+ del_local_ckpt_after_load: False
+ val_before_train: False
diff --git a/pyproject.toml b/pyproject.toml
index c6917217ad..6ba60afab3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,6 +39,7 @@ dependencies = [
"requests",
"tensorboard",
"openai",
+ "jsonlines",
]
[project.scripts]
diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py
index 363a4939ad..e53669a850 100644
--- a/tests/buffer/file_test.py
+++ b/tests/buffer/file_test.py
@@ -9,12 +9,54 @@
get_unittest_dataset_config,
)
from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer
+from trinity.buffer.reader.file_reader import RawDataReader
from trinity.buffer.utils import default_storage_path
+from trinity.buffer.writer.file_writer import JSONWriter
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType
class TestFileBuffer(unittest.TestCase):
+ temp_output_path = "tmp/test_file_buffer/"
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ os.makedirs(cls.temp_output_path, exist_ok=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ if os.path.exists(cls.temp_output_path):
+ os.system(f"rm -rf {cls.temp_output_path}")
+
+ def test_file_buffer(self):
+ meta = StorageConfig(
+ name="test_buffer",
+ path=os.path.join(self.temp_output_path, "buffer.jsonl"),
+ storage_type=StorageType.FILE,
+ raw=True,
+ )
+ data = [
+ {"key1": 1, "key2": 2},
+ {"key1": 3, "key2": 4},
+ {"key1": 5, "key2": 6},
+ {"key1": 7, "key2": 8},
+ ]
+
+ # test writer
+ writer = JSONWriter(meta, None)
+ writer.write(data)
+ writer.finish()
+
+ # test reader
+ meta.path = self.temp_output_path
+ reader = RawDataReader(meta, None)
+ loaded_data = reader.read()
+ self.assertEqual(len(loaded_data), 4)
+ self.assertEqual(loaded_data, data)
+ self.assertRaises(StopIteration, reader.read)
+
def test_file_reader(self):
"""Test file reader."""
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
diff --git a/tests/data/controllers/task_parser_test.py b/tests/data/controllers/task_parser_test.py
index 542c491f41..af36f8777a 100644
--- a/tests/data/controllers/task_parser_test.py
+++ b/tests/data/controllers/task_parser_test.py
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
"""Test cases for data task parser."""
+import os
import unittest
import agentscope
from agentscope.models import DashScopeChatWrapper
from loguru import logger
-from trinity.common.config import Config
+from trinity.common.config import DataPipelineConfig
from trinity.data.controllers.task_parser import DataTaskParser
@@ -16,7 +17,7 @@ class TestTaskParser(unittest.TestCase):
def setUp(self) -> None:
print("setup", flush=True)
- api_key = "your_dashscope_key"
+ api_key = os.environ.get("OPENAI_API_KEY", None)
agentscope.init(
model_configs=[
@@ -43,25 +44,20 @@ def _run_test(self, rft_config, return_none=False):
logger.info("None dj config.")
else:
self.assertIsNotNone(dj_config)
- op_weights = {}
- for op in dj_config.process:
- op_name = list(op.keys())[0]
- op_weights[op_name] = op[op_name]["op_weight"]
- logger.info(op_weights)
def test_instruction1(self):
- rft_config = Config()
- rft_config.data.dj_process_desc = "Please recommend a data filtering strategy for me."
+ rft_config = DataPipelineConfig()
+ rft_config.dj_process_desc = "Please recommend a data filtering strategy for me."
self._run_test(rft_config)
def test_instruction2(self):
- rft_config = Config()
- rft_config.data.dj_process_desc = "Do nothing."
+ rft_config = DataPipelineConfig()
+ rft_config.dj_process_desc = "Do nothing."
self._run_test(rft_config, return_none=True)
def test_instruction3(self):
- rft_config = Config()
- rft_config.data.dj_process_desc = "Remove samples with repeat contents."
+ rft_config = DataPipelineConfig()
+ rft_config.dj_process_desc = "Remove samples with repeat contents."
self._run_test(rft_config)
diff --git a/tests/data/core/dataset_test.py b/tests/data/core/dataset_test.py
index be6e765fbd..76758e84d6 100644
--- a/tests/data/core/dataset_test.py
+++ b/tests/data/core/dataset_test.py
@@ -3,10 +3,7 @@
import os
import unittest
-from trinity.common.config import DataProcessorConfig, FormatConfig
-from trinity.common.rewards import AccuracyReward
-from trinity.common.task import TaskSet
-from trinity.common.workflows import MathWorkflow, SimpleWorkflow
+from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig
from trinity.data.core.dataset import RewardSchema, RftDataset
from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter
@@ -15,28 +12,38 @@ class TestRftDataset(unittest.TestCase):
"""Test cases for RftDataset"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_pipeline_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
response_key="solution",
solution_key="solution",
),
)
- self.data_config_sample_level_setting = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10_with_rewfn_workflow",
- ),
+ self.data_pipeline_config_sample_level_setting = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10_with_rewfn_workflow",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
response_key="solution",
@@ -47,13 +54,19 @@ def setUp(self) -> None:
)
def test_rft_dataset_init(self):
- dataset = RftDataset(data_config=self.data_config, reward_schema="default")
+ dataset = RftDataset(
+ data_pipeline_config=self.data_pipeline_config, reward_schema="default"
+ )
+ dataset.read_from_buffer()
self.assertEqual(len(dataset), 10)
self.assertIsInstance(dataset.reward_schema, RewardSchema)
def test_format_dataset(self):
- dataset = RftDataset(data_config=self.data_config, reward_schema="default")
+ dataset = RftDataset(
+ data_pipeline_config=self.data_pipeline_config, reward_schema="default"
+ )
+ dataset.read_from_buffer()
original_data = dataset.data
# no formatter
dataset.format(formatters=[])
@@ -62,56 +75,12 @@ def test_format_dataset(self):
# apply formatters
dataset.format(
formatters=[
- BoxedMathAnswerFormatter(config=self.data_config.format),
- RLHFFormatter(config=self.data_config.format),
+ BoxedMathAnswerFormatter(config=self.data_pipeline_config.format),
+ RLHFFormatter(config=self.data_pipeline_config.format),
]
)
self.assertNotEqual(dataset.data, original_data)
- def test_to_taskset(self):
- dataset = RftDataset(data_config=self.data_config, reward_schema="default")
- taskset = dataset.to_taskset()
- self.assertIsInstance(taskset, TaskSet)
- self.assertEqual(len(taskset), 10)
- self.assertIsNone(taskset.reward_fn)
- self.assertIsNone(taskset.workflow)
- self.assertEqual(taskset._index, 0)
-
- def test_to_taskset_with_global_settings(self):
- dataset = RftDataset(data_config=self.data_config, reward_schema="default")
- taskset = dataset.to_taskset(
- reward_fn=AccuracyReward,
- workflow=SimpleWorkflow,
- )
- self.assertIsInstance(taskset, TaskSet)
- self.assertEqual(taskset.workflow, SimpleWorkflow)
- self.assertEqual(taskset.reward_fn, AccuracyReward)
-
- def test_to_taskset_with_sample_level_settings(self):
- dataset = RftDataset(
- data_config=self.data_config_sample_level_setting, reward_schema="default"
- )
- taskset = dataset.to_taskset()
- self.assertIsInstance(taskset, TaskSet)
- for task in taskset.tasks:
- self.assertEqual(task.workflow, MathWorkflow)
- self.assertEqual(task.reward_fn, AccuracyReward)
-
- def test_to_taskset_with_both_settings(self):
- dataset = RftDataset(
- data_config=self.data_config_sample_level_setting, reward_schema="default"
- )
- taskset = dataset.to_taskset(
- reward_fn=AccuracyReward,
- workflow=SimpleWorkflow,
- )
- self.assertIsInstance(taskset, TaskSet)
- for task in taskset.tasks:
- self.assertEqual(task.workflow, MathWorkflow)
- self.assertEqual(task.reward_fn, AccuracyReward)
- self.assertEqual(taskset.workflow, SimpleWorkflow)
- self.assertEqual(taskset.reward_fn, AccuracyReward)
-
if __name__ == "__main__":
unittest.main()
diff --git a/tests/data/core/formatter_test.py b/tests/data/core/formatter_test.py
index 363c736ed9..dbb73ed971 100644
--- a/tests/data/core/formatter_test.py
+++ b/tests/data/core/formatter_test.py
@@ -3,7 +3,7 @@
import os
import unittest
-from trinity.common.config import DataProcessorConfig, FormatConfig
+from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig
from trinity.data.core.dataset import RftDataset
from trinity.data.core.formatter import (
BoxedMathAnswerFormatter,
@@ -18,14 +18,19 @@ class TestBoxedMathDataset(unittest.TestCase):
"""Test cases for RftDataset"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
response_key="answer",
@@ -43,12 +48,13 @@ def test_init(self):
self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ")
# test for default configs
self.assertEqual(formatter.config.reward_key, "")
- self.assertEqual(formatter.config.chosen_key, "")
- self.assertEqual(formatter.config.rejected_key, "")
+ self.assertEqual(formatter.config.chosen_key, "chosen")
+ self.assertEqual(formatter.config.rejected_key, "rejected")
self.assertEqual(formatter.config.label_key, "")
def test_transform(self):
- dataset = RftDataset(data_config=self.data_config, reward_schema="default")
+ dataset = RftDataset(data_pipeline_config=self.data_config, reward_schema="default")
+ dataset.read_from_buffer()
formatter = BoxedMathAnswerFormatter(config=self.data_config.format)
self.assertNotIn(formatter.config.response_key, dataset.data.column_names)
dataset.format(formatter)
@@ -59,14 +65,19 @@ class TestRLHFFormatter(unittest.TestCase):
"""Test cases for RLHFFormatter"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
chat_template="User: {}\nAssistant: ",
@@ -107,14 +118,19 @@ class TestRewardFormatter(unittest.TestCase):
"""Test cases for RewardFormatter"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
chosen_key="chosen",
@@ -164,14 +180,19 @@ class TestSFTFormatter(unittest.TestCase):
"""Test cases for SFTFormatter"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
response_key="answer",
@@ -217,14 +238,19 @@ class TestComposedFormatter(unittest.TestCase):
"""Test cases for ComposedFormatter"""
def setUp(self) -> None:
- self.data_config = DataProcessorConfig(
- source_data_path=os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "..",
- "..",
- "test_data",
- "test_10",
- ),
+ self.data_config = DataPipelineConfig(
+ input_buffers=[
+ StorageConfig(
+ path=os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ "..",
+ "..",
+ "test_data",
+ "test_10",
+ ),
+ raw=True,
+ )
+ ],
format=FormatConfig(
prompt_key="problem",
response_key="answer",
diff --git a/tests/data/processor/cleaner_test.py b/tests/data/processor/cleaner_test.py
index d21a6960c5..ef2aa13d20 100644
--- a/tests/data/processor/cleaner_test.py
+++ b/tests/data/processor/cleaner_test.py
@@ -15,7 +15,7 @@ def setUp(self) -> None:
print("setup", flush=True)
self.rft_config = load_config("./tests/test_configs/cleaner_test_rft_cfg.yaml")
- print(self.rft_config)
+ # print(self.rft_config)
self.ds_list = [
{"text": "Today is"},
{"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"},
@@ -25,95 +25,67 @@ def setUp(self) -> None:
]
def _run_test(self, tgt_list, weight=1, data_dist="gaussian"):
- task_parser = DataTaskParser(self.rft_config)
+ task_parser = DataTaskParser(self.rft_config.data_processor.task_pipeline)
dj_config, _, _, _ = task_parser.parse_to_dj_config()
+ op_weights = {}
for op_config in dj_config.process:
- _, op_args = list(op_config.items())[0]
- op_args["op_weight"] = weight
+ op_name, _ = list(op_config.items())[0]
+ op_weights[op_name] = weight
cleaner = DataCleaner(
dj_config,
clean_strategy="iterative",
- min_size_ratio=self.rft_config.data.min_size_ratio,
+ min_size_ratio=self.rft_config.data_processor.task_pipeline.min_size_ratio,
data_dist=data_dist,
+ op_weights=op_weights,
)
- dataset = RftDataset(self.rft_config.data)
+ dataset = RftDataset(self.rft_config.data_processor.task_pipeline)
+ dataset.read_from_buffer()
dataset = cleaner.process([dataset])
- res_list = dataset.to_list()
+ res_list = dataset.data.select_columns("text").to_list()
+ print(res_list)
self.assertEqual(res_list, tgt_list)
self.assertNotIn("clean_email_mapper", cleaner.dj_cfg.process)
def test_dj_executor(self):
tgt_list = [
- {
- "text": "a v s e c s f e f g a a a ",
- "__dj__stats__": {"text_len": 27},
- },
- {
- "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►",
- "__dj__stats__": {"text_len": 34},
- },
- {
- "text": "中文也是一个字算一个长度",
- "__dj__stats__": {"text_len": 12},
- },
+ {"text": "Today is"},
+ {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"},
+ {"text": "a v s e c s f e f g a a a "},
+ {"text": "中文也是一个字算一个长度"},
]
- self.rft_config.data.min_size_ratio = None
+ self.rft_config.data_processor.task_pipeline.min_size_ratio = None
self._run_test(tgt_list)
def test_iterative_clean(self):
tgt_list = [
- {
- "text": "a v s e c s f e f g a a a ",
- "__dj__stats__": {"text_len": 27},
- },
- {
- "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►",
- "__dj__stats__": {"text_len": 34},
- },
+ {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"},
+ {"text": "a v s e c s f e f g a a a "},
]
- self.rft_config.data.min_size_ratio = 0.5
+ self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5
self._run_test(tgt_list)
def test_weight(self):
tgt_list = [
- {
- "text": "a v s e c s f e f g a a a ",
- "__dj__stats__": {"text_len": 27},
- },
- {
- "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►",
- "__dj__stats__": {"text_len": 34},
- },
- {
- "text": "中文也是一个字算一个长度",
- "__dj__stats__": {"text_len": 12},
- },
+ {"text": "Today is"},
+ {"text": "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!"},
+ {"text": "a v s e c s f e f g a a a "},
]
- self.rft_config.data.min_size_ratio = 0.5
+ self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5
self._run_test(tgt_list, weight=0.5)
def test_uniform_dist(self):
- tgt_list = [
- {
- "text": "a v s e c s f e f g a a a ",
- "__dj__stats__": {"text_len": 27},
- },
- {
- "text": ",。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►",
- "__dj__stats__": {"text_len": 34},
- },
- ]
+ tgt_list = []
- self.rft_config.data.min_size_ratio = 0.5
+ self.rft_config.data_processor.task_pipeline.min_size_ratio = 0.5
self._run_test(tgt_list, data_dist="uniform")
diff --git a/tests/test_configs/active_iterator_test_cfg.yaml b/tests/test_configs/active_iterator_test_cfg.yaml
index 3b105e1f66..3e6008b7cf 100644
--- a/tests/test_configs/active_iterator_test_cfg.yaml
+++ b/tests/test_configs/active_iterator_test_cfg.yaml
@@ -1,13 +1,18 @@
data_processor:
# basic info
- source_data_path: 'tests/test_data/test_10/'
- load_kwargs:
- split: 'train'
- format:
- prompt_key: 'problem'
- response_key: 'solution'
- # cleaner related
- dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml'
- clean_strategy: 'iterative'
- # db related
- db_url: 'postgresql://{username}@localhost:5432/{db_name}'
+ task_pipeline:
+ input_buffers:
+ - name: 'raw_input'
+ path: 'tests/test_data/test_10/'
+ storage_type: 'file'
+ raw: true
+ output_buffer:
+ name: 'raw_output'
+ path: './outputs/task_pipeline_output/processed.jsonl'
+ storage_type: 'file'
+ format:
+ prompt_key: 'problem'
+ response_key: 'solution'
+ # cleaner related
+ dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml'
+ clean_strategy: 'iterative'
diff --git a/tests/test_configs/active_iterator_test_dj_cfg.yaml b/tests/test_configs/active_iterator_test_dj_cfg.yaml
index f7f848e338..367709968f 100644
--- a/tests/test_configs/active_iterator_test_dj_cfg.yaml
+++ b/tests/test_configs/active_iterator_test_dj_cfg.yaml
@@ -1,7 +1,5 @@
project_name: 'demo-process'
-export_path: './outputs/demo-process/demo-processed.jsonl'
-
text_keys: 'solution'
process:
diff --git a/tests/test_configs/cleaner_test_dj_cfg.yaml b/tests/test_configs/cleaner_test_dj_cfg.yaml
index 9e2da88d64..cf11488963 100644
--- a/tests/test_configs/cleaner_test_dj_cfg.yaml
+++ b/tests/test_configs/cleaner_test_dj_cfg.yaml
@@ -3,7 +3,5 @@ project_name: 'demo-process'
export_path: './outputs/demo-process/demo-processed.jsonl'
process:
- - text_length_filter:
- min_len: 10
- max_len: 50
+ - alphanumeric_filter:
- clean_email_mapper:
diff --git a/tests/test_configs/cleaner_test_rft_cfg.yaml b/tests/test_configs/cleaner_test_rft_cfg.yaml
index 7f8581c0ef..c78e3a1ac8 100644
--- a/tests/test_configs/cleaner_test_rft_cfg.yaml
+++ b/tests/test_configs/cleaner_test_rft_cfg.yaml
@@ -1,5 +1,7 @@
data_processor:
- source_data_path: './tests/test_data/test_cleaner'
- load_kwargs: {"split": "train"}
- dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml'
- clean_strategy: 'iterative'
+ task_pipeline:
+ input_buffers:
+ - path: './tests/test_data/test_cleaner'
+ raw: true
+ dj_config_path: './tests/test_configs/cleaner_test_dj_cfg.yaml'
+ clean_strategy: 'iterative'
diff --git a/tests/test_configs/human_annotator_test_rft_cfg.yaml b/tests/test_configs/human_annotator_test_rft_cfg.yaml
index 79d8b8108b..b20f015182 100644
--- a/tests/test_configs/human_annotator_test_rft_cfg.yaml
+++ b/tests/test_configs/human_annotator_test_rft_cfg.yaml
@@ -1,10 +1,10 @@
data_processor:
- source_data_path: './tests/test_data/test_human_annotator'
- load_kwargs: {"split": "train"}
- dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml'
- format:
- prompt_key: 'prompt'
- chosen_key: 'chosen'
- rejected_key: 'rejected'
- # db related
- db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
+ task_pipeline:
+ input_buffers:
+ - path: './tests/test_data/test_human_annotator'
+ raw: true
+ dj_config_path: './tests/test_configs/human_annotator_test_dj_cfg.yaml'
+ format:
+ prompt_key: 'prompt'
+ chosen_key: 'chosen'
+ rejected_key: 'rejected'
diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py
index 73bda50d5b..060ed05b9e 100644
--- a/trinity/buffer/buffer.py
+++ b/trinity/buffer/buffer.py
@@ -42,7 +42,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
from trinity.buffer.reader.file_reader import FILE_READERS
algorithm_type = storage_config.algorithm_type
- if algorithm_type is not None:
+ if storage_config.raw:
+ file_read_type = "raw"
+ elif algorithm_type is not None:
file_read_type = algorithm_type
else:
file_read_type = "rollout"
diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py
index 63de366db6..b7cf06b2b5 100644
--- a/trinity/buffer/ray_wrapper.py
+++ b/trinity/buffer/ray_wrapper.py
@@ -142,6 +142,8 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
raise ValueError(
f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
)
+ path_dir = os.path.dirname(storage_config.path)
+ os.makedirs(path_dir, exist_ok=True)
self.file = open(storage_config.path, "a", encoding="utf-8")
self.encoder = _Encoder(ensure_ascii=False)
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 0dd9aef75e..507bfb7c82 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -67,7 +67,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.response_key = meta.format.response_key
self.read_batch_size = config.read_batch_size
self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=subset_name, split=self.split),
+ load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
max_epoch=meta.total_epochs,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -144,7 +144,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.rejected_key = meta.format.rejected_key
self.read_batch_size = config.read_batch_size
self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=subset_name, split=self.split),
+ load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
max_epoch=meta.total_epochs,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -216,7 +216,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.epoch = 0
datasets.disable_caching()
self.dataset = _HFBatchReader(
- load_dataset(meta.path, name=subset_name, split=self.split),
+ load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
offset=self.meta.index,
)
@@ -261,3 +261,23 @@ def read(
)
tasks.append(task)
return tasks
+
+
+@FILE_READERS.register_module("raw")
+class RawDataReader(BufferReader):
+ def __init__(self, meta: StorageConfig, config: Optional[BufferConfig]):
+ self.returned = False
+ self.dataset = load_dataset(
+ meta.path, name=meta.subset_name, split=meta.split, trust_remote_code=True
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def read(
+ self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
+ ) -> List:
+ if self.returned:
+ raise StopIteration
+ self.returned = True
+ return self.dataset.to_list()
diff --git a/trinity/cli/client.py b/trinity/cli/client.py
index 311de1b9d8..cc3318b570 100644
--- a/trinity/cli/client.py
+++ b/trinity/cli/client.py
@@ -31,12 +31,12 @@ def request(url, **kwargs):
if __name__ == "__main__":
# --- only for local testing
- LOCAL_DATA_WORKFLOW_SERVER_URL = "http://127.0.0.1:5005/data_workflow"
+ LOCAL_DATA_PROCESSOR_SERVER_URL = "http://127.0.0.1:5005/data_processor"
LOCAL_TRINITY_TRAINING_SERVER_URL = "http://127.0.0.1:5006/trinity_rft"
# --- only for local testing
res = request(
- url=LOCAL_DATA_WORKFLOW_SERVER_URL,
+ url=LOCAL_DATA_PROCESSOR_SERVER_URL,
configPath="examples/grpo_gsm8k/gsm8k.yaml",
)
if res:
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index a63b06a36d..3ea4f0486f 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -8,7 +8,7 @@
import ray
-from trinity.common.config import Config, load_config
+from trinity.common.config import Config, DataPipelineConfig, load_config
from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
@@ -112,13 +112,13 @@ def both(config: Config) -> None:
trainer.shutdown.remote()
-def activate_data_module(data_workflow_url: str, config_path: str):
+def activate_data_module(data_processor_url: str, config_path: str):
"""Check whether to activate data module and preprocess datasets."""
from trinity.cli.client import request
- logger.info("Activating data module...")
+ logger.info(f"Activating data module of {data_processor_url}...")
res = request(
- url=data_workflow_url,
+ url=data_processor_url,
configPath=config_path,
)
if res["return_code"] != 0:
@@ -126,17 +126,71 @@ def activate_data_module(data_workflow_url: str, config_path: str):
return
+def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str):
+ """
+ Check if the data pipeline is valid. The config should:
+ 1. Non-empty input buffer
+ 2. Different input/output buffers
+
+ :param data_pipeline_config: the input data pipeline to be validated.
+ :param pipeline_type: the type of pipeline, should be one of ["task", "experience"]
+ """
+ input_buffers = data_pipeline_config.input_buffers
+ output_buffer = data_pipeline_config.output_buffer
+ # common checks
+ # check if the input buffer list is empty
+ if len(input_buffers) == 0:
+ logger.warning("Empty input buffers in the data pipeline. Won't activate it.")
+ return False
+ # check if the input and output buffers are different
+ input_buffer_names = [buffer.name for buffer in input_buffers]
+ if output_buffer.name in input_buffer_names:
+ logger.warning("Output buffer exists in input buffers. Won't activate it.")
+ return False
+ if pipeline_type == "task":
+ # task pipeline specific
+ # "raw" field should be True for task pipeline because the data source must be raw data files
+ for buffer in input_buffers:
+ if not buffer.raw:
+ logger.warning(
+ 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.'
+ )
+ return False
+ elif pipeline_type == "experience":
+ # experience pipeline specific
+ raise NotImplementedError("experience_pipeline is not implemented yet.")
+ else:
+ logger.warning(
+ f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].'
+ )
+ return False
+ return True
+
+
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
load_plugins(plugin_dir)
config = load_config(config_path)
config.check_and_update()
pprint(config)
- # try to activate data module
+ # try to activate task pipeline for raw data
data_processor_config = config.data_processor
- if data_processor_config.data_workflow_url and (
- data_processor_config.dj_config_path or data_processor_config.dj_process_desc
+ if (
+ data_processor_config.data_processor_url
+ and data_processor_config.task_pipeline
+ and validate_data_pipeline(data_processor_config.task_pipeline, "task")
):
- activate_data_module(data_processor_config.data_workflow_url, config_path)
+ activate_data_module(
+ f"{data_processor_config.data_processor_url}/task_pipeline", config_path
+ )
+ # try to activate experience pipeline for experiences
+ if (
+ data_processor_config.data_processor_url
+ and data_processor_config.experience_pipeline
+ and validate_data_pipeline(data_processor_config.experience_pipeline, "experience")
+ ):
+ activate_data_module(
+ f"{data_processor_config.data_processor_url}/experience_pipeline", config_path
+ )
ray_namespace = f"{config.project}-{config.name}"
if dlc:
from trinity.utils.dlc_utils import setup_ray_cluster
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 1409fa33f3..f4480da311 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -70,6 +70,9 @@ class StorageConfig:
storage_type: StorageType = StorageType.FILE
path: Optional[str] = None
+ # only available for StorageType.FILE. When requiring data processing on raw data, set the raw to True.
+ raw: bool = False
+
# used for StorageType.FILE
split: str = "train"
subset_name: Optional[str] = None
@@ -99,16 +102,17 @@ class StorageConfig:
@dataclass
-class DataProcessorConfig:
- """Data-Juicer config"""
+class DataPipelineConfig:
+ """Config for data pipeline."""
- data_workflow_url: Optional[str] = None
+ # I/O buffer
+ input_buffers: List[StorageConfig] = field(default_factory=list)
+ output_buffer: StorageConfig = field(default_factory=StorageConfig)
- source_data_path: str = ""
+ # data format
format: FormatConfig = field(default_factory=FormatConfig)
# data active iterator related
- load_kwargs: Dict[str, Any] = field(default_factory=dict)
dj_config_path: Optional[str] = None # The path to Data-Juicer config file.
dj_process_desc: Optional[
str
@@ -121,10 +125,18 @@ class DataProcessorConfig:
priority_weights: Optional[Dict[str, float]] = None
data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"]
- # dataset database related
- db_url: str = ""
- max_retry_times: int = 3
- max_retry_interval: int = 1
+
+@dataclass
+class DataProcessorConfig:
+ """Data-Juicer config"""
+
+ data_processor_url: Optional[str] = None
+
+ # support two types of data pipelines for now
+ # 1. For task. Data preprocessing from raw dataset to the task set
+ task_pipeline: Optional[DataPipelineConfig] = None
+ # 2. For experience. Data processing for rollouts
+ experience_pipeline: Optional[DataPipelineConfig] = None
@dataclass
diff --git a/trinity/data/controllers/active_iterator.py b/trinity/data/controllers/active_iterator.py
index 40da73384b..963a1015a9 100644
--- a/trinity/data/controllers/active_iterator.py
+++ b/trinity/data/controllers/active_iterator.py
@@ -1,14 +1,14 @@
import os
import traceback
+from numbers import Number
from typing import Any, Dict, List
import ray
-from trinity.common.config import Config
+from trinity.common.config import BufferConfig, DataPipelineConfig
from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS
from trinity.data.controllers.task_parser import DataTaskParser
from trinity.data.core.dataset import RftDataset
-from trinity.data.core.dataset_db import RftDatasetDB
from trinity.data.processors.cleaner import DataCleaner
from trinity.data.processors.human_annotator import DataHumanAnnotator
from trinity.data.processors.synthesizer import DataSynthesizer
@@ -21,42 +21,39 @@ class DataActiveIterator:
def __init__(
self,
- config: Config,
+ config: DataPipelineConfig,
+ buffer_config: BufferConfig,
):
self.config = config
- self.data_config = config.data
- if (
- self.data_config.agent_model_name is not None
- and self.data_config.agent_model_config is not None
- ):
+ self.buffer_config = buffer_config
+ if self.config.agent_model_name is not None and self.config.agent_model_config is not None:
# get the api key
api_key = os.environ.get("OPENAI_API_KEY")
# initialize the agent
import agentscope
from agentscope.models import DashScopeChatWrapper
- agentscope.init(model_configs=[self.data_config.agent_model_config])
+ agentscope.init(model_configs=[self.config.agent_model_config])
self.llm_agent = DashScopeChatWrapper(
config_name="_",
- model_name=self.data_config.agent_model_name,
+ model_name=self.config.agent_model_name,
api_key=api_key,
stream=False,
)
else:
self.llm_agent = None
self.task_parser = DataTaskParser(config, self.llm_agent)
- self.dsdb = RftDatasetDB(self.data_config)
# Priority weights
# larger positive values means larger scores --> higher priority
# smaller negative values means lower scores --> higher priority
- self.priority_weights = self.data_config.priority_weights or {
+ self.priority_weights = self.config.priority_weights or {
"difficulty": -0.7,
"diversity": 0.8,
"usage_frequency": -0.5,
"quality": 1.0,
}
- self.min_priority_score = self.data_config.min_priority_score
+ self.min_priority_score = self.config.min_priority_score
# Statistics tracking
self.state = {"iterations": 0, "samples_selected": 0, "avg_priority_score": 0.0}
@@ -67,17 +64,17 @@ def __init__(
# 2. input_keys: [prompt_key, response_key] if they are available
# 3. field_names: [prompt_key, response_key] if they are available
self.updated_op_args = {
- "text_key": self.data_config.format.prompt_key,
+ "text_key": self.config.format.prompt_key,
"input_keys": [
- self.data_config.format.prompt_key,
+ self.config.format.prompt_key,
],
"field_names": [
- self.data_config.format.prompt_key,
+ self.config.format.prompt_key,
],
}
- if self.data_config.format.response_key != "":
- self.updated_op_args["input_keys"].append(self.data_config.format.response_key)
- self.updated_op_args["field_names"].append(self.data_config.format.response_key)
+ if self.config.format.response_key != "":
+ self.updated_op_args["input_keys"].append(self.config.format.response_key)
+ self.updated_op_args["field_names"].append(self.config.format.response_key)
# flake8: noqa: C901
def run(self):
@@ -94,9 +91,9 @@ def run(self):
traceback.print_exc()
return 1, "config parsing failed."
- # step 2. load dataset
+ # step 2. load data from the input buffers
try:
- dataset = RftDataset(self.data_config)
+ dataset = RftDataset(self.config, self.buffer_config)
except Exception:
traceback.print_exc()
return 2, "RftDataset loading failed."
@@ -106,9 +103,9 @@ def run(self):
if hit_cleaner:
cleaner = DataCleaner(
dj_config,
- clean_strategy=self.data_config.clean_strategy,
- min_size_ratio=self.data_config.min_size_ratio,
- data_dist=self.data_config.data_dist,
+ clean_strategy=self.config.clean_strategy,
+ min_size_ratio=self.config.min_size_ratio,
+ data_dist=self.config.data_dist,
)
if hit_synthesizer:
synthesizer = DataSynthesizer(
@@ -122,43 +119,61 @@ def run(self):
traceback.print_exc()
return 3, "DataCleaner loading failed."
- # step 4. apply processors to calculate scores of different dimensions
- try:
- res_dataset = dataset
- if hit_cleaner:
- res_dataset = cleaner.process([res_dataset])
- if hit_synthesizer:
- res_dataset = synthesizer.process([res_dataset])
- if hit_human_annotator:
- res_dataset = human_annotator.process([res_dataset])
- except Exception:
- traceback.print_exc()
- return 4, "DataProcessors processing failed."
-
- # step 5. calculate the average and final scores, including priority
- try:
- if hit_cleaner:
- scored_dataset = self._group_scores(res_dataset)
- scored_dataset = self._compute_priority_scores(scored_dataset)
- else:
- scored_dataset = res_dataset
- except Exception:
- traceback.print_exc()
- return 5, "Grouping and computing priority score failed."
-
- # step 6. track lineage if they are changed
- try:
- res_dataset = scored_dataset
- except Exception:
- traceback.print_exc()
- return 6, "Tracking lineage failed."
-
- # step 7. export the result to the database
- try:
- self.dsdb.add_entries(res_dataset)
- except Exception:
- traceback.print_exc()
- return 7, "Exporting result to database failed."
+ while True:
+ # step 4. load data from the input buffers for the next batch
+ try:
+ dataset.read_from_buffer()
+ except StopIteration:
+ break
+ except Exception:
+ traceback.print_exc()
+ return 4, "RftDataset loading from buffers failed."
+
+ # step 5. apply processors to calculate scores of different dimensions
+ try:
+ res_dataset = dataset
+ if hit_cleaner:
+ res_dataset = cleaner.process([res_dataset])
+ if hit_synthesizer:
+ res_dataset = synthesizer.process([res_dataset])
+ if hit_human_annotator:
+ res_dataset = human_annotator.process([res_dataset])
+ except Exception:
+ traceback.print_exc()
+ return 5, "DataProcessors processing failed."
+
+ # step 6. calculate the average and final scores, including priority
+ try:
+ if hit_cleaner:
+ scored_dataset = self._group_scores(res_dataset)
+ scored_dataset = self._compute_priority_scores(scored_dataset)
+ else:
+ scored_dataset = res_dataset
+ except Exception:
+ traceback.print_exc()
+ return 6, "Grouping and computing priority score failed."
+
+ # step 7. track lineage if they are changed
+ try:
+ res_dataset = scored_dataset
+ except Exception:
+ traceback.print_exc()
+ return 7, "Tracking lineage failed."
+
+ # step 8
+ try:
+ if "priority" in res_dataset.data.features:
+ res_dataset.sort_by("priority", reverse=True)
+ except Exception:
+ traceback.print_exc()
+ return 8, "Sorting results by priority failed."
+
+ # step 9. sort and export the result to the output buffer
+ try:
+ res_dataset.write_to_buffer()
+ except Exception:
+ traceback.print_exc()
+ return 9, "Exporting result to output buffer failed."
return 0, "success"
@@ -171,7 +186,8 @@ def _group_scores(self, dataset: RftDataset) -> RftDataset:
all_stats = [
sample[Fields.stats][stats] for sample in dataset.data if Fields.stats in sample
]
- stats_min_max[stats] = [min(all_stats), max(all_stats)]
+ if len(all_stats) > 0 and isinstance(all_stats[0], Number):
+ stats_min_max[stats] = [min(all_stats), max(all_stats)]
def _group_single(sample):
stats = sample[Fields.stats]
@@ -240,7 +256,7 @@ def _compute_combined_score(
difficulty = stats.get("difficulty_score", 0.5)
score += self.priority_weights["difficulty"] * difficulty
- sample["priority"] = [score]
+ sample["priority"] = [score] if isinstance(sample[Fields.stats], list) else score
return sample
def _compute_diversity_score(self) -> float:
@@ -252,10 +268,6 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset:
dataset.data = dataset.data.map(self._compute_combined_score)
return dataset
- def _select_top_k(self, dataset: RftDataset, k: int) -> List:
- """Select top-k samples based on utility scores"""
- return dataset.data.sort("priority", reverse=True).take(k).to_list()
-
@ray.method(num_returns=1)
def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]:
"""Select a batch of samples for training"""
@@ -267,7 +279,8 @@ def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, A
dataset.data = dataset.data.filter(lambda s: s["priority"] >= self.min_priority_score)
# Select top-k samples
- selected_samples = self._select_top_k(dataset, batch_size)
+ dataset.sort_by("priority", reverse=True, top_k=batch_size)
+ selected_samples = dataset.data.to_list()
# Update state
self._update_state(selected_samples, dataset.data["priority"])
diff --git a/trinity/data/controllers/task_parser.py b/trinity/data/controllers/task_parser.py
index 23b169ab2d..2e30dace63 100644
--- a/trinity/data/controllers/task_parser.py
+++ b/trinity/data/controllers/task_parser.py
@@ -7,7 +7,7 @@
from jsonargparse import Namespace
from loguru import logger
-from trinity.common.config import Config
+from trinity.common.config import DataPipelineConfig
from trinity.data.core.dataset import RftDataset
from .default_ops import (
@@ -128,7 +128,7 @@ class DataTaskParser:
def __init__(
self,
- rft_config: Config,
+ data_pipeline_config: DataPipelineConfig,
llm_agent: DashScopeChatWrapper = None,
dataset: RftDataset = None,
validate_config: bool = True,
@@ -136,12 +136,12 @@ def __init__(
"""
Initialization method.
- :param rft_config: All configs.
+ :param data_pipeline_config: All configs of specified data pipeline.
:param llm_agent: The LLM agent for natural language parsing.
:param dataset: The dataset to be processed.
:param validate_config: If execute the config validation check.
"""
- self.config = rft_config.data
+ self.config = data_pipeline_config
self.llm_agent = llm_agent
self.validate_config = validate_config
# TODO: refer dataset to support natural language parsing.
@@ -164,15 +164,21 @@ def parse_to_dj_config(self, extra_op_args=None):
return dj_config, hit_cleaner, hit_synthesizer, hit_human_annotator
def _check_types_of_processors(self, dj_config):
+ if dj_config is None:
+ return False, False, False
hit_cleaner, hit_synthesizer, hit_human_annotator = False, False, False
- for op in dj_config.process:
+ process_list = dj_config.get("process", [])
+ for op in process_list:
op_name = list(op.keys())[0]
- if op_name in DEFAULT_CLEANER:
- hit_cleaner = True
- elif op_name in DEFAULT_SYNTHESIZER:
+ if op_name in DEFAULT_SYNTHESIZER:
hit_synthesizer = True
elif op_name in DEFAULT_HUMAN_ANNOTATOR:
hit_human_annotator = True
+ else:
+ for dimension in DEFAULT_CLEANER:
+ if op_name in DEFAULT_CLEANER[dimension]:
+ hit_cleaner = True
+ break
return hit_cleaner, hit_synthesizer, hit_human_annotator
def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> Namespace:
@@ -185,20 +191,10 @@ def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> N
print(op)
return dj_config
- def _add_extra_args(self, dj_config: Namespace, op_weights: Dict = {}) -> Namespace:
- """Add extra argument for RFT project"""
- for op in dj_config.process:
- op_name = list(op.keys())[0]
- if "op_weight" not in op[op_name]:
- op[op_name]["op_weight"] = op_weights[op_name] if op_name in op_weights else 1
- op[op_name]["op_weight"] = max(0, op[op_name]["op_weight"])
- return dj_config
-
def _direct_mapping(self) -> Namespace:
"""Direct mapping from RFT config to DJ config"""
dj_config = prepare_side_configs(self.config.dj_config_path)
dj_config = get_init_configs(dj_config)
- dj_config = self._add_extra_args(dj_config)
return dj_config
def _agent_based_parsing(self, extra_op_args=None, try_num=3) -> Namespace:
@@ -251,13 +247,11 @@ def _parse_llm_response(self, response: ModelResponse, extra_op_args=None):
other_op_args = DEFAULT_OP_ARGS
dj_process = []
- op_weights = {}
def json_to_dj_config(parsed_json):
for dim in set(parsed_json.keys()) & set(cleaners.keys()):
for op_name in set(parsed_json[dim].keys()) & set(cleaners[dim].keys()):
dj_process.append({op_name: {}})
- op_weights[op_name] = float(parsed_json[dim][op_name])
json_match = re.search(r"```json\n(.*?)\n```", response.text, re.DOTALL)
if json_match:
@@ -284,20 +278,5 @@ def json_to_dj_config(parsed_json):
op[op_name][key] = val
dj_config = Namespace(process=dj_process)
dj_config = get_init_configs(dj_config)
- dj_config = self._add_extra_args(dj_config, op_weights)
-
- if self.validate_config and not self._validate_config(dj_config):
- return None
return dj_config
-
- def _validate_config(self, config: Namespace) -> bool:
- """Validate generated DJ config"""
- try:
- for op in config.process:
- op_name = list(op.keys())[0]
- weight = float(op[op_name]["op_weight"])
- assert 0 <= weight and weight <= 1
- except Exception:
- return False
- return True
diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py
index 3e4af0fe12..93be832cc7 100644
--- a/trinity/data/core/dataset.py
+++ b/trinity/data/core/dataset.py
@@ -3,13 +3,10 @@
from typing import Any, Dict, List, Optional, Union
import networkx as nx
-from data_juicer.core.data.dj_dataset import Dataset
-from datasets import load_dataset
+from datasets import Dataset, concatenate_datasets
-from trinity.common.config import DataProcessorConfig
-from trinity.common.rewards import REWARD_FUNCTIONS
-from trinity.common.task import TaskSet
-from trinity.common.workflows import WORKFLOWS
+from trinity.buffer import get_buffer_reader, get_buffer_writer
+from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig
from trinity.data.core.formatter import BaseDataFormatter
@@ -31,25 +28,27 @@ class RftDataset:
4. Basic statistics and metrics computation
Args:
- config (Dict): Configuration dict including DJ config
+ data_pipeline_config (DataPipelineConfig): Configuration including DJ config
reward_schema (Union[str, Dict]): Schema definition for reward fields
track_lineage (bool): Whether to track data lineage
"""
def __init__(
self,
- data_config: DataProcessorConfig,
+ data_pipeline_config: DataPipelineConfig,
+ buffer_config: BufferConfig = None,
reward_schema: Union[str, Dict] = "default",
track_lineage: bool = True,
):
- self.config = data_config
- source_data_path = data_config.source_data_path
- if not source_data_path:
- raise ValueError("source_data_path is not specified in DJ config")
- load_kwargs = data_config.load_kwargs
- self.data = load_dataset(source_data_path, trust_remote_code=True, **load_kwargs)
-
- self.format = data_config.format
+ self.config = data_pipeline_config
+ self.buffer_config = buffer_config
+ input_buffer_configs = self.config.input_buffers
+ if len(input_buffer_configs) == 0:
+ raise ValueError("input_buffers is empty in data pipeline config")
+ self.buffers = []
+ for input_buffer_config in input_buffer_configs:
+ self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config))
+ self.data = Dataset.from_list([])
self.reward_schema = self._init_reward_schema(reward_schema)
self.stats: Dict[str, Any] = {}
@@ -65,15 +64,28 @@ def format(
for formatter in formatters:
self.data = formatter(self.data, num_proc)
- def to_taskset(self, **kwargs) -> TaskSet:
- default_workflow_cls = WORKFLOWS.get(self.config.default_workflow_type)
- default_reward_fn_cls = REWARD_FUNCTIONS.get(self.config.default_reward_fn_type)
- return TaskSet(
- dataset=self.data,
- config=self.config,
- workflow=default_workflow_cls,
- reward_fn=default_reward_fn_cls,
- )
+ def sort_by(self, key: str, reverse: bool = False, top_k: int = -1):
+ if top_k == -1:
+ top_k = len(self.data)
+ self.data = self.data.sort(key, reverse=reverse).take(top_k)
+
+ def read_from_buffer(self):
+ datasets = []
+ for buffer in self.buffers:
+ datasets.append(Dataset.from_list(buffer.read()))
+ self.data = concatenate_datasets(datasets)
+
+ def write_to_buffer(
+ self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None
+ ):
+ if output_storage_config is None:
+ output_storage_config = self.config.output_buffer
+ if buffer_config is None:
+ buffer_config = self.buffer_config
+ output_buffer = get_buffer_writer(output_storage_config, buffer_config)
+ output_buffer.write(self.data.to_list())
+ output_buffer.finish()
+ self.data = Dataset.from_list([])
def to_parquet(self, path: str):
self.data.to_parquet(path)
diff --git a/trinity/data/core/dataset_db.py b/trinity/data/core/dataset_db.py
deleted file mode 100644
index f47b138995..0000000000
--- a/trinity/data/core/dataset_db.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from typing import List
-
-from sqlalchemy import asc, create_engine, desc
-from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.pool import NullPool
-
-from trinity.buffer.utils import retry_session
-from trinity.common.config import DataProcessorConfig
-from trinity.common.schema import Base, RftDatasetModel
-from trinity.data.core.dataset import RftDataset
-from trinity.utils.log import get_logger
-
-logger = get_logger(__name__)
-
-
-def rft_dataset_to_model(dataset: RftDataset) -> List[RftDatasetModel]:
- # hit keys of schema
- hit_schema_keys = []
- hit_dataset_keys = []
- # get hit keys & vals
- # - for content keys, we need to map it with content_key_mapping and try to
- # find them in the dataset
- # - for other keys, we just need to check if they are in the dataset
- data = dataset.data
- features = data.features
- content_key_mapping = dataset.format.__dict__
- schema_keys = {key for key in RftDatasetModel.__dict__.keys() if not key.startswith("_")}
- for schema_key in schema_keys:
- key = schema_key
- if f"{schema_key}_key" in content_key_mapping:
- key = content_key_mapping[f"{schema_key}_key"]
- if key in features:
- hit_schema_keys.append(schema_key)
- hit_dataset_keys.append(key)
- # construct entries
- entries = []
- for sample in data:
- valid_data = {
- schema_key: sample[key] for schema_key, key in zip(hit_schema_keys, hit_dataset_keys)
- }
- entries.append(RftDatasetModel(**valid_data))
- return entries
-
-
-class RftDatasetDB:
- def __init__(self, config: DataProcessorConfig) -> None:
- self.db_url = config.db_url
- self.engine = create_engine(self.db_url, poolclass=NullPool)
- self.config = config
- try:
- Base.metadata.create_all(self.engine, checkfirst=True)
- except OperationalError:
- logger.warning("Failed to create database, assuming it already exists.")
- self.session = sessionmaker(bind=self.engine)
-
- def add_entries(self, dataset: RftDataset):
- with retry_session(
- self, self.config.max_retry_times, self.config.max_retry_interval
- ) as session:
- session.add_all(rft_dataset_to_model(dataset))
-
- def get_entries(self, num_entries: int, order_by: str = None, ascending: bool = False):
- # get num_entries entries from the database
- if order_by is not None and hasattr(RftDatasetModel, order_by):
- order_by_key = getattr(RftDatasetModel, order_by)
- order_by_key = asc(order_by_key) if ascending else desc(order_by_key)
- else:
- order_by_key = None
- with retry_session(
- self, self.config.max_retry_times, self.config.max_retry_interval
- ) as session:
- entries = (
- session.query(RftDatasetModel)
- .order_by(order_by_key)
- .limit(num_entries)
- .with_for_update()
- .all()
- )
-
- for entry in entries:
- entry.consumed_cnt += 1
- samples = [entry.to_dict() for entry in entries]
- return samples
diff --git a/trinity/data/processors/cleaner.py b/trinity/data/processors/cleaner.py
index b031e528e1..10979990b1 100644
--- a/trinity/data/processors/cleaner.py
+++ b/trinity/data/processors/cleaner.py
@@ -36,6 +36,7 @@ def __init__(
clean_strategy: str = "iterative",
min_size_ratio: PositiveFloat = None,
data_dist: str = "gaussian",
+ op_weights: dict = None,
**kwargs,
):
"""
@@ -54,6 +55,7 @@ def __init__(
self.min_size_ratio = min_size_ratio
self.data_dist = data_dist
self.op_name_to_stats_key = {}
+ self.op_weights = op_weights
def keep_cleaner_op_cfg(self, dj_cfg):
"""Only consider cleaner op in data-juicer configs."""
@@ -112,7 +114,7 @@ def update_op_threshold(
update_record = {}
for process in exe_cfg.process:
op_name, args = list(process.items())[0]
- op_weight = args["op_weight"]
+ op_weight = self.op_weights.get(op_name, 1)
update_record[op_name] = {}
temp_args = copy.deepcopy(args)
@@ -164,7 +166,7 @@ def process(
else:
logger.info("Executing Data-Juicer analyzer...")
analyzer = Analyzer(self.dj_cfg)
- analyzer.run(dataset)
+ analyzer.run(dataset, skip_export=True)
df = analyzer.overall_result
mean_series = df[df.index == "mean"]
stats_key_to_mean = mean_series.iloc[0, :].to_dict()
diff --git a/trinity/data/readme.md b/trinity/data/readme.md
index 3294819f43..4b5c828ee6 100644
--- a/trinity/data/readme.md
+++ b/trinity/data/readme.md
@@ -88,14 +88,14 @@ synth_data = synthesizer.process(clean_data)
- Then you need to prepare the `data_processor` section in the config file (e.g. [test_cfg.yaml](tests/test_configs/active_iterator_test_cfg.yaml))
- For the `dj_config_path` argument in it, you can either specify a data-juicer config file path (e.g. [test_dj_cfg.yaml](tests/test_configs/active_iterator_test_dj_cfg.yaml)), or write the demand in `dj_process_desc` argument in natural language and our agent will help you to organize the data-juicer config.
- Finally you can send requests to the data server to start an active iterator to process datasets in many ways:
- - Request with `curl`: `curl "http://127.0.0.1:5000/data_workflow?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"`
+ - Request with `curl`: `curl "http://127.0.0.1:5005/data_processor/task_pipeline?configPath=tests%2Ftest_configs%2Factive_iterator_test_cfg.yaml"`
- Request using our simple client:
```python
from trinity.cli.client import request
res = request(
- url="http://127.0.0.1:5005/data_workflow",
+ url="http://127.0.0.1:5005/data_processor/task_pipeline",
configPath="tests/test_configs/active_iterator_test_cfg.yaml"
)
diff --git a/trinity/data/server.py b/trinity/data/server.py
index 08ca5ebfea..e1f57ba81b 100644
--- a/trinity/data/server.py
+++ b/trinity/data/server.py
@@ -1,20 +1,39 @@
import fire
from flask import Flask, jsonify, request
+from markupsafe import escape
app = Flask(__name__)
-APP_NAME = "data_workflow"
+APP_NAME = "data_processor"
-@app.route(f"/{APP_NAME}", methods=["GET"])
-def data_workflow():
+@app.route(f"/{APP_NAME}/", methods=["GET"])
+def data_processor(pipeline_type):
from trinity.common.config import load_config
from trinity.data.controllers.active_iterator import DataActiveIterator
config_path = request.args.get("configPath")
+ pipeline_type = escape(pipeline_type)
config = load_config(config_path)
- iterator = DataActiveIterator(config)
+ pipeline_config = getattr(config.data_processor, pipeline_type)
+ if pipeline_config is None:
+ return jsonify(
+ {
+ "return_code": -1,
+ "message": f"Error: {pipeline_type} is not supported or the corresponding config is empty",
+ }
+ )
+
+ if pipeline_config.dj_config_path is None and pipeline_config.dj_process_desc is None:
+ return jsonify(
+ {
+ "return_code": -1,
+ "message": "Error: Both dj_config_path and dj_process_desc in the pipeline config are None.",
+ }
+ )
+
+ iterator = DataActiveIterator(pipeline_config, config.buffer)
ret, msg = iterator.run()
return jsonify({"return_code": ret, "message": msg})
From acf7788c88ec06e0ae72210ab37dd14c4314b201 Mon Sep 17 00:00:00 2001
From: pxc
Date: Fri, 20 Jun 2025 17:48:22 +0800
Subject: [PATCH 28/28] bumping repository code version to v0.2.0.dev0
---
pyproject.toml | 2 +-
trinity/__init__.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index d0d17d8065..fd7ca1f3c0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "trinity-rft"
-version = "0.1.1"
+version = "0.2.0.dev0"
authors = [
{name="Trinity-RFT Team", email="trinity-rft@outlook.com"},
]
diff --git a/trinity/__init__.py b/trinity/__init__.py
index ff7c8c4b29..c5b13a8976 100644
--- a/trinity/__init__.py
+++ b/trinity/__init__.py
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Trinity-RFT (Reinforcement Fine-Tuning)"""
-__version__ = "0.1.1"
+__version__ = "0.2.0.dev0"