From f0771a70f860c6d46574757129742bbb9db11022 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 18 Jul 2025 10:15:30 +0800 Subject: [PATCH 01/21] refactor workflow --- tests/explorer/workflow_test.py | 2 +- trinity/common/workflows/math_rm_workflow.py | 5 +++-- trinity/common/workflows/workflow.py | 20 +++++++++++++------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 2132a591a0..cc85eaf8c7 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -31,7 +31,7 @@ class MockResponse: class DummyWorkflow(Workflow): def __init__(self, model, task: Task, auxiliary_models=None): - super().__init__(model, task, auxiliary_models) + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index 45940fdfae..f8e1d41720 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -19,14 +19,15 @@ class MathRMWorkflow(SimpleWorkflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index e9549d9e2e..5f3c47aa7f 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -84,10 +84,12 @@ class Workflow(ABC): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): + self.task = task self.model = model self.auxiliary_models = auxiliary_models @@ -102,6 +104,7 @@ def reset(self, task: Task): @abstractmethod def run(self) -> List[Experience]: """Run workflow and return a list of experiences.""" + raise NotImplementedError class MultiTurnWorkflow(Workflow): @@ -111,13 +114,14 @@ class MultiTurnWorkflow(Workflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) @@ -161,14 +165,15 @@ class SimpleWorkflow(Workflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) @@ -236,14 +241,15 @@ class MathWorkflow(SimpleWorkflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) From 662d654998c06db38c870f5941c4430e6fe0fb4c Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 18 Jul 2025 10:48:22 +0800 Subject: [PATCH 02/21] add step-wise workflow --- .../common/workflows/step_wise_workflow.py | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 trinity/common/workflows/step_wise_workflow.py diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py new file mode 100644 index 0000000000..634a54030c --- /dev/null +++ b/trinity/common/workflows/step_wise_workflow.py @@ -0,0 +1,117 @@ +from abc import abstractmethod + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import Task, Workflow + + +class StepWiseRewardWorkflow(Workflow): + """A workflow that implements step-wise rewards for tasks.""" + + def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + assert model.enable_history, ( + "Rollout Model must have history enabled for step-wise rewards, please " + "set `explorer.rollout_model.enable_history` to `True` in your config." + ) + # use the rollout model's OpenAI client to write your agent application + self.client: openai.OpenAI = model.get_openai_client() + + def run(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards.""" + experiences = [] + for step in self.max_step_num: + # Run a single step of the agent application + continue_run = self.step(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # Calculate the reward for the current step + exps = self.reward(exps, step_num=step) + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + + return experiences + + @abstractmethod + def step(self, step_num: int) -> bool: + """Run a single step of your agent application. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + def reward(self, exps: list[Experience], step_num: int) -> float: + """Calculate the reward for the given experiences at the specified step.""" + pass + + @property + @abstractmethod + def max_step_num(self): + """Return the maximum number of steps in the task.""" + + +class RewardPropagationWorkflow(Workflow): + """A workflow that propagates rewards across multiple turns.""" + + def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + assert model.enable_history, ( + "Rollout Model must have history enabled for step-wise rewards, please " + "set `explorer.rollout_model.enable_history` to `True` in your config." + ) + # use the rollout model's OpenAI client to write your agent application + self.client: openai.OpenAI = model.get_openai_client() + + def run(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards.""" + experiences = [] + for step in self.max_step_num: + # Run a single step of the agent application + continue_run = self.step(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + self.reward(experiences) + return experiences + + @abstractmethod + def step(self, step_num: int) -> bool: + """Run a single step of your agent application. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + def reward(self, exps: list[Experience]) -> float: + """Calculate the reward for the given experiences of the entire run.""" + pass + + @property + @abstractmethod + def max_step_num(self): + """Return the maximum number of steps in the task.""" From ae9db40a92a30dd93bf722dba8ff3cc4c3b2b05f Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 18 Jul 2025 18:01:36 +0800 Subject: [PATCH 03/21] refactor experience --- tests/buffer/sql_test.py | 12 +- tests/common/experience_test.py | 2 +- tests/explorer/scheduler_test.py | 6 +- trinity/algorithm/algorithm.py | 2 +- .../sample_strategy/mix_sample_strategy.py | 4 +- trinity/algorithm/sample_strategy/utils.py | 4 +- trinity/buffer/reader/file_reader.py | 37 +- trinity/buffer/schema/sql_schema.py | 7 +- trinity/common/experience.py | 527 ++++++++++-------- trinity/common/models/model.py | 13 +- trinity/common/models/vllm_model.py | 22 +- trinity/common/workflows/workflow.py | 14 +- trinity/trainer/verl_trainer.py | 4 +- 13 files changed, 341 insertions(+), 313 deletions(-) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 22b1c739a6..07032f8c69 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -8,7 +8,7 @@ from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import Experience +from trinity.common.experience import MultiTurnExperience, SingleTurnExperience db_path = os.path.join(os.path.dirname(__file__), "test.db") @@ -33,12 +33,11 @@ async def test_create_sql_buffer(self) -> None: sql_writer = SQLWriter(meta, config) sql_reader = SQLReader(meta, config) exps = [ - Experience( - tokens=torch.tensor([float(j) for j in range(i + 1)]), + SingleTurnExperience( + token_ids=torch.tensor([float(j) for j in range(i + 1)]), prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), - action_mask=torch.tensor([j % 2 for j in range(i + 1)]), ) for i in range(1, put_batch_size + 1) ] @@ -52,9 +51,8 @@ async def test_create_sql_buffer(self) -> None: # dynamic read/write sql_writer.write( [ - Experience( - tokens=torch.tensor([float(j) for j in range(i + 1)]), - prompt_length=i, + MultiTurnExperience( + token_ids=torch.tensor([float(j) for j in range(i + 1)]), reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 947c4d4ecb..ec903fd4c7 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -77,7 +77,7 @@ def test_batch_conversion(self): self.assertEqual(batch.rewards[i], exps[i].reward) self.assertTrue( torch.all( - batch.tokens[i][ + batch.token_ids[i][ prompt_length - exps[i].prompt_length : prompt_length - exps[i].prompt_length diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index d25c394bff..1c2b97cfd2 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -10,7 +10,7 @@ from trinity.buffer.reader.queue_reader import QueueReader from trinity.common.config import GenerationConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import Experience +from trinity.common.experience import Experience, SingleTurnExperience from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task from trinity.common.workflows.workflow import WORKFLOWS, Workflow @@ -42,8 +42,8 @@ def run(self) -> List[Experience]: assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 return [ - Experience( - tokens=torch.zeros(5), + SingleTurnExperience( + token_ids=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success", info={"repeat_times": self.repeat_times}, diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 54f5c3d296..1f0de5be64 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -153,7 +153,7 @@ def default_config(cls) -> Dict: @classmethod def check_config(cls, config: Config) -> None: - if config.model == "train": + if config.mode == "train": if ( config.buffer.trainer_input.experience_buffer is None or not config.buffer.trainer_input.experience_buffer.path diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 80a4af7d49..8ebd335246 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -94,8 +94,8 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): "uid": np.array(experiences.group_ids), "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "input_ids": experiences.token_ids.long(), + "responses": experiences.token_ids[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index f9df00ee4e..8bd868c22c 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -16,8 +16,8 @@ def to_data_proto(experiences: Experiences) -> DataProto: "uid": np.array(experiences.group_ids), "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "input_ids": experiences.token_ids.long(), + "responses": experiences.token_ids[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 1f5e29a2b9..673721c056 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -10,7 +10,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import PromptType, ReadStrategy, TaskType -from trinity.common.experience import Experience +from trinity.common.experience import DPOExperience, SingleTurnExperience from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows import WORKFLOWS, Task from trinity.utils.registry import Registry @@ -129,15 +129,15 @@ def read( if self.prompt_type == PromptType.MESSAGES: for sample in samples: messages = sample[self.messages_key] - tokens = self.tokenizer.apply_chat_template( + token_ids = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] - prompt_tokens = self.tokenizer.apply_chat_template( + prompt_tokens_ids = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] - experience = Experience( - tokens=tokens, - prompt_length=len(prompt_tokens), + experience = SingleTurnExperience( + token_ids=token_ids, + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -151,17 +151,17 @@ def read( response_messages = [response_messages] full_messages = prompt_messages + response_messages - tokens = self.tokenizer.apply_chat_template( + token_ids = self.tokenizer.apply_chat_template( full_messages, add_generation_prompt=False, return_tensors="pt" )[0] - prompt_tokens = self.tokenizer.apply_chat_template( + prompt_tokens_ids = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] - experience = Experience( - tokens=tokens, - prompt_length=len(prompt_tokens), + experience = SingleTurnExperience( + token_ids=token_ids, + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -170,11 +170,11 @@ def read( for sample in samples: prompt = sample[self.prompt_key] response = sample[self.response_key] - tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] - prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] - experience = Experience( - tokens=tokens, - prompt_length=len(prompt_tokens), + token_ids = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] + prompt_tokens_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] + experience = SingleTurnExperience( + token_ids=token_ids, + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) else: @@ -250,9 +250,8 @@ def read( add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] - experience = Experience( - tokens=prompt_tokens, - prompt_length=len(prompt_tokens), + experience = DPOExperience( + token_ids=prompt_tokens, chosen=chosen_tokens, rejected=rejected_tokens, ) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 21289c7768..d85985e287 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -5,7 +5,7 @@ from sqlalchemy import Column, Float, Integer, LargeBinary, String from sqlalchemy.ext.declarative import declarative_base -from trinity.common.experience import Experience +from trinity.common.experience import Experience, MultiTurnExperience Base = declarative_base() @@ -90,9 +90,8 @@ def from_messages( messages=messages, chat_template=chat_template, ) - exp = Experience( - tokens=token_ids, - prompt_length=0, + exp = MultiTurnExperience( + token_ids=token_ids, action_mask=action_mask, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, ) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 56bdd1c4c5..9f6e227dcc 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -3,8 +3,8 @@ from __future__ import annotations import pickle -from dataclasses import dataclass -from itertools import chain, repeat +import uuid +from dataclasses import dataclass, field from typing import List, Optional import torch @@ -12,66 +12,206 @@ @dataclass -class Experience: - """A single experience.""" +class EID: + """Experience ID class to uniquely identify an experience.""" + + # TODO: do we need to add project/name here to make it unique across different projects? + batch: int = 0 # Batch number, e.g., the explorer step num + task: int = 0 # Task sequence in the batch, e.g., the first task in the batch has task=0 + run: int = 0 # Run id, e.g., the first run in the task has run=0 + step: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 + suffix: str = field( + default_factory=lambda: uuid.uuid4().hex[:6] + ) # Unique identifier suffix, e.g., a UUID - tokens: Tensor # [seq] - prompt_length: int - logprobs: Optional[Tensor] = None # [seq] + @property + def uid(self) -> str: + """An unique identifier for the experience.""" + return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" + + @property + def sid(self) -> str: + """Step ID of the experience. + + For example, experiences generated by all runs of a same task at the same step will have the same sid. + """ + return f"{self.batch}/{self.task}/{self.step}" + + @property + def rid(self) -> str: + """Run ID of the experience. + + For example, experiences generated by one run of a task at all steps will have the same run_id. + """ + return f"{self.batch}/{self.task}/{self.run}" + + @property + def gid(self) -> str: + """Group ID for the experience. + + For example, experiences generated by a group run in GRPO-like algorithms will have the same gid. + """ + return f"{self.batch}/{self.task}" + + def __str__(self): + return self.uid + + def __repr__(self): + return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})" + + +@dataclass +class Experience: + token_ids: Tensor # [seq_length] + eid: EID = field(default_factory=EID) # Unique identifier for the experience reward: Optional[float] = None - prompt_text: Optional[str] = None - response_text: Optional[str] = None - action_mask: Optional[Tensor] = None - chosen: Optional[Tensor] = None # for dpo - rejected: Optional[Tensor] = None # for dpo info: Optional[dict] = None metrics: Optional[dict[str, float]] = None - group_id: str = "" # for grpo - unique_id: str = "" - - def __post_init__(self): - if self.action_mask is not None: - assert ( - self.action_mask.shape == self.tokens.shape - ), "The provided action_mask must have the same shape as tokens." - - # explicit type cast - if not isinstance(self.tokens, Tensor): - self.tokens = Tensor(self.tokens) - if self.logprobs is not None and not isinstance(self.logprobs, Tensor): - self.logprobs = Tensor(self.logprobs) - if self.action_mask is not None and not isinstance(self.action_mask, Tensor): - self.action_mask = Tensor(self.action_mask) - if self.chosen is not None and not isinstance(self.chosen, Tensor): - self.chosen = Tensor(self.chosen) - if self.rejected is not None and not isinstance(self.rejected, Tensor): - self.rejected = Tensor(self.rejected) def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self) - @staticmethod - def deserialize(data: bytes) -> Experience: - """Deserialize the experience from bytes.""" + @classmethod + def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data) - def to_dict(self) -> dict: - """Convert the experience to a dictionary.""" - res = { - "prompt_text": self.prompt_text, - "info": self.info, - "metrics": self.metrics, - } - if self.response_text is not None: - res["response_text"] = self.response_text - if self.chosen is not None: - res["chosen"] = self.chosen.tolist() - if self.rejected is not None: - res["rejected"] = self.rejected.tolist() - if self.reward is not None: - res["reward"] = float(self.reward) - return res + +@dataclass +class SingleTurnExperience(Experience): + """A single-turn prompt-response experience.""" + + # Length of the prompt in tokens, used for generating attention masks + prompt_length: int # type: ignore[misc] + response_text: Optional[str] = None # Text of the response + prompt_text: Optional[str] = None # Text of the prompt + logprobs: Optional[Tensor] = None # [seq] + + @property + def action_mask(self) -> Tensor: + """Get the action mask for the single-turn experience.""" + # set the prompt length to 0 and the rest to 1 + action_mask = torch.zeros_like(self.token_ids, dtype=torch.bool) + action_mask[self.prompt_length :] = 1 + return action_mask + + @classmethod + def gather(cls, experiences: List[SingleTurnExperience], pad_token_id: int = 0) -> Experiences: + if len(experiences) == 0: + return empty_experiences() + max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] + max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [operator] + eids = [exp.eid for exp in experiences] + + # Gather token_ids + token_ids = gather_token_ids( + experiences, max_prompt_length, max_response_length, pad_token_id + ) + + # Gather rewards + if experiences[0].reward is not None: + rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) + else: + rewards = None + + # gather action_masks + action_masks = gather_action_masks(experiences, max_prompt_length, max_response_length) + + # gather attention_masks + attention_masks = gather_attention_masks( + experiences, max_prompt_length, max_response_length + ) + + # gather logprobs + + if all(exp.logprobs is not None for exp in experiences): + logprobs = gather_logprobs(experiences, max_prompt_length, max_response_length) + else: + logprobs = None + + return Experiences( + eids=eids, + token_ids=token_ids, + rewards=rewards, + attention_masks=attention_masks, + action_masks=action_masks, + prompt_length=max_prompt_length, + logprobs=logprobs, + ) + + +@dataclass +class MultiTurnExperience(Experience): + """A multi-turn experience, which includes the conversation history in a single tensor.""" + + # Action mask which indicates which tokens are generated by the model + action_mask: Tensor # type: ignore[misc] + logprobs: Optional[Tensor] = None # [seq_length] + messages: Optional[List[dict]] = None # List of messages in the conversation + + @property + def prompt_length(self) -> int: + return 1 # use action mask to determine the response tokens, set prompt_length to 1 to avoid no prompt issues + + @classmethod + def gather(cls, experiences: List[SingleTurnExperience], pad_token_id: int = 0) -> Experiences: + return SingleTurnExperience.gather(experiences, pad_token_id=pad_token_id) + + +@dataclass +class DPOExperience(Experience): + """A DPO experience, which includes the chosen and rejected responses. + + `token_ids` should only contain the prompt tokens, while `chosen` and `rejected` should + contain the response tokens. + """ + + # Token ids of the chosen response [resp_length] + chosen: Tensor # type: ignore[misc] + # Token ids of the rejected response [resp_length] + rejected: Tensor # type: ignore[misc] + chosen_text: Optional[str] = None # Text of the chosen response + rejected_text: Optional[str] = None # Text of the rejected response + prompt_text: Optional[str] = None # Text of the prompt + + @classmethod + def gather(cls, experiences: List[DPOExperience], pad_token_id: int = 0) -> Experiences: + """Gather a batch of DPO experiences from a list of experiences.""" + single_turn_experiences = [] + for exp in experiences: + single_turn_experiences.append( + SingleTurnExperience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + ), + token_ids=torch.cat([exp.token_ids, exp.chosen]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.token_ids), + prompt_text=exp.prompt_text, + response_text=exp.chosen_text, + ) + ) + single_turn_experiences.append( + SingleTurnExperience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + ), + token_ids=torch.cat([exp.token_ids, exp.rejected]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.token_ids), + prompt_text=exp.prompt_text, + response_text=exp.rejected_text, + ) + ) + return SingleTurnExperience.gather(single_turn_experiences, pad_token_id=pad_token_id) @dataclass(frozen=True) @@ -81,7 +221,7 @@ class Experiences: Example: >>> |<- prompt_length ->| | - >>> tokens: ('P' represents prompt, 'O' represents output) + >>> token_ids: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> @@ -90,19 +230,18 @@ class Experiences: >>> exp2: |......1111111111111|1111111........| """ - tokens: Tensor + eids: List[EID] # Experience IDs of each experience in the batch + token_ids: Tensor rewards: Tensor attention_masks: Tensor action_masks: Optional[Tensor] prompt_length: int logprobs: Optional[Tensor] - group_ids: List[str] - unique_ids: List[str] @property def batch_size(self) -> int: """Get the batch size.""" - return self.tokens.size(0) + return self.token_ids.size(0) @classmethod def gather_experiences( @@ -113,207 +252,103 @@ def gather_experiences( This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. """ if len(experiences) == 0: - return Experiences( - tokens=torch.empty(0, dtype=torch.int32), - rewards=torch.empty(0, dtype=torch.float32), - attention_masks=torch.empty(0, dtype=torch.bool), - action_masks=torch.empty(0, dtype=torch.bool), - logprobs=torch.empty(0, dtype=torch.float32), - prompt_length=torch.empty(0, dtype=torch.int32), - group_ids=[], - unique_ids=[], - ) - max_prompt_length = max([exp.prompt_length for exp in experiences]) - max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) - group_ids = [exp.group_id for exp in experiences] - unique_ids = [exp.unique_id for exp in experiences] - tokens_dtype = experiences[0].tokens.dtype - tokens = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - pad_token_id, - dtype=tokens_dtype, - ), - exp.tokens, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - pad_token_id, - dtype=tokens_dtype, - ), - ] - ) - for exp in experiences - ] - ) - if experiences[0].reward is not None: - rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) - else: - rewards = None - - # Calculate the action_masks according to the provided experience.action_mask - if experiences[0].action_mask is not None: - action_mask_dtype = experiences[0].action_mask.dtype - action_masks = torch.stack( + return empty_experiences() + return experiences[0].__class__.gather(experiences, pad_token_id=pad_token_id) + + +def empty_experiences() -> Experiences: + return Experiences( + token_ids=torch.empty(0, dtype=torch.int32), + rewards=torch.empty(0, dtype=torch.float32), + attention_masks=torch.empty(0, dtype=torch.bool), + action_masks=torch.empty(0, dtype=torch.bool), + logprobs=torch.empty(0, dtype=torch.float32), + prompt_length=torch.empty(0, dtype=torch.int32), + eids=[], + ) + + +def gather_token_ids( + experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int +) -> Tensor: + token_ids_dtype = experiences[0].token_ids.dtype + return torch.stack( + [ + torch.cat( [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - 0, - dtype=action_mask_dtype, - ), - exp.action_mask, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - 0, - dtype=action_mask_dtype, - ), - ] - ) - for exp in experiences + torch.full( + (max_prompt_length - exp.prompt_length,), + pad_token_id, + dtype=token_ids_dtype, + ), + exp.tokens, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + pad_token_id, + dtype=token_ids_dtype, + ), ] ) - else: - action_masks = None - attention_masks = torch.zeros( - (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool - ) - for i, exp in enumerate(experiences): - start = max_prompt_length - exp.prompt_length - end = start + len(exp.tokens) - attention_masks[i, start:end] = 1 + for exp in experiences + ] + ) - if all(exp.logprobs is not None for exp in experiences): - logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] - logprobs = torch.stack( + +def gather_action_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + return torch.stack( + [ + torch.cat( [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - 0.0, - dtype=logprob_dtype, - ), - exp.logprobs, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - 0.0, - dtype=logprob_dtype, - ), - ] - ) - for exp in experiences + torch.full( + (max_prompt_length - exp.prompt_length,), + 0, + dtype=torch.bool, + ), + exp.action_mask, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + 0, + dtype=torch.bool, + ), ] ) - else: - logprobs = None + for exp in experiences + ] + ) - return cls( - group_ids=group_ids, - unique_ids=unique_ids, - tokens=tokens, - rewards=rewards, - attention_masks=attention_masks, - action_masks=action_masks, - prompt_length=max_prompt_length, - logprobs=logprobs, - ) - - @classmethod - def gather_dpo_experiences( - cls, experiences: list[Experience], pad_token_id: int = 0 - ) -> Experiences: - """Gather a batch of dpo experiences from a list of experiences. - - Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L849 - Note: We arrange inputs in the order of (chosen, rejected, chosen, rejected, ...) - to ensure that each pair of (chosen, rejected) is not split by subsequent operations +def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + attention_masks = torch.zeros( + (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool + ) - Args: - Experiences: `(list[Experience])` - - `"prompt"`: token ids of the prompt - - `"chosen"`: token ids of the chosen response - - `"rejected"`: token ids of the rejected response - pad_token_id: `(int)` - The pad token id. - - Returns: - Experiences: - - `"tokens"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. - - `"attention_masks"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. - """ - if len(experiences) == 0: - return Experiences( - tokens=torch.empty(0, dtype=torch.int32), - rewards=torch.empty(0, dtype=torch.float32), - attention_masks=torch.empty(0, dtype=torch.bool), - action_masks=torch.empty(0, dtype=torch.bool), - logprobs=torch.empty(0, dtype=torch.float32), - prompt_length=torch.empty(0, dtype=torch.int32), - group_ids=[], - unique_ids=[], - ) + for i, exp in enumerate(experiences): + start = max_prompt_length - exp.prompt_length + end = start + len(exp.tokens) + attention_masks[i, start:end] = 1 - # TODO: exp.tokens in DPO are prompt tokens - prompt_tokens = list(chain.from_iterable([repeat(exp.tokens, 2) for exp in experiences])) - max_prompt_length = max([exp.prompt_length for exp in experiences]) + return attention_masks - chosen_tokens = [exp.chosen for exp in experiences] - rejected_tokens = [exp.rejected for exp in experiences] - response_tokens = list(chain.from_iterable(zip(chosen_tokens, rejected_tokens))) - max_response_length = max([len(response) for response in response_tokens]) # type: ignore - group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences])) - unique_ids = list( - chain.from_iterable( - [(f"{exp.unique_id}/1", f"{exp.unique_id}/0") for exp in experiences] +def gather_logprobs(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] + return torch.stack( + [ + torch.cat( + [ + torch.full( + (max_prompt_length - exp.prompt_length,), + 0.0, + dtype=logprob_dtype, + ), + exp.logprobs, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + 0.0, + dtype=logprob_dtype, + ), + ] ) - ) - tokens_dtype = experiences[0].tokens.dtype - tokens = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_prompt_length - len(prompt),), - pad_token_id, - dtype=tokens_dtype, - ), - prompt, - response, - torch.full( - (max_response_length - len(response),), # type: ignore - pad_token_id, - dtype=tokens_dtype, - ), - ] - ) - for prompt, response in zip(prompt_tokens, response_tokens) - ] - ) - - attention_masks = torch.zeros( - (len(tokens), max_prompt_length + max_response_length), dtype=torch.bool - ) - - for (i, prompt), response in zip(enumerate(prompt_tokens), response_tokens): - start = max_prompt_length - len(prompt) - end = max_prompt_length + len(response) # type: ignore - attention_masks[i, start:end] = 1 - - assert len(tokens) == 2 * len(experiences) - - return cls( - group_ids=group_ids, - unique_ids=unique_ids, - tokens=tokens, - attention_masks=attention_masks, - prompt_length=max_prompt_length, - rewards=None, - action_masks=None, - logprobs=None, - ) + for exp in experiences + ] + ) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2568c88005..e32e7ceddb 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,25 +4,25 @@ import socket import time from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Union +from typing import Any, List, Sequence, Tuple, Union import openai import ray import torch from torch import Tensor -from trinity.common.experience import Experience +from trinity.common.experience import Experience, SingleTurnExperience from trinity.utils.log import get_logger class InferenceModel(ABC): """A model for high performance for rollout inference.""" - async def generate(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a responses from a prompt in async.""" raise NotImplementedError - async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError @@ -189,8 +189,8 @@ def convert_api_output_to_experience( ) -> List[Experience]: """Convert the API output to a list of experiences.""" return [ - Experience( - tokens=torch.cat( + SingleTurnExperience( + token_ids=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), torch.tensor(choice.token_ids, dtype=torch.int32), @@ -207,7 +207,6 @@ def convert_api_output_to_experience( ) ), prompt_length=len(output.prompt_token_ids), - prompt_text=None, response_text=choice.message.content, ) for choice in output.choices diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 01b8135511..856ea169b6 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,7 +3,7 @@ import os import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import aiohttp import ray @@ -12,7 +12,11 @@ from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig -from trinity.common.experience import Experience +from trinity.common.experience import ( + Experience, + MultiTurnExperience, + SingleTurnExperience, +) from trinity.common.models.model import InferenceModel from trinity.common.models.utils import ( tokenize_and_mask_messages_default, @@ -29,7 +33,6 @@ class vLLMRolloutModel(InferenceModel): Args: config (Config): The config. - kwargs (dict): The keyword arguments for the engine. """ def __init__( @@ -103,7 +106,7 @@ def __init__( self.api_server_host = None self.api_server_port = None - async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: """Chat with the model with a list of messages in async. Args: @@ -134,7 +137,7 @@ async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: ) return await self.generate(prompt=prompt, **kwargs) - async def generate(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a response from the provided prompt in async. Args: @@ -146,8 +149,8 @@ async def generate(self, prompt: str, **kwargs) -> List[Experience]: """ output = await self._generate_internal(prompt=prompt, **kwargs) experiences = [ - Experience( - tokens=torch.cat( + SingleTurnExperience( + token_ids=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), @@ -222,9 +225,8 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.tokenizer, messages, self.chat_template ) logprobs = await self.logprobs(token_ids=token_ids.tolist()) - return Experience( - tokens=token_ids, - prompt_length=len(token_ids), + return MultiTurnExperience( + token_ids=token_ids, logprobs=logprobs, action_mask=action_mask, ) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 5f3c47aa7f..97952f2cd3 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -8,10 +8,9 @@ from typing import Any, List, Optional, Type, Union import openai -import torch from trinity.common.config import FormatConfig, GenerationConfig -from trinity.common.experience import Experience +from trinity.common.experience import Experience, MultiTurnExperience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathRewardFn from trinity.common.rewards.reward_fn import RewardFn @@ -132,24 +131,21 @@ def run(self) -> List[Experience]: def process_messages_to_experience(self, messages, reward, info={}) -> Experience: converted_experience = self.model.convert_messages_to_experience(messages) - tokens = converted_experience.tokens + token_ids = converted_experience.token_ids log_probs = converted_experience.logprobs assert converted_experience.action_mask is not None generation_mask = converted_experience.action_mask log_probs = log_probs * generation_mask - assert tokens.shape == log_probs.shape - # set prompt length to the first 1 in the gen_mask - prompt_length = torch.where(generation_mask == 1)[0][0].item() + assert token_ids.shape == log_probs.shape metrics = {} for k, v in info.items(): if isinstance(v, float) or isinstance(v, int): metrics[k] = float(v) - experience = Experience( - tokens=tokens, - prompt_length=prompt_length, + experience = MultiTurnExperience( + token_ids=token_ids, action_mask=generation_mask, reward=reward, logprobs=log_probs, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2198f4d2d1..a2c195d146 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -405,10 +405,10 @@ def _log_single_experience( ) -> None: reward = experiences.rewards[idx] attn_mask = experiences.attention_masks[idx].bool() - prompt_token = experiences.tokens[idx][: experiences.prompt_length][ + prompt_token = experiences.token_ids[idx][: experiences.prompt_length][ attn_mask[: experiences.prompt_length] ] - response_token = experiences.tokens[idx][experiences.prompt_length :][ + response_token = experiences.token_ids[idx][experiences.prompt_length :][ attn_mask[experiences.prompt_length :] ] prompt_text = self.tokenizer.decode(prompt_token, skip_special_tokens=skip_special_tokens) From 83544ca9b44581a8482caed86099f743c5391a08 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 11:56:16 +0800 Subject: [PATCH 04/21] fix experience tests --- tests/buffer/queue_test.py | 17 +- tests/buffer/sql_test.py | 6 +- tests/common/experience_test.py | 251 ++++++++++++++++++++++--- tests/explorer/scheduler_test.py | 4 +- trinity/buffer/reader/file_reader.py | 14 +- trinity/buffer/schema/sql_schema.py | 8 +- trinity/common/experience.py | 271 +++++++++++++++++---------- trinity/common/models/model.py | 4 +- trinity/common/models/vllm_model.py | 10 +- trinity/common/workflows/workflow.py | 6 +- 10 files changed, 430 insertions(+), 161 deletions(-) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 702e8b8ca1..c591cb6b05 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -30,7 +30,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc): ) async def test_queue_buffer(self, name, use_priority_queue): meta = StorageConfig( - name="test_buffer", + name=name, algorithm_type="ppo", storage_type=StorageType.QUEUE, max_read_timeout=3, @@ -42,7 +42,7 @@ async def test_queue_buffer(self, name, use_priority_queue): self.assertEqual(await writer.acquire(), 1) exps = [ Experience( - tokens=torch.tensor([float(j) for j in range(i + 1)]), + token_ids=torch.tensor([float(j) for j in range(i + 1)]), prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), @@ -59,8 +59,7 @@ async def test_queue_buffer(self, name, use_priority_queue): print(f"finish read {self.read_batch_size} experience") exps = [ Experience( - tokens=torch.tensor([float(j) for j in range(i + 1)]), - prompt_length=i, + token_ids=torch.tensor([float(j) for j in range(i + 1)]), reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), @@ -100,7 +99,7 @@ async def test_priority_queue_capacity(self): writer.write( [ Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), @@ -164,12 +163,12 @@ async def test_priority_queue_buffer_reuse(self): writer.write( [ Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), @@ -181,12 +180,12 @@ def replace_call(): writer.write( [ Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0}, ), Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0}, ), diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 07032f8c69..389c5195be 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -8,7 +8,7 @@ from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import MultiTurnExperience, SingleTurnExperience +from trinity.common.experience import Experience db_path = os.path.join(os.path.dirname(__file__), "test.db") @@ -33,7 +33,7 @@ async def test_create_sql_buffer(self) -> None: sql_writer = SQLWriter(meta, config) sql_reader = SQLReader(meta, config) exps = [ - SingleTurnExperience( + Experience( token_ids=torch.tensor([float(j) for j in range(i + 1)]), prompt_length=i, reward=float(i), @@ -51,7 +51,7 @@ async def test_create_sql_buffer(self) -> None: # dynamic read/write sql_writer.write( [ - MultiTurnExperience( + Experience( token_ids=torch.tensor([float(j) for j in range(i + 1)]), reward=float(i), logprobs=torch.tensor([0.1]), diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index ec903fd4c7..097a7ff6a5 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -6,72 +6,242 @@ import torch from trinity.buffer.schema.sql_schema import ExperienceModel -from trinity.common.experience import Experience, Experiences +from trinity.common.experience import EID, Experience, Experiences db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") +class TestEID(unittest.TestCase): + def test_eid_properties(self): + # test properties + eid = EID(batch=1, task=2, run=3, step=4, suffix="abc123") + self.assertEqual(eid.uid, "1/2/3/4/abc123") + self.assertEqual(eid.sid, "1/2/4") + self.assertEqual(eid.rid, "1/2/3") + self.assertEqual(eid.gid, "1/2") + self.assertEqual(str(eid), "1/2/3/4/abc123") + self.assertIn("EID(batch=1, task=2, run=3, step=4, uuid=abc123)", repr(eid)) + + # test unique + eid1 = EID(batch=1, task=2, run=3, step=4) + eid2 = EID(batch=1, task=2, run=3, step=4) + self.assertNotEqual(eid1.suffix, eid2.suffix) + self.assertNotEqual(eid1.uid, eid2.uid) + + # test default + eid = EID() + eid2 = EID() + self.assertIsInstance(eid.suffix, str) + self.assertEqual(eid.batch, 0) + self.assertEqual(eid.task, 0) + self.assertEqual(eid.run, 0) + self.assertEqual(eid.step, 0) + self.assertNotEqual(eid.uid, eid2.uid) + + +class TestExperience(unittest.TestCase): + def test_single_turn_experience(self): + token_ids = torch.tensor([10, 11, 12], dtype=torch.int32) + logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + exp = Experience(token_ids=token_ids, logprobs=logprobs, reward=1.0, prompt_length=1) + self.assertEqual(exp.experience_type.name, "SINGLE_TURN") + self.assertTrue(torch.equal(exp.token_ids, token_ids)) + self.assertTrue(torch.equal(exp.logprobs, logprobs)) + self.assertEqual(exp.reward, 1.0) + self.assertEqual(exp.prompt_length, 1) + self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool))) + + def test_multi_turn_experience(self): + token_ids = torch.tensor([1, 2, 3, 4]) + logprobs = torch.tensor([0.1, 0.2, 0.3, 0.4]) + action_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + exp = Experience( + token_ids=token_ids, logprobs=logprobs, reward=2.0, action_mask=action_mask + ) + self.assertEqual(exp.experience_type.name, "MULTI_TURN") + self.assertTrue(torch.equal(exp.action_mask, action_mask)) + self.assertEqual(exp.prompt_length, 1) + + def test_dpo_experience(self): + token_ids = torch.tensor([1, 2]) + chosen_ids = torch.tensor([3, 4]) + rejected_ids = torch.tensor([5, 6]) + exp = Experience( + token_ids=token_ids, chosen_ids=chosen_ids, rejected_ids=rejected_ids, reward=0.5 + ) + self.assertEqual(exp.experience_type.name, "DPO") + self.assertTrue(torch.equal(exp.chosen_ids, chosen_ids)) + self.assertTrue(torch.equal(exp.rejected_ids, rejected_ids)) + self.assertEqual(exp.prompt_length, 2) + + def test_serialize_deserialize(self): + token_ids = torch.tensor([1, 2, 3]) + exp = Experience(token_ids=token_ids, reward=1.23, prompt_length=1) + data = exp.serialize() + exp2 = Experience.deserialize(data) + self.assertTrue(torch.equal(exp.token_ids, exp2.token_ids)) + self.assertEqual(exp.reward, exp2.reward) + self.assertEqual(exp.prompt_length, exp2.prompt_length) + self.assertEqual(exp.experience_type, exp2.experience_type) + + def test_to_dict(self): + token_ids = torch.tensor([1, 2, 3]) + exp = Experience( + token_ids=token_ids, reward=2.5, prompt_length=1, prompt_text="hi", response_text="yo" + ) + d = exp.to_dict() + self.assertIn("eid", d) + self.assertIn("type", d) + self.assertIn("reward", d) + self.assertEqual(d["prompt_text"], "hi") + self.assertEqual(d["response_text"], "yo") + self.assertEqual(d["reward"], 2.5) + + def test_gather(self): + # test empty gathering + batch = Experiences.gather_experiences([]) + self.assertEqual(batch.token_ids.numel(), 0) + self.assertEqual(batch.rewards.numel(), 0) + self.assertEqual(batch.eids, []) + + # test single experience gathering + exp = Experience(token_ids=torch.tensor([1, 2, 3]), reward=1.0, prompt_length=1) + batch = Experiences.gather_experiences([exp]) + self.assertEqual(batch.batch_size, 1) + self.assertTrue( + torch.equal(batch.token_ids[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) + ) + self.assertEqual(batch.prompt_length, 1) + self.assertEqual(batch.rewards[0], 1.0) + + # test multiple experiences gathering + exps = [ + Experience(token_ids=torch.tensor([1, 2]), reward=0.1, prompt_length=1), + Experience(token_ids=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + self.assertEqual(batch.token_ids.shape[1], 3) + self.assertEqual(batch.rewards[0], 0.1) + self.assertEqual(batch.rewards[1], 0.2) + + def test_action_mask_and_logprobs_type(self): + exp = Experience(token_ids=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) + self.assertIsInstance(exp.token_ids, torch.Tensor) + self.assertIsInstance(exp.logprobs, torch.Tensor) + self.assertIsInstance(exp.action_mask, torch.Tensor) + + def test_assertions(self): + # prompt_length must be > 0 + with self.assertRaises(AssertionError): + Experience(token_ids=[1, 2, 3], prompt_length=0) + # token_ids must be longer than prompt_length for single-turn + with self.assertRaises(AssertionError): + Experience(token_ids=[1, 2], prompt_length=2) + # DPO: token_ids must match prompt_length + exp = Experience(token_ids=[1, 2], chosen_ids=[3], rejected_ids=[4], prompt_length=1) + exp.prompt_length = 2 # should automatically adjust + + class TestExperienceConversion(unittest.TestCase): """Test cases for ExperienceModel""" def test_experience_model_experience_conversion(self): """Test the conversion between Experience and ExperienceModel""" - tokens = torch.tensor([1, 2, 3], dtype=torch.int32) + token_ids = torch.tensor([1, 2, 3], dtype=torch.int32) reward = 0.6 prompt_length = 2 logprobs = torch.tensor([0, 0, 0.1], dtype=torch.float32) - action_mask = torch.tensor([1, 0, 1], dtype=torch.bool) experience = Experience( - tokens=tokens, + token_ids=token_ids, reward=reward, prompt_length=prompt_length, logprobs=logprobs, - action_mask=action_mask, ) model = ExperienceModel.from_experience(experience) - experience = model.to_experience() - self.assertTrue(torch.equal(experience.tokens, tokens)) - self.assertEqual(experience.prompt_length, prompt_length) - self.assertEqual(experience.reward, reward) - self.assertTrue(torch.equal(experience.logprobs, logprobs)) - self.assertTrue(torch.equal(experience.action_mask, action_mask)) + new_experience = model.to_experience() + self.assertTrue(torch.equal(new_experience.token_ids, token_ids)) + self.assertEqual(new_experience.prompt_length, prompt_length) + self.assertEqual(new_experience.reward, reward) + self.assertTrue(torch.equal(new_experience.logprobs, logprobs)) + self.assertTrue(torch.equal(new_experience.action_mask, experience.action_mask)) def test_batch_conversion(self): exps = [ Experience( - tokens=torch.tensor([1, 2]), + token_ids=torch.tensor([1, 2]), prompt_length=1, reward=float(0.1), logprobs=torch.tensor([0, 0.1]), - action_mask=torch.tensor([1, 0]), ), Experience( - tokens=torch.tensor([1, 2, 3]), + token_ids=torch.tensor([1, 2, 3]), prompt_length=2, reward=float(0.2), logprobs=torch.tensor([0, 0, 0.1]), - action_mask=torch.tensor([1, 0, 1]), ), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + prompt_length = batch.prompt_length + for i in range(batch.batch_size): + self.assertEqual(batch.rewards[i], exps[i].reward) + self.assertTrue( + torch.all( + batch.token_ids[i][ + prompt_length + - exps[i].prompt_length : prompt_length + - exps[i].prompt_length + + exps[i].token_ids.size(0) + ] + == exps[i].token_ids + ) + ) + self.assertTrue( + torch.all( + batch.logprobs[i][ + prompt_length + - exps[i].prompt_length : prompt_length + + exps[i].token_ids.size(0) + - exps[i].prompt_length + ] + == exps[i].logprobs + ) + ) + self.assertTrue( + torch.all( + batch.action_masks[i][ + prompt_length + - exps[i].prompt_length : prompt_length + - exps[i].prompt_length + + exps[i].action_mask.size(0) + ] + == exps[i].action_mask + ) + ) + + def test_multiturn_experience_batch_converstion(self): + exps = [ Experience( - tokens=torch.tensor([1, 2, 3, 4]), - prompt_length=2, + token_ids=torch.tensor([1, 2, 3, 4]), reward=float(0.3), logprobs=torch.tensor([0, 0, 0.1, 0.2]), action_mask=torch.tensor([1, 0, 1, 0]), ), Experience( - tokens=torch.tensor([1, 2, 3, 4]), - prompt_length=3, + token_ids=torch.tensor([1, 2, 3, 4]), reward=float(0.4), logprobs=torch.tensor([0, 0, 0, 0.1]), - action_mask=torch.tensor([1, 0, 1, 0]), + action_mask=torch.tensor([1, 0, 0, 1]), ), ] batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 4) - self.assertEqual(batch.prompt_length, 3) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 1) prompt_length = batch.prompt_length for i in range(batch.batch_size): self.assertEqual(batch.rewards[i], exps[i].reward) @@ -81,9 +251,9 @@ def test_batch_conversion(self): prompt_length - exps[i].prompt_length : prompt_length - exps[i].prompt_length - + exps[i].tokens.size(0) + + exps[i].token_ids.size(0) ] - == exps[i].tokens + == exps[i].token_ids ) ) self.assertTrue( @@ -91,7 +261,7 @@ def test_batch_conversion(self): batch.logprobs[i][ prompt_length - exps[i].prompt_length : prompt_length - + exps[i].tokens.size(0) + + exps[i].token_ids.size(0) - exps[i].prompt_length ] == exps[i].logprobs @@ -109,6 +279,37 @@ def test_batch_conversion(self): ) ) + def test_dpo_experience_batch_conversion(self): + exps = [ + Experience( + token_ids=torch.tensor([1, 2]), + chosen_ids=torch.tensor([3, 4]), + rejected_ids=torch.tensor([5, 6]), + ), + Experience( + token_ids=torch.tensor([7, 8, 9]), + chosen_ids=torch.tensor([10, 11]), + rejected_ids=torch.tensor([12, 13]), + ), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 4) + self.assertEqual(batch.prompt_length, 3) + prompt_length = batch.prompt_length + for i in range(batch.batch_size): + j = i // 2 + self.assertTrue( + torch.all( + batch.token_ids[i][ + prompt_length + - exps[j].prompt_length : prompt_length + - exps[j].prompt_length + + exps[j].token_ids.size(0) + ] + == exps[j].token_ids + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 1c2b97cfd2..1f1342fdad 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -10,7 +10,7 @@ from trinity.buffer.reader.queue_reader import QueueReader from trinity.common.config import GenerationConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import Experience, SingleTurnExperience +from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task from trinity.common.workflows.workflow import WORKFLOWS, Workflow @@ -42,7 +42,7 @@ def run(self) -> List[Experience]: assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 return [ - SingleTurnExperience( + Experience( token_ids=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success", diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 673721c056..e2ec4a69b9 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -10,7 +10,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import PromptType, ReadStrategy, TaskType -from trinity.common.experience import DPOExperience, SingleTurnExperience +from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows import WORKFLOWS, Task from trinity.utils.registry import Registry @@ -135,7 +135,7 @@ def read( prompt_tokens_ids = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] - experience = SingleTurnExperience( + experience = Experience( token_ids=token_ids, prompt_length=len(prompt_tokens_ids), ) @@ -159,7 +159,7 @@ def read( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] - experience = SingleTurnExperience( + experience = Experience( token_ids=token_ids, prompt_length=len(prompt_tokens_ids), ) @@ -172,7 +172,7 @@ def read( response = sample[self.response_key] token_ids = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] prompt_tokens_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] - experience = SingleTurnExperience( + experience = Experience( token_ids=token_ids, prompt_length=len(prompt_tokens_ids), ) @@ -250,10 +250,10 @@ def read( add_generation_prompt=False, return_tensors="pt", )[0][prompt_length:] - experience = DPOExperience( + experience = Experience( token_ids=prompt_tokens, - chosen=chosen_tokens, - rejected=rejected_tokens, + chosen_ids=chosen_tokens, + rejected_ids=rejected_tokens, ) exp_list.append(experience) return exp_list diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index d85985e287..f41c8038ef 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -5,7 +5,7 @@ from sqlalchemy import Column, Float, Integer, LargeBinary, String from sqlalchemy.ext.declarative import declarative_base -from trinity.common.experience import Experience, MultiTurnExperience +from trinity.common.experience import Experience Base = declarative_base() @@ -90,7 +90,7 @@ def from_messages( messages=messages, chat_template=chat_template, ) - exp = MultiTurnExperience( + exp = Experience( token_ids=token_ids, action_mask=action_mask, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, @@ -119,8 +119,8 @@ class DPODataModel(Base): # type: ignore def to_experience(self) -> Experience: """Load the experience from the database.""" exp = Experience.deserialize(self.serialized_exp) - exp.chosen = Experience.deserialize(self.chosen) - exp.rejected = Experience.deserialize(self.rejected) + exp.chosen_ids = Experience.deserialize(self.chosen) + exp.rejected_ids = Experience.deserialize(self.rejected) return exp diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 9f6e227dcc..61bcc5d304 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -5,6 +5,7 @@ import pickle import uuid from dataclasses import dataclass, field +from enum import Enum from typing import List, Optional import torch @@ -12,7 +13,7 @@ @dataclass -class EID: +class EID(dict): """Experience ID class to uniquely identify an experience.""" # TODO: do we need to add project/name here to make it unique across different projects? @@ -60,13 +61,113 @@ def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})" +class ExperienceType(Enum): + """Enum for experience types.""" + + SINGLE_TURN = "single_turn" # Single-turn experience, e.g., a prompt-response pair + MULTI_TURN = "multi_turn" # Multi-turn experience, e.g., a conversation history + DPO = "dpo" # DPO experience, e.g., a chosen and rejected response pair + + @dataclass class Experience: - token_ids: Tensor # [seq_length] eid: EID = field(default_factory=EID) # Unique identifier for the experience + token_ids: Optional[Tensor] = None # [seq_length] + logprobs: Optional[Tensor] = None # [seq_length] reward: Optional[float] = None - info: Optional[dict] = None - metrics: Optional[dict[str, float]] = None + # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected_ids + experience_type: ExperienceType = ExperienceType.SINGLE_TURN + info: Optional[dict] = field( + default_factory=dict + ) # Additional information about the experience + metrics: Optional[dict[str, float]] = field( + default_factory=dict + ) # Metrics associated with the experience + + # for single-turn experiences + prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks + response_text: Optional[str] = None # Text of the response + prompt_text: Optional[str] = None # Text of the prompt + + # for multi-turn experiences + action_mask: Optional[ + Tensor + ] = None # Action mask which indicates which tokens are generated by the model + messages: Optional[List[dict]] = None # List of messages + + # for dpo experiences + chosen_ids: Optional[Tensor] = None # Token ids of the chosen response [resp_length] + rejected_ids: Optional[Tensor] = None # Token ids of the rejected response [resp_length] + chosen_text: Optional[str] = None # Text of the chosen response + rejected_text: Optional[str] = None # Text of the rejected response + + def __init__( + self, + *, + eid=None, + token_ids, + logprobs=None, + reward=None, + info=None, + metrics=None, + prompt_length=1, + response_text=None, + prompt_text=None, + action_mask=None, + messages=None, + chosen_ids=None, + rejected_ids=None, + chosen_text=None, + rejected_text=None, + ): + if action_mask is not None: + experience_type = ExperienceType.MULTI_TURN + elif chosen_ids is not None and rejected_ids is not None: + experience_type = ExperienceType.DPO + else: + experience_type = ExperienceType.SINGLE_TURN + + if experience_type == ExperienceType.SINGLE_TURN: + assert ( + prompt_length > 0 + ), "Prompt length must be greater than 0 for single-turn experiences." + assert ( + len(token_ids) > prompt_length + ), "Token ids must be longer than the prompt length." + action_mask = torch.zeros(len(token_ids), dtype=torch.bool) + action_mask[prompt_length:] = 1 + elif experience_type == ExperienceType.MULTI_TURN: + prompt_length = 1 + elif experience_type == ExperienceType.DPO: + prompt_length = len(token_ids) + + self.eid = eid or EID() + self.token_ids = token_ids + self.logprobs = logprobs + self.reward = reward + self.experience_type = experience_type + self.info = info or {} + self.metrics = metrics or {} + self.prompt_length = prompt_length + self.response_text = response_text + self.prompt_text = prompt_text + self.action_mask = action_mask + self.messages = messages + self.chosen_ids = chosen_ids + self.rejected_ids = rejected_ids + self.chosen_text = chosen_text + self.rejected_text = rejected_text + + if not isinstance(self.token_ids, Tensor): + self.token_ids = torch.tensor(self.token_ids) + if self.logprobs is not None and not isinstance(self.logprobs, Tensor): + self.logprobs = torch.tensor(self.logprobs) + if self.action_mask is not None and not isinstance(self.action_mask, Tensor): + self.action_mask = torch.tensor(self.action_mask) + if self.chosen_ids is not None and not isinstance(self.chosen_ids, Tensor): + self.chosen_ids = torch.tensor(self.chosen_ids) + if self.rejected_ids is not None and not isinstance(self.rejected_ids, Tensor): + self.rejected_ids = torch.tensor(self.rejected_ids) def serialize(self) -> bytes: """Serialize the experience to bytes.""" @@ -76,31 +177,37 @@ def serialize(self) -> bytes: def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data) - -@dataclass -class SingleTurnExperience(Experience): - """A single-turn prompt-response experience.""" - - # Length of the prompt in tokens, used for generating attention masks - prompt_length: int # type: ignore[misc] - response_text: Optional[str] = None # Text of the response - prompt_text: Optional[str] = None # Text of the prompt - logprobs: Optional[Tensor] = None # [seq] - - @property - def action_mask(self) -> Tensor: - """Get the action mask for the single-turn experience.""" - # set the prompt length to 0 and the rest to 1 - action_mask = torch.zeros_like(self.token_ids, dtype=torch.bool) - action_mask[self.prompt_length :] = 1 - return action_mask + def to_dict(self) -> dict: + """Convert the experience to a dictionary.""" + res = { + "eid": self.eid, + "type": self.experience_type.value, + "info": self.info, + "metrics": self.metrics, + } + if self.prompt_text is not None: + res["prompt_text"] = self.prompt_text + if self.response_text is not None: + res["response_text"] = self.response_text + if self.messages is not None: + res["messages"] = self.messages + if self.chosen_ids is not None: + res["chosen_ids"] = self.chosen_ids.tolist() + if self.rejected_ids is not None: + res["rejected_ids"] = self.rejected_ids.tolist() + if self.reward is not None: + res["reward"] = float(self.reward) + return res @classmethod - def gather(cls, experiences: List[SingleTurnExperience], pad_token_id: int = 0) -> Experiences: + def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experiences: if len(experiences) == 0: return empty_experiences() + exp_type = experiences[0].experience_type + if exp_type == ExperienceType.DPO: + experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] - max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [operator] + max_response_length = max([len(exp.token_ids) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] eids = [exp.eid for exp in experiences] # Gather token_ids @@ -140,78 +247,44 @@ def gather(cls, experiences: List[SingleTurnExperience], pad_token_id: int = 0) ) -@dataclass -class MultiTurnExperience(Experience): - """A multi-turn experience, which includes the conversation history in a single tensor.""" - - # Action mask which indicates which tokens are generated by the model - action_mask: Tensor # type: ignore[misc] - logprobs: Optional[Tensor] = None # [seq_length] - messages: Optional[List[dict]] = None # List of messages in the conversation - - @property - def prompt_length(self) -> int: - return 1 # use action mask to determine the response tokens, set prompt_length to 1 to avoid no prompt issues - - @classmethod - def gather(cls, experiences: List[SingleTurnExperience], pad_token_id: int = 0) -> Experiences: - return SingleTurnExperience.gather(experiences, pad_token_id=pad_token_id) - - -@dataclass -class DPOExperience(Experience): - """A DPO experience, which includes the chosen and rejected responses. - - `token_ids` should only contain the prompt tokens, while `chosen` and `rejected` should - contain the response tokens. - """ - - # Token ids of the chosen response [resp_length] - chosen: Tensor # type: ignore[misc] - # Token ids of the rejected response [resp_length] - rejected: Tensor # type: ignore[misc] - chosen_text: Optional[str] = None # Text of the chosen response - rejected_text: Optional[str] = None # Text of the rejected response - prompt_text: Optional[str] = None # Text of the prompt - - @classmethod - def gather(cls, experiences: List[DPOExperience], pad_token_id: int = 0) -> Experiences: - """Gather a batch of DPO experiences from a list of experiences.""" - single_turn_experiences = [] - for exp in experiences: - single_turn_experiences.append( - SingleTurnExperience( - eid=EID( - batch=exp.eid.batch, - task=exp.eid.task, - step=exp.eid.step, - ), - token_ids=torch.cat([exp.token_ids, exp.chosen]), - reward=exp.reward, - info=exp.info, - metrics=exp.metrics, - prompt_length=len(exp.token_ids), - prompt_text=exp.prompt_text, - response_text=exp.chosen_text, - ) +def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: + single_turn_experiences = [] + for exp in experiences: + single_turn_experiences.append( + Experience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + run=exp.eid.run, + ), + token_ids=torch.cat([exp.token_ids, exp.chosen_ids]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.token_ids), # type: ignore [arg-type] + prompt_text=exp.prompt_text, + response_text=exp.chosen_text, ) - single_turn_experiences.append( - SingleTurnExperience( - eid=EID( - batch=exp.eid.batch, - task=exp.eid.task, - step=exp.eid.step, - ), - token_ids=torch.cat([exp.token_ids, exp.rejected]), - reward=exp.reward, - info=exp.info, - metrics=exp.metrics, - prompt_length=len(exp.token_ids), - prompt_text=exp.prompt_text, - response_text=exp.rejected_text, - ) + ) + single_turn_experiences.append( + Experience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + run=exp.eid.run, + ), + token_ids=torch.cat([exp.token_ids, exp.rejected_ids]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.token_ids), # type: ignore [arg-type] + prompt_text=exp.prompt_text, + response_text=exp.rejected_text, ) - return SingleTurnExperience.gather(single_turn_experiences, pad_token_id=pad_token_id) + ) + return single_turn_experiences @dataclass(frozen=True) @@ -281,9 +354,9 @@ def gather_token_ids( pad_token_id, dtype=token_ids_dtype, ), - exp.tokens, + exp.token_ids, torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), + (max_response_length + exp.prompt_length - len(exp.token_ids),), pad_token_id, dtype=token_ids_dtype, ), @@ -306,7 +379,7 @@ def gather_action_masks(experiences, max_prompt_length: int, max_response_length ), exp.action_mask, torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), + (max_response_length + exp.prompt_length - len(exp.token_ids),), 0, dtype=torch.bool, ), @@ -324,7 +397,7 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length - end = start + len(exp.tokens) + end = start + len(exp.token_ids) attention_masks[i, start:end] = 1 return attention_masks @@ -343,7 +416,7 @@ def gather_logprobs(experiences, max_prompt_length: int, max_response_length: in ), exp.logprobs, torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), + (max_response_length + exp.prompt_length - len(exp.token_ids),), 0.0, dtype=logprob_dtype, ), diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index e32e7ceddb..61f6a1af8c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -11,7 +11,7 @@ import torch from torch import Tensor -from trinity.common.experience import Experience, SingleTurnExperience +from trinity.common.experience import Experience from trinity.utils.log import get_logger @@ -189,7 +189,7 @@ def convert_api_output_to_experience( ) -> List[Experience]: """Convert the API output to a list of experiences.""" return [ - SingleTurnExperience( + Experience( token_ids=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 856ea169b6..5bdf987c8b 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -12,11 +12,7 @@ from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig -from trinity.common.experience import ( - Experience, - MultiTurnExperience, - SingleTurnExperience, -) +from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.models.utils import ( tokenize_and_mask_messages_default, @@ -149,7 +145,7 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """ output = await self._generate_internal(prompt=prompt, **kwargs) experiences = [ - SingleTurnExperience( + Experience( token_ids=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), @@ -225,7 +221,7 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.tokenizer, messages, self.chat_template ) logprobs = await self.logprobs(token_ids=token_ids.tolist()) - return MultiTurnExperience( + return Experience( token_ids=token_ids, logprobs=logprobs, action_mask=action_mask, diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 97952f2cd3..7536f0ebac 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -10,7 +10,7 @@ import openai from trinity.common.config import FormatConfig, GenerationConfig -from trinity.common.experience import Experience, MultiTurnExperience +from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathRewardFn from trinity.common.rewards.reward_fn import RewardFn @@ -24,7 +24,7 @@ @dataclass -class Task: +class Task(dict): """A Task class that defines a task and its associated reward function / workflow.""" workflow: Type[Workflow] @@ -144,7 +144,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc if isinstance(v, float) or isinstance(v, int): metrics[k] = float(v) - experience = MultiTurnExperience( + experience = Experience( token_ids=token_ids, action_mask=generation_mask, reward=reward, From 0f368e74019e7b6f414c3a522ee2f4f39795d1f5 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 17:00:16 +0800 Subject: [PATCH 05/21] add add_strategy --- tests/common/experience_test.py | 2 +- tests/explorer/scheduler_test.py | 2 +- trinity/algorithm/__init__.py | 3 + trinity/algorithm/add_strategy/__init__.py | 11 +++ .../algorithm/add_strategy/add_strategy.py | 82 +++++++++++++++++++ trinity/buffer/__init__.py | 10 ++- trinity/common/config.py | 9 ++ trinity/common/experience.py | 25 ++++-- trinity/common/workflows/workflow.py | 4 +- trinity/explorer/explorer.py | 45 ++++++---- trinity/explorer/scheduler.py | 69 +++++++++------- trinity/explorer/workflow_runner.py | 29 ++++--- 12 files changed, 221 insertions(+), 70 deletions(-) create mode 100644 trinity/algorithm/add_strategy/__init__.py create mode 100644 trinity/algorithm/add_strategy/add_strategy.py diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 097a7ff6a5..2b431fc71e 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -19,7 +19,7 @@ def test_eid_properties(self): self.assertEqual(eid.uid, "1/2/3/4/abc123") self.assertEqual(eid.sid, "1/2/4") self.assertEqual(eid.rid, "1/2/3") - self.assertEqual(eid.gid, "1/2") + self.assertEqual(eid.tid, "1/2") self.assertEqual(str(eid), "1/2/3/4/abc123") self.assertIn("EID(batch=1, task=2, run=3, step=4, uuid=abc123)", repr(eid)) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 1f1342fdad..07387de922 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -436,7 +436,7 @@ async def test_split_tasks(self): self.queue.read(batch_size=1) # test group_id and unique_id - group_ids = [exp.group_id for exp in exp_list] + group_ids = [exp.eid.gid for exp in exp_list] self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3 unique_ids = [exp.unique_id for exp in exp_list] self.assertEqual(len(unique_ids), len(set(unique_ids))) diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 667aa10d74..b5b03d2075 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,3 +1,4 @@ +from trinity.algorithm.add_strategy import ADD_STRATEGY, AddStrategy from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn @@ -18,4 +19,6 @@ "ENTROPY_LOSS_FN", "SampleStrategy", "SAMPLE_STRATEGY", + "AddStrategy", + "ADD_STRATEGY", ] diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py new file mode 100644 index 0000000000..45cd0d6e40 --- /dev/null +++ b/trinity/algorithm/add_strategy/__init__.py @@ -0,0 +1,11 @@ +from trinity.algorithm.add_strategy.sample_strategy import ( + ADD_STRATEGY, + AddStrategy, + RewardVarianceAddStrategy, +) + +__all__ = [ + "ADD_STRATEGY", + "AddStrategy", + "RewardVarianceAddStrategy", +] diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py new file mode 100644 index 0000000000..b0c173f4ad --- /dev/null +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Literal + +from numpy import np + +from trinity.buffer import BufferWriter +from trinity.common.experience import Experience +from trinity.utils.registry import Registry + +ADD_STRATEGY = Registry("add_strategy") + +class AddStrategy(ABC): + + def __init__(self, writer: BufferWriter, **kwargs) -> None: + self.writer = writer + + @abstractmethod + async def add(self, experiences: List[Experience], step: int) -> int: + """Add experiences to the buffer. + + Args: + experiences (`Experience`): The experiences to be added. + step (`int`): The current step number. + + Returns: + `int`: The number of experiences added to the buffer. + """ + + @classmethod + @abstractmethod + def default_args(cls) -> dict: + """Get the default arguments of the add strategy. + + Returns: + `dict`: The default arguments. + """ + + +class RewardVarianceAddStrategy(AddStrategy): + """An example add strategy that filters experiences based on a reward threshold.""" + + def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: + super().__init__(writer) + + async def add(self, experiences: List[Experience], step: int) -> int: + cnt = 0 + grouped_experiences = group_by(experiences, id_type="task") + for _, group_exps in grouped_experiences.items(): + if len(group_exps) < 2: + continue + # check if the rewards are the same + rewards = [exp.reward for exp in group_exps] + variance = np.var(rewards) + if variance < self.variance_threshold: + continue + cnt += len(group_exps) + await self.writer.write_async(group_exps) + return cnt + + + @classmethod + def default_args(cls) -> dict: + return {"reward_threshold": 0.0} + + +def group_by(experiences: List[Experience], id_type: Literal["task", "run", "step"]) -> Dict[str, List[Experience]]: + """Group experiences by ID.""" + if id_type == "task": + id_type = "tid" + elif id_type == "run": + id_type = "rid" + elif id_type == "step": + id_type = "sid" + else: + raise ValueError(f"Unknown id_type: {id_type}") + grouped = {} + for exp in experiences: + group_id = getattr(exp.eid, id_type) + if group_id not in grouped: + grouped[group_id] = [] + grouped[group_id].append(exp) + return grouped diff --git a/trinity/buffer/__init__.py b/trinity/buffer/__init__.py index 7e11b73a44..e7cc4c1b9b 100644 --- a/trinity/buffer/__init__.py +++ b/trinity/buffer/__init__.py @@ -1,7 +1,15 @@ -from trinity.buffer.buffer import Buffer, get_buffer_reader, get_buffer_writer +from trinity.buffer.buffer import ( + Buffer, + BufferReader, + BufferWriter, + get_buffer_reader, + get_buffer_writer, +) __all__ = [ "Buffer", + "BufferReader", + "BufferWriter", "get_buffer_reader", "get_buffer_writer", ] diff --git a/trinity/common/config.py b/trinity/common/config.py index 5a9494504d..4bee6de577 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -226,6 +226,11 @@ class AlgorithmConfig: # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 + # the strategy for adding experiences to the buffer + add_strategy: Optional[str] = None + add_strategy_args: Optional[dict] = None + + # the strategy for sampling experiences from the buffer sample_strategy: Optional[str] = None sample_strategy_args: Optional[dict] = None @@ -328,6 +333,10 @@ class ExplorerConfig: runner_num: Optional[int] = None # deprecated + # Explorer collects experiences from workflow runners + # some algorithms (e.g., DAPO) need to collect experiences generated by the same task and do some post-processing + collect_experiences: bool = False + # for inference models # for rollout model rollout_model: InferenceModelConfig = field(default_factory=InferenceModelConfig) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 61bcc5d304..81a7398e70 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -14,13 +14,24 @@ @dataclass class EID(dict): - """Experience ID class to uniquely identify an experience.""" + """Experience ID class to uniquely identify an experience. + + To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. + """ # TODO: do we need to add project/name here to make it unique across different projects? - batch: int = 0 # Batch number, e.g., the explorer step num + # Batch number, e.g., the explorer step num + # Automatically set by the workflow runner + batch: int = 0 + # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 + # Automatically set by the workflow runner task: int = 0 # Task sequence in the batch, e.g., the first task in the batch has task=0 - run: int = 0 # Run id, e.g., the first run in the task has run=0 - step: int = 0 # Step number when running the task, e.g., the first step in the task has step=0 + # Run id, e.g., the first run in the task has run=0 + # User should set this field in custom workflows when creating experiences + run: int = 0 + # Step number when running the task, e.g., the first step in the task has step=0 + # User should set this field in custom workflows when creating experiences + step: int = 0 suffix: str = field( default_factory=lambda: uuid.uuid4().hex[:6] ) # Unique identifier suffix, e.g., a UUID @@ -47,10 +58,10 @@ def rid(self) -> str: return f"{self.batch}/{self.task}/{self.run}" @property - def gid(self) -> str: - """Group ID for the experience. + def tid(self) -> str: + """Task ID for the experience. - For example, experiences generated by a group run in GRPO-like algorithms will have the same gid. + For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. """ return f"{self.batch}/{self.task}" diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 7536f0ebac..d2418fd99f 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -36,7 +36,9 @@ class Task(dict): reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample - group_id: Optional[str] = None # for GRPO-like algorithms, automatically assigned + # automatically assigned ids + batch_id: int = 0 + task_id: int = 0 def to_workflow( self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c22516be42..13d945c6e5 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -11,6 +11,7 @@ import torch +from trinity.algorithm import ADD_STRATEGY from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.buffer import get_buffer_writer from trinity.buffer.buffer import get_buffer_reader @@ -80,6 +81,12 @@ def __init__(self, config: Config): self.status = RunningStatus.RUNNING self.logger.info("Finished initializing Explorer.") self._ready_to_sync_condition = asyncio.Condition() + self.collect_experiences = self.config.explorer.collect_experiences + self.generated_experience_cnt = 0 + if self.collect_experiences: + self.add_strategy = ADD_STRATEGY.get( + self.config.algorithm.add_strategy, self.config.algorithm.trainer_type + ) async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None @@ -280,12 +287,12 @@ async def benchmark(self) -> bool: if self.config.explorer.bench_on_latest_checkpoint: self.explore_step_num = await self._checkpoint_weights_update() self.eval() - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") return True # benchmark on base model if self.config.explorer.eval_on_startup: - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") # benchmark on all checkpoints all_ckp_steps = sorted( @@ -299,16 +306,17 @@ async def benchmark(self) -> bool: for step_num in all_ckp_steps: self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) self.eval() - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") return True async def save_checkpoint(self, sync_weight: bool = False) -> None: - # wait for all tasks to complete - self.logger.info("Waiting for all tasks to complete") - await self.scheduler.wait_all() - self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") + if not self.config.explorer.collect_experiences: + # wait for all tasks to complete + self.logger.info("Waiting for all tasks to complete") + await self.scheduler.wait_all() + self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") log_task = asyncio.create_task( - self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + self._finish_steps(self.last_sync_step + 1, self.explore_step_num) ) if sync_weight: @@ -335,19 +343,22 @@ async def sync_weight(self) -> None: # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) - async def _log_metrics(self, start_step: int, end_step: int) -> None: + async def _finish_steps(self, start_step: int, end_step: int) -> None: for step in range(start_step, end_step + 1): self.logger.info(f"Log metrics of step {step}") - await self._log_explore_metrics(step=step) - await self._log_eval_metrics(step=step) - - async def _log_explore_metrics(self, step: int) -> None: - results = await self.scheduler.get_results(batch_id=step) - if results: - metric = gather_metrics([status.metric for status in results], "rollout") + await self._finish_explore_step(step=step) + await self._finish_eval_step(step=step) + + async def _finish_explore_step(self, step: int) -> None: + statuses, exps = await self.scheduler.get_results(batch_id=step) + if self.config.explorer.collect_experiences: + exp_cnt = self.add_strategy.add(exps, step) + self.generated_experience_cnt += exp_cnt + if statuses: + metric = gather_metrics([status.metric for status in statuses], "rollout") self.monitor.log(metric, step=step) - async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eval") -> None: + async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: return step = step or self.explore_step_num diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 459793d2c8..7a79f9a8f0 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -11,6 +11,7 @@ import ray from trinity.common.config import Config +from trinity.common.experience import Experience from trinity.common.models import InferenceModel from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner @@ -59,21 +60,23 @@ def _create_runner(self): .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) ) - async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]: + async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: """ Returns: `Status`: The return status of the task. + `List`: The experiences generated by the task. `int`: The runner_id of current runner. """ last_exception_msg = None await self.runner.__ray_ready__.remote() start_time = time.time() status = Status(ok=False, metric=dict()) + exps = [] try: for attempt in range(self.retry_times + 1): try: task.task.rollout_args.n = task.repeat_times - status = await asyncio.wait_for( + status, exps = await asyncio.wait_for( self.runner.run_task.remote(task.task), self.timeout ) if status.ok: @@ -93,7 +96,7 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]: finally: end_time = time.time() status.metric["task_run_time"] = end_time - start_time - return status, self.runner_id + return status, exps, self.runner_id def restart_runner(self): old_runner = self.runner @@ -147,7 +150,9 @@ def __init__( set ) # batch_id -> futures self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task - self.completed_tasks: Dict[Union[int, str], deque[Status]] = defaultdict( + self.completed_tasks: Dict[ + Union[int, str], deque[Tuple[Status, List[Experience]]] + ] = defaultdict( deque ) # batch_id -> results @@ -225,13 +230,11 @@ def task_done_callback(self, async_task: asyncio.Task): self.logger.error(f"Task {task.task_id} failed: {async_task.exception()}") return else: - task_result, runner_id = async_task.result() - self.completed_tasks[task.batch_id].appendleft(task_result) + status, exps, runner_id = async_task.result() + self.completed_tasks[task.batch_id].appendleft((status, exps)) self.busy_runners.pop(runner_id) self.idle_runners.add(runner_id) - self.logger.debug( - f"Task completed (batch_id {task.batch_id}), success: {task_result.ok}" - ) + self.logger.debug(f"Task completed (batch_id {task.batch_id}), success: {status.ok}") if task.batch_id in self.running_tasks: self.running_tasks[task.batch_id].remove(async_task) @@ -294,8 +297,8 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) -> None: for i, task in enumerate(tasks): - group_id = f"{batch_id}/{i}" - task.group_id = group_id + task.batch_id = batch_id + task.task_id = i if self.max_repeat_times is None: self.pending_tasks[batch_id].appendleft( TaskWrapper( @@ -321,7 +324,7 @@ async def get_results( min_num: Optional[int] = None, timeout: Optional[float] = None, clear_timeout_tasks: bool = True, - ) -> List[Status]: + ) -> Tuple[List[Status], List[Experience]]: """Get the result of tasks at the specific batch_id. Args: @@ -333,18 +336,19 @@ async def get_results( timeout = timeout or self.default_timeout start_time = time.time() if min_num is None: - min_num = 0 - if batch_id in self.pending_tasks: - min_num += len(self.pending_tasks[batch_id]) - if batch_id in self.running_tasks: - min_num += len(self.running_tasks[batch_id]) - if batch_id in self.completed_tasks: - min_num += len(self.completed_tasks[batch_id]) + min_num = sum( + len(tasks) # type: ignore [misc] + for tasks in ( + self.pending_tasks.get(batch_id, []), + self.running_tasks.get(batch_id, []), + self.completed_tasks.get(batch_id, []), + ) + ) self.logger.debug(f"Waiting for {min_num} tasks to complete...") while time.time() - start_time < timeout: - completed_count = len(self.completed_tasks[batch_id]) + completed_count = len(self.completed_tasks.get(batch_id, [])) if completed_count >= min_num: break await asyncio.sleep(0.1) @@ -353,25 +357,32 @@ async def get_results( self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") if clear_timeout_tasks: self._clear_timeout_tasks(batch_id=batch_id) - for runner_id in list(self.busy_runners.keys()): - if self.busy_runners[runner_id].batch_id == batch_id: + for runner_id, task in list(self.busy_runners.items()): + if task.batch_id == batch_id: self._restart_runner(runner_id) - results = [] + statuses = [] + experiences = [] + completed_queue = self.completed_tasks.get(batch_id, deque()) for _ in range(min_num): - if len(self.completed_tasks[batch_id]) > 0: - results.append(self.completed_tasks[batch_id].pop()) - - if not self.completed_tasks[batch_id]: + if completed_queue: + status, exps = completed_queue.pop() + statuses.append(status) + if isinstance(exps, list): + experiences.extend(exps) + else: + experiences.append(exps) + + if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: del self.completed_tasks[batch_id] - completed_count = len(results) + completed_count = len(statuses) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) - return results + return statuses, experiences def has_step(self, batch_id: Union[int, str]) -> bool: return ( diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b973a114d2..eb8e75205f 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -"""The Workflow Runner Moudle.""" +"""The Workflow Runner Module.""" import time import traceback -import uuid from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from trinity.buffer import get_buffer_writer from trinity.common.config import Config @@ -25,7 +24,7 @@ class Status: class WorkflowRunner: - """A Ray remote actor to run the workflow and put the returned experiences into the buffer.""" + """A Ray remote actor to run the workflow and generate experiences.""" def __init__( self, @@ -55,6 +54,7 @@ def __init__( self.logger = get_logger(__name__) self.workflow_instance = None self.runner_id = runner_id + self.return_experiences = self.config.explorer.collect_experiences def is_alive(self): return True @@ -73,8 +73,9 @@ def _run_task(self, task: Task) -> List[Experience]: self.workflow_instance.reset(task) return self.workflow_instance.run() - def run_task(self, task: Task) -> Status: + def run_task(self, task: Task) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" + # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() exps = self._run_task(task) @@ -82,11 +83,8 @@ def run_task(self, task: Task) -> Status: metrics: dict[str, List[float]] = defaultdict(list) # set group id for idx, exp in enumerate(exps): - setattr(exp, "group_id", task.group_id) - setattr( - exp, "unique_id", f"{task.group_id}/{self.runner_id}/{str(uuid.uuid4())[:6]}" - ) - + exp.eid.batch = task.batch_id + exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = self.model_wrapper.model_version @@ -102,10 +100,15 @@ def run_task(self, task: Task) -> Status: if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore - if not task.is_eval: + if task.is_eval: + # If the task is an evaluation task, we do not record the experiences to the buffer + return Status(True, metric=metric), [] + if self.return_experiences: self.experience_buffer.write(exps) - return Status(True, metric=metric) + return Status(True, metric=metric), [] + else: + return Status(True, metric=metric), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") - return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)) + return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] From 890bbc55967ab918682b65fd8de4e5b9127d59ab Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 17:25:37 +0800 Subject: [PATCH 06/21] check config --- .../algorithm/add_strategy/add_strategy.py | 9 +-- trinity/common/config.py | 55 +++++++------------ 2 files changed, 24 insertions(+), 40 deletions(-) diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index b0c173f4ad..067f96483b 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Literal +from typing import Dict, List, Literal from numpy import np @@ -9,8 +9,8 @@ ADD_STRATEGY = Registry("add_strategy") -class AddStrategy(ABC): +class AddStrategy(ABC): def __init__(self, writer: BufferWriter, **kwargs) -> None: self.writer = writer @@ -57,13 +57,14 @@ async def add(self, experiences: List[Experience], step: int) -> int: await self.writer.write_async(group_exps) return cnt - @classmethod def default_args(cls) -> dict: return {"reward_threshold": 0.0} -def group_by(experiences: List[Experience], id_type: Literal["task", "run", "step"]) -> Dict[str, List[Experience]]: +def group_by( + experiences: List[Experience], id_type: Literal["task", "run", "step"] +) -> Dict[str, List[Experience]]: """Group experiences by ID.""" if id_type == "task": id_type = "tid" diff --git a/trinity/common/config.py b/trinity/common/config.py index 4bee6de577..f8d6fffd24 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -335,6 +335,7 @@ class ExplorerConfig: # Explorer collects experiences from workflow runners # some algorithms (e.g., DAPO) need to collect experiences generated by the same task and do some post-processing + # will automatically set to True if `algorithm.add_strategy` is not None collect_experiences: bool = False # for inference models @@ -647,6 +648,7 @@ def _check_buffer(self) -> None: # noqa: C901 def _check_algorithm(self) -> None: from trinity.algorithm import ( + ADD_STRATEGY, ADVANTAGE_FN, ENTROPY_LOSS_FN, KL_FN, @@ -670,42 +672,23 @@ def _check_algorithm(self) -> None: if getattr(self.algorithm, key, None) is None: setattr(self.algorithm, key, value) - # TODO: simplify the following code - sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy) - if sample_strategy_cls is None: - raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}") - if self.algorithm.sample_strategy_args is None: - self.algorithm.sample_strategy_args = sample_strategy_cls.default_args() - - policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) - if policy_fn_cls is None: - raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}") - if self.algorithm.policy_loss_fn_args is None: - self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() - - advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn) - if advantage_fn_cls is None: - raise ValueError(f"Invalid advantage_fn: {self.algorithm.advantage_fn}") - if self.algorithm.advantage_fn_args is None: - self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() - - kl_loss_fn_cls = KL_FN.get(self.algorithm.kl_loss_fn) - if kl_loss_fn_cls is None: - raise ValueError(f"Invalid kl_loss_fn: {self.algorithm.kl_loss_fn}") - if self.algorithm.kl_loss_fn_args is None: - self.algorithm.kl_loss_fn_args = kl_loss_fn_cls.default_args() - - kl_penalty_fn_cls = KL_FN.get(self.algorithm.kl_penalty_fn) - if kl_penalty_fn_cls is None: - raise ValueError(f"Invalid kl_penalty_fn: {self.algorithm.kl_penalty_fn}") - if self.algorithm.kl_penalty_fn_args is None: - self.algorithm.kl_penalty_fn_args = kl_penalty_fn_cls.default_args() - - entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.algorithm.entropy_loss_fn) - if entropy_loss_fn_cls is None: - raise ValueError(f"Invalid entropy_loss_fn: {self.algorithm.entropy_loss_fn}") - if self.algorithm.entropy_loss_fn_args is None: - self.algorithm.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + def check_and_set(name, registry, args_attr): + fn_cls = registry.get(getattr(self.algorithm, name)) + if fn_cls is None: + raise ValueError(f"Invalid {name}: {getattr(self.algorithm, name)}") + if getattr(self.algorithm, args_attr) is None: + setattr(self.algorithm, args_attr, fn_cls.default_args()) + return fn_cls + + if self.algorithm.add_strategy is not None: + check_and_set("add_strategy", ADD_STRATEGY, "add_strategy_args") + self.explorer.collect_experiences = True + check_and_set("sample_strategy", SAMPLE_STRATEGY, "sample_strategy_args") + check_and_set("policy_loss_fn", POLICY_LOSS_FN, "policy_loss_fn_args") + check_and_set("advantage_fn", ADVANTAGE_FN, "advantage_fn_args") + check_and_set("kl_loss_fn", KL_FN, "kl_loss_fn_args") + check_and_set("kl_penalty_fn", KL_FN, "kl_penalty_fn_args") + check_and_set("entropy_loss_fn", ENTROPY_LOSS_FN, "entropy_loss_fn_args") def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" From dbf2cd03af05abb55c8997063fc257ad25f49d2b Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 18:14:04 +0800 Subject: [PATCH 07/21] fix scheduler tests --- .../source/tutorial/example_mix_algo.md | 6 +- .../tutorial/trinity_programming_guide.md | 4 +- tests/common/vllm_test.py | 12 +- tests/explorer/scheduler_test.py | 112 ++++++++++-------- trinity/algorithm/add_strategy/__init__.py | 2 +- .../algorithm/add_strategy/add_strategy.py | 2 +- .../sample_strategy/mix_sample_strategy.py | 2 +- .../workflows/customized_math_workflows.py | 2 +- trinity/explorer/explorer.py | 6 +- trinity/explorer/scheduler.py | 2 +- trinity/explorer/workflow_runner.py | 4 +- 11 files changed, 84 insertions(+), 70 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 0e08b8648b..6f91481829 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -113,7 +113,7 @@ class MixSampleStrategy(SampleStrategy): expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + exp.logprobs = torch.zeros_like(exp.token_ids, dtype=torch.float32) if exp.info is None: exp.info = {} exp.info["is_expert"] = True @@ -144,8 +144,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to batch_dict = { "uid": np.array(experiences.group_ids), "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "input_ids": experiences.token_ids.long(), + "responses": experiences.token_ids[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index a7c92bef61..b1ca7d2dc8 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -178,7 +178,7 @@ class ExampleWorkflow(Workflow): # construct Experience experiences.append( Experience( - tokens=response.tokens, + token_ids=response.tokens, prompt_length=response.prompt_length, reward=reward, logprobs=response.logprobs, @@ -275,7 +275,7 @@ class ExampleWorkflow(Workflow): # construct Experience experiences.append( Experience( - tokens=response.tokens, + token_ids=response.tokens, prompt_length=response.prompt_length, reward=reward, logprobs=response.logprobs, diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 71ccf32b7f..b3dcc025ee 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -130,7 +130,7 @@ async def test_generate( self.assertEqual(len(history_experiences), len(generate_results)) for exp, history_exp in zip(generate_results, history_experiences): self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.token_ids.tolist(), history_exp.token_ids.tolist()) self.assertEqual(exp.prompt_length, history_exp.prompt_length) self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) else: @@ -155,7 +155,7 @@ async def test_generate( self.assertEqual(len(history_experiences) - len(generate_results), len(results)) for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.token_ids.tolist(), history_exp.token_ids.tolist()) self.assertEqual(exp.prompt_length, history_exp.prompt_length) self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) for result in results: @@ -164,10 +164,10 @@ async def test_generate( self.assertTrue(torch.all(input_logprobs == 0)) self.assertTrue(torch.any(output_logprobs != 0)) if self.use_async: - logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) + logprobs = await self.model_wrapper.logprobs_async(results[0].token_ids.tolist()) else: - logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) - self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) + logprobs = self.model_wrapper.logprobs(results[0].token_ids.tolist()) + self.assertEqual(logprobs.shape[0], results[0].token_ids.shape[0]) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() self.assertTrue(len(history_experiences) == 0) @@ -191,7 +191,7 @@ async def test_generate( return_dict=True, ) self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) - self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) + self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.token_ids)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 07387de922..83d4656fbe 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -19,8 +19,8 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models): - super().__init__(model, task, auxiliary_models) + def __init__(self, *, task, model, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.error_type = task.raw_task.get("error_type", "") self.seconds = None self.repeat_times = task.rollout_args.n @@ -176,13 +176,14 @@ async def test_get_results(self): tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) - results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) - self.assertEqual(len(results), 8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 0) self.assertEqual(len(self.queue.read(batch_size=8)), 8) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) - for result in results: + for result in statuses: self.assertTrue(result.ok) for batch_id in range(1, 4): @@ -191,8 +192,9 @@ async def test_get_results(self): for batch_id in range(1, 4): self.assertTrue(scheduler.has_step(batch_id)) - results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(batch_id)) self.assertEqual(len(self.queue.read(batch_size=4)), 4) with self.assertRaises(TimeoutError): @@ -201,8 +203,9 @@ async def test_get_results(self): tasks = generate_tasks(3) scheduler.schedule(tasks, batch_id=4) self.assertTrue(scheduler.has_step(4)) - results = await scheduler.get_results(batch_id=4) - self.assertEqual(len(results), 3) + statuses, exps = await scheduler.get_results(batch_id=4) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(4)) self.assertEqual(len(self.queue.read(batch_size=3)), 3) @@ -211,11 +214,11 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - results = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) end_time = time.time() self.assertLessEqual(end_time - start_time, 5) - self.assertEqual(len(results), 2) + self.assertEqual(len(statuses), 2) self.assertEqual(len(self.queue.read(batch_size=2)), 2) # test run tasks after timeout @@ -223,10 +226,10 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) # actor restart is slow, set a big timeout - results = await scheduler.get_results(batch_id=0, timeout=20) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=0, timeout=20) + self.assertEqual(len(statuses), 4) - success_count = sum(1 for r in results if r.ok) + success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 4) self.assertEqual(len(self.queue.read(batch_size=4)), 4) with self.assertRaises(TimeoutError): @@ -235,10 +238,10 @@ async def test_get_results(self): # test exception tasks tasks = generate_tasks(1, exception_num=3) scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1, timeout=5) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=1, timeout=5) + self.assertEqual(len(statuses), 4) - success_count = sum(1 for r in results if r.ok) + success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 1) self.assertEqual(len(self.queue.read(batch_size=1)), 1) with self.assertRaises(TimeoutError): @@ -247,11 +250,16 @@ async def test_get_results(self): # test clear_timeout_tasks tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) - results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) - self.assertEqual(len(results), 3) + statuses, exps = await scheduler.get_results( + batch_id=2, timeout=2, clear_timeout_tasks=False + ) + self.assertEqual(len(statuses), 3) self.assertEqual(len(self.queue.read(batch_size=3)), 3) - results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) - self.assertEqual(len(results), 1) + statuses, exps = await scheduler.get_results( + batch_id=2, timeout=2, clear_timeout_tasks=False + ) + self.assertEqual(len(statuses), 1) + self.assertEqual(len(exps), 0) self.assertEqual(len(self.queue.read(batch_size=1)), 1) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) @@ -277,10 +285,10 @@ async def test_wait_all(self): self.assertEqual(len(scheduler.pending_tasks), 0) self.assertEqual(len(scheduler.running_tasks), 0) - results0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) - results1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) - self.assertEqual(len(results0), 4) - self.assertEqual(len(results1), 3) + status0, exps0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) + status1, exps1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) + self.assertEqual(len(status0), 4) + self.assertEqual(len(status1), 3) # test timeout tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) @@ -342,9 +350,9 @@ async def schedule_tasks(batch_id, num_tasks): schedule_tasks(2, 2), ) - self.assertEqual(len(results[0]), 3) - self.assertEqual(len(results[1]), 4) - self.assertEqual(len(results[2]), 2) + self.assertEqual(len(results[0][0]), 3) + self.assertEqual(len(results[1][0]), 4) + self.assertEqual(len(results[2][0]), 2) await scheduler.stop() @@ -354,47 +362,55 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(2) scheduler.schedule(tasks, batch_id=0) - results = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) + results, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) self.assertEqual(len(results), 2) + self.assertEqual(len(exps), 0) await scheduler.stop() + self.config.explorer.collect_experiences = True await scheduler.start() - tasks = generate_tasks(3) + tasks = generate_tasks(3, repeat_times=2) scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) + results, exps = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) self.assertEqual(len(results), 3) + self.assertEqual(len(exps), 3 * 2) await scheduler.stop() async def test_scheduler_all_methods(self): + self.config.explorer.collect_experiences = True scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) self.assertTrue(scheduler.has_step(0)) - results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) - self.assertEqual(len(results), 8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 8) scheduler.schedule(tasks, batch_id=1) scheduler.schedule(tasks[:4], batch_id=2) self.assertFalse(scheduler.has_step(0)) - results = await scheduler.get_results(batch_id=0, min_num=8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8) self.assertFalse(scheduler.has_step(0)) - self.assertEqual(len(results), 0) # batch_id 0 has no more tasks + self.assertEqual(len(statuses), 0) # batch_id 0 has no more tasks + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(0)) self.assertTrue(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) await scheduler.wait_all() st = time.time() - results = await scheduler.get_results(batch_id=1) + statuses, exps = await scheduler.get_results(batch_id=1) et = time.time() self.assertTrue(et - st < 1.0) - self.assertEqual(len(results), 8) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 8) self.assertFalse(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) st = time.time() - results = await scheduler.get_results(batch_id=2) + statuses, exps = await scheduler.get_results(batch_id=2) et = time.time() self.assertTrue(et - st < 1.0) - self.assertEqual(len(results), 4) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 4) self.assertFalse(scheduler.has_step(2)) await scheduler.stop() @@ -407,8 +423,8 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1) - self.assertEqual(len(results), 4 * 4) + statuses, exps = await scheduler.get_results(batch_id=1) + self.assertEqual(len(statuses), 4 * 4) exps = self.queue.read(batch_size=4 * 8) self.assertEqual(len(exps), 4 * 8) exp_list.extend(exps) @@ -417,8 +433,8 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 scheduler.schedule(tasks, batch_id=2) - results = await scheduler.get_results(batch_id=2) - self.assertEqual(len(results), 4 * 3) + statuses, exps = await scheduler.get_results(batch_id=2) + self.assertEqual(len(statuses), 4 * 3) exps = self.queue.read(batch_size=4 * 5) self.assertEqual(len(exps), 4 * 5) exp_list.extend(exps) @@ -427,18 +443,18 @@ async def test_split_tasks(self): tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 scheduler.schedule(tasks, batch_id=3) - results = await scheduler.get_results(batch_id=3) - self.assertEqual(len(results), 3 * 1) + statuses, exps = await scheduler.get_results(batch_id=3) + self.assertEqual(len(statuses), 3 * 1) exps = self.queue.read(batch_size=3 * 1) self.assertEqual(len(exps), 3 * 1) exp_list.extend(exps) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) - # test group_id and unique_id - group_ids = [exp.eid.gid for exp in exp_list] + # test task_id and unique_id + group_ids = [exp.eid.tid for exp in exp_list] self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3 - unique_ids = [exp.unique_id for exp in exp_list] + unique_ids = [exp.eid.uid for exp in exp_list] self.assertEqual(len(unique_ids), len(set(unique_ids))) await scheduler.stop() diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index 45cd0d6e40..d1bbc84e1c 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -1,4 +1,4 @@ -from trinity.algorithm.add_strategy.sample_strategy import ( +from trinity.algorithm.add_strategy.add_strategy import ( ADD_STRATEGY, AddStrategy, RewardVarianceAddStrategy, diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index 067f96483b..18f91cbb2a 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Literal -from numpy import np +import numpy as np from trinity.buffer import BufferWriter from trinity.common.experience import Experience diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 8ebd335246..a750988a5f 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -57,7 +57,7 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + exp.logprobs = torch.zeros_like(exp.token_ids, dtype=torch.float32) if exp.info is None: exp.info = {} exp.info["is_expert"] = True diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index c2762ae43c..4b04524a74 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -81,7 +81,7 @@ def run(self) -> List[Experience]: truth=self.truth, with_think=self.with_think, format_score_coef=self.format_score_coef, - response_token=response.tokens[response.prompt_length :], + response_token=response.token_ids[response.prompt_length :], ) if response.metrics is None: diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 13d945c6e5..84dcf596e7 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -84,9 +84,7 @@ def __init__(self, config: Config): self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 if self.collect_experiences: - self.add_strategy = ADD_STRATEGY.get( - self.config.algorithm.add_strategy, self.config.algorithm.trainer_type - ) + self.add_strategy = ADD_STRATEGY.get(self.config.algorithm.add_strategy) async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None @@ -369,7 +367,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - eval_results = await self.scheduler.get_results(f"{step}/{eval_task_name}") + eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}") metric.update( gather_metrics( [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 7a79f9a8f0..5580d20fa0 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -347,7 +347,7 @@ async def get_results( self.logger.debug(f"Waiting for {min_num} tasks to complete...") - while time.time() - start_time < timeout: + while time.time() - start_time <= timeout: completed_count = len(self.completed_tasks.get(batch_id, [])) if completed_count >= min_num: break diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index eb8e75205f..8e7a31e41e 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -104,10 +104,10 @@ def run_task(self, task: Task) -> Tuple[Status, List[Experience]]: # If the task is an evaluation task, we do not record the experiences to the buffer return Status(True, metric=metric), [] if self.return_experiences: + return Status(True, metric=metric), exps + else: self.experience_buffer.write(exps) return Status(True, metric=metric), [] - else: - return Status(True, metric=metric), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") From 2776146feec57af5808e03dded5fabed49bc84b0 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 18:22:13 +0800 Subject: [PATCH 08/21] use tokens --- .../source/tutorial/example_mix_algo.md | 6 +- .../tutorial/trinity_programming_guide.md | 4 +- tests/buffer/queue_test.py | 14 +-- tests/buffer/sql_test.py | 4 +- tests/common/experience_test.py | 92 +++++++++---------- tests/common/vllm_test.py | 26 +++--- tests/explorer/scheduler_test.py | 2 +- .../sample_strategy/mix_sample_strategy.py | 6 +- trinity/algorithm/sample_strategy/utils.py | 4 +- trinity/buffer/reader/file_reader.py | 14 +-- trinity/buffer/schema/sql_schema.py | 4 +- trinity/common/experience.py | 56 ++++++----- trinity/common/models/api/vllm_patch.py | 12 +-- trinity/common/models/model.py | 4 +- trinity/common/models/utils.py | 4 +- trinity/common/models/vllm_model.py | 16 ++-- .../workflows/customized_math_workflows.py | 2 +- trinity/common/workflows/workflow.py | 6 +- trinity/trainer/verl_trainer.py | 4 +- 19 files changed, 136 insertions(+), 144 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 6f91481829..0e08b8648b 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -113,7 +113,7 @@ class MixSampleStrategy(SampleStrategy): expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.token_ids, dtype=torch.float32) + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) if exp.info is None: exp.info = {} exp.info["is_expert"] = True @@ -144,8 +144,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to batch_dict = { "uid": np.array(experiences.group_ids), "position_ids": position_ids, - "input_ids": experiences.token_ids.long(), - "responses": experiences.token_ids[:, experiences.prompt_length :].long(), + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index b1ca7d2dc8..a7c92bef61 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -178,7 +178,7 @@ class ExampleWorkflow(Workflow): # construct Experience experiences.append( Experience( - token_ids=response.tokens, + tokens=response.tokens, prompt_length=response.prompt_length, reward=reward, logprobs=response.logprobs, @@ -275,7 +275,7 @@ class ExampleWorkflow(Workflow): # construct Experience experiences.append( Experience( - token_ids=response.tokens, + tokens=response.tokens, prompt_length=response.prompt_length, reward=reward, logprobs=response.logprobs, diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index c591cb6b05..11526c223f 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -42,7 +42,7 @@ async def test_queue_buffer(self, name, use_priority_queue): self.assertEqual(await writer.acquire(), 1) exps = [ Experience( - token_ids=torch.tensor([float(j) for j in range(i + 1)]), + tokens=torch.tensor([float(j) for j in range(i + 1)]), prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), @@ -59,7 +59,7 @@ async def test_queue_buffer(self, name, use_priority_queue): print(f"finish read {self.read_batch_size} experience") exps = [ Experience( - token_ids=torch.tensor([float(j) for j in range(i + 1)]), + tokens=torch.tensor([float(j) for j in range(i + 1)]), reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), @@ -99,7 +99,7 @@ async def test_priority_queue_capacity(self): writer.write( [ Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), @@ -163,12 +163,12 @@ async def test_priority_queue_buffer_reuse(self): writer.write( [ Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": i, "use_count": 0}, ), @@ -180,12 +180,12 @@ def replace_call(): writer.write( [ Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0}, ), Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0}, ), diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 389c5195be..7d43d04168 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -34,7 +34,7 @@ async def test_create_sql_buffer(self) -> None: sql_reader = SQLReader(meta, config) exps = [ Experience( - token_ids=torch.tensor([float(j) for j in range(i + 1)]), + tokens=torch.tensor([float(j) for j in range(i + 1)]), prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), @@ -52,7 +52,7 @@ async def test_create_sql_buffer(self) -> None: sql_writer.write( [ Experience( - token_ids=torch.tensor([float(j) for j in range(i + 1)]), + tokens=torch.tensor([float(j) for j in range(i + 1)]), reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 2b431fc71e..be12219387 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -42,33 +42,31 @@ def test_eid_properties(self): class TestExperience(unittest.TestCase): def test_single_turn_experience(self): - token_ids = torch.tensor([10, 11, 12], dtype=torch.int32) + tokens = torch.tensor([10, 11, 12], dtype=torch.int32) logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) - exp = Experience(token_ids=token_ids, logprobs=logprobs, reward=1.0, prompt_length=1) + exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1) self.assertEqual(exp.experience_type.name, "SINGLE_TURN") - self.assertTrue(torch.equal(exp.token_ids, token_ids)) + self.assertTrue(torch.equal(exp.tokens, tokens)) self.assertTrue(torch.equal(exp.logprobs, logprobs)) self.assertEqual(exp.reward, 1.0) self.assertEqual(exp.prompt_length, 1) self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool))) def test_multi_turn_experience(self): - token_ids = torch.tensor([1, 2, 3, 4]) + tokens = torch.tensor([1, 2, 3, 4]) logprobs = torch.tensor([0.1, 0.2, 0.3, 0.4]) action_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool) - exp = Experience( - token_ids=token_ids, logprobs=logprobs, reward=2.0, action_mask=action_mask - ) + exp = Experience(tokens=tokens, logprobs=logprobs, reward=2.0, action_mask=action_mask) self.assertEqual(exp.experience_type.name, "MULTI_TURN") self.assertTrue(torch.equal(exp.action_mask, action_mask)) self.assertEqual(exp.prompt_length, 1) def test_dpo_experience(self): - token_ids = torch.tensor([1, 2]) + tokens = torch.tensor([1, 2]) chosen_ids = torch.tensor([3, 4]) rejected_ids = torch.tensor([5, 6]) exp = Experience( - token_ids=token_ids, chosen_ids=chosen_ids, rejected_ids=rejected_ids, reward=0.5 + tokens=tokens, chosen_ids=chosen_ids, rejected_ids=rejected_ids, reward=0.5 ) self.assertEqual(exp.experience_type.name, "DPO") self.assertTrue(torch.equal(exp.chosen_ids, chosen_ids)) @@ -76,19 +74,19 @@ def test_dpo_experience(self): self.assertEqual(exp.prompt_length, 2) def test_serialize_deserialize(self): - token_ids = torch.tensor([1, 2, 3]) - exp = Experience(token_ids=token_ids, reward=1.23, prompt_length=1) + tokens = torch.tensor([1, 2, 3]) + exp = Experience(tokens=tokens, reward=1.23, prompt_length=1) data = exp.serialize() exp2 = Experience.deserialize(data) - self.assertTrue(torch.equal(exp.token_ids, exp2.token_ids)) + self.assertTrue(torch.equal(exp.tokens, exp2.tokens)) self.assertEqual(exp.reward, exp2.reward) self.assertEqual(exp.prompt_length, exp2.prompt_length) self.assertEqual(exp.experience_type, exp2.experience_type) def test_to_dict(self): - token_ids = torch.tensor([1, 2, 3]) + tokens = torch.tensor([1, 2, 3]) exp = Experience( - token_ids=token_ids, reward=2.5, prompt_length=1, prompt_text="hi", response_text="yo" + tokens=tokens, reward=2.5, prompt_length=1, prompt_text="hi", response_text="yo" ) d = exp.to_dict() self.assertIn("eid", d) @@ -101,47 +99,47 @@ def test_to_dict(self): def test_gather(self): # test empty gathering batch = Experiences.gather_experiences([]) - self.assertEqual(batch.token_ids.numel(), 0) + self.assertEqual(batch.tokens.numel(), 0) self.assertEqual(batch.rewards.numel(), 0) self.assertEqual(batch.eids, []) # test single experience gathering - exp = Experience(token_ids=torch.tensor([1, 2, 3]), reward=1.0, prompt_length=1) + exp = Experience(tokens=torch.tensor([1, 2, 3]), reward=1.0, prompt_length=1) batch = Experiences.gather_experiences([exp]) self.assertEqual(batch.batch_size, 1) self.assertTrue( - torch.equal(batch.token_ids[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) + torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) ) self.assertEqual(batch.prompt_length, 1) self.assertEqual(batch.rewards[0], 1.0) # test multiple experiences gathering exps = [ - Experience(token_ids=torch.tensor([1, 2]), reward=0.1, prompt_length=1), - Experience(token_ids=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2), + Experience(tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1), + Experience(tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2), ] batch = Experiences.gather_experiences(exps) self.assertEqual(batch.batch_size, 2) self.assertEqual(batch.prompt_length, 2) - self.assertEqual(batch.token_ids.shape[1], 3) + self.assertEqual(batch.tokens.shape[1], 3) self.assertEqual(batch.rewards[0], 0.1) self.assertEqual(batch.rewards[1], 0.2) def test_action_mask_and_logprobs_type(self): - exp = Experience(token_ids=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) - self.assertIsInstance(exp.token_ids, torch.Tensor) + exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) + self.assertIsInstance(exp.tokens, torch.Tensor) self.assertIsInstance(exp.logprobs, torch.Tensor) self.assertIsInstance(exp.action_mask, torch.Tensor) def test_assertions(self): # prompt_length must be > 0 with self.assertRaises(AssertionError): - Experience(token_ids=[1, 2, 3], prompt_length=0) - # token_ids must be longer than prompt_length for single-turn + Experience(tokens=[1, 2, 3], prompt_length=0) + # tokens must be longer than prompt_length for single-turn with self.assertRaises(AssertionError): - Experience(token_ids=[1, 2], prompt_length=2) - # DPO: token_ids must match prompt_length - exp = Experience(token_ids=[1, 2], chosen_ids=[3], rejected_ids=[4], prompt_length=1) + Experience(tokens=[1, 2], prompt_length=2) + # DPO: tokens must match prompt_length + exp = Experience(tokens=[1, 2], chosen_ids=[3], rejected_ids=[4], prompt_length=1) exp.prompt_length = 2 # should automatically adjust @@ -150,12 +148,12 @@ class TestExperienceConversion(unittest.TestCase): def test_experience_model_experience_conversion(self): """Test the conversion between Experience and ExperienceModel""" - token_ids = torch.tensor([1, 2, 3], dtype=torch.int32) + tokens = torch.tensor([1, 2, 3], dtype=torch.int32) reward = 0.6 prompt_length = 2 logprobs = torch.tensor([0, 0, 0.1], dtype=torch.float32) experience = Experience( - token_ids=token_ids, + tokens=tokens, reward=reward, prompt_length=prompt_length, logprobs=logprobs, @@ -163,7 +161,7 @@ def test_experience_model_experience_conversion(self): model = ExperienceModel.from_experience(experience) new_experience = model.to_experience() - self.assertTrue(torch.equal(new_experience.token_ids, token_ids)) + self.assertTrue(torch.equal(new_experience.tokens, tokens)) self.assertEqual(new_experience.prompt_length, prompt_length) self.assertEqual(new_experience.reward, reward) self.assertTrue(torch.equal(new_experience.logprobs, logprobs)) @@ -172,13 +170,13 @@ def test_experience_model_experience_conversion(self): def test_batch_conversion(self): exps = [ Experience( - token_ids=torch.tensor([1, 2]), + tokens=torch.tensor([1, 2]), prompt_length=1, reward=float(0.1), logprobs=torch.tensor([0, 0.1]), ), Experience( - token_ids=torch.tensor([1, 2, 3]), + tokens=torch.tensor([1, 2, 3]), prompt_length=2, reward=float(0.2), logprobs=torch.tensor([0, 0, 0.1]), @@ -192,13 +190,13 @@ def test_batch_conversion(self): self.assertEqual(batch.rewards[i], exps[i].reward) self.assertTrue( torch.all( - batch.token_ids[i][ + batch.tokens[i][ prompt_length - exps[i].prompt_length : prompt_length - exps[i].prompt_length - + exps[i].token_ids.size(0) + + exps[i].tokens.size(0) ] - == exps[i].token_ids + == exps[i].tokens ) ) self.assertTrue( @@ -206,7 +204,7 @@ def test_batch_conversion(self): batch.logprobs[i][ prompt_length - exps[i].prompt_length : prompt_length - + exps[i].token_ids.size(0) + + exps[i].tokens.size(0) - exps[i].prompt_length ] == exps[i].logprobs @@ -227,13 +225,13 @@ def test_batch_conversion(self): def test_multiturn_experience_batch_converstion(self): exps = [ Experience( - token_ids=torch.tensor([1, 2, 3, 4]), + tokens=torch.tensor([1, 2, 3, 4]), reward=float(0.3), logprobs=torch.tensor([0, 0, 0.1, 0.2]), action_mask=torch.tensor([1, 0, 1, 0]), ), Experience( - token_ids=torch.tensor([1, 2, 3, 4]), + tokens=torch.tensor([1, 2, 3, 4]), reward=float(0.4), logprobs=torch.tensor([0, 0, 0, 0.1]), action_mask=torch.tensor([1, 0, 0, 1]), @@ -247,13 +245,13 @@ def test_multiturn_experience_batch_converstion(self): self.assertEqual(batch.rewards[i], exps[i].reward) self.assertTrue( torch.all( - batch.token_ids[i][ + batch.tokens[i][ prompt_length - exps[i].prompt_length : prompt_length - exps[i].prompt_length - + exps[i].token_ids.size(0) + + exps[i].tokens.size(0) ] - == exps[i].token_ids + == exps[i].tokens ) ) self.assertTrue( @@ -261,7 +259,7 @@ def test_multiturn_experience_batch_converstion(self): batch.logprobs[i][ prompt_length - exps[i].prompt_length : prompt_length - + exps[i].token_ids.size(0) + + exps[i].tokens.size(0) - exps[i].prompt_length ] == exps[i].logprobs @@ -282,12 +280,12 @@ def test_multiturn_experience_batch_converstion(self): def test_dpo_experience_batch_conversion(self): exps = [ Experience( - token_ids=torch.tensor([1, 2]), + tokens=torch.tensor([1, 2]), chosen_ids=torch.tensor([3, 4]), rejected_ids=torch.tensor([5, 6]), ), Experience( - token_ids=torch.tensor([7, 8, 9]), + tokens=torch.tensor([7, 8, 9]), chosen_ids=torch.tensor([10, 11]), rejected_ids=torch.tensor([12, 13]), ), @@ -300,13 +298,13 @@ def test_dpo_experience_batch_conversion(self): j = i // 2 self.assertTrue( torch.all( - batch.token_ids[i][ + batch.tokens[i][ prompt_length - exps[j].prompt_length : prompt_length - exps[j].prompt_length - + exps[j].token_ids.size(0) + + exps[j].tokens.size(0) ] - == exps[j].token_ids + == exps[j].tokens ) ) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b3dcc025ee..d84fba885d 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -130,7 +130,7 @@ async def test_generate( self.assertEqual(len(history_experiences), len(generate_results)) for exp, history_exp in zip(generate_results, history_experiences): self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.token_ids.tolist(), history_exp.token_ids.tolist()) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) self.assertEqual(exp.prompt_length, history_exp.prompt_length) self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) else: @@ -155,7 +155,7 @@ async def test_generate( self.assertEqual(len(history_experiences) - len(generate_results), len(results)) for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.token_ids.tolist(), history_exp.token_ids.tolist()) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) self.assertEqual(exp.prompt_length, history_exp.prompt_length) self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) for result in results: @@ -164,10 +164,10 @@ async def test_generate( self.assertTrue(torch.all(input_logprobs == 0)) self.assertTrue(torch.any(output_logprobs != 0)) if self.use_async: - logprobs = await self.model_wrapper.logprobs_async(results[0].token_ids.tolist()) + logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) else: - logprobs = self.model_wrapper.logprobs(results[0].token_ids.tolist()) - self.assertEqual(logprobs.shape[0], results[0].token_ids.shape[0]) + logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) + self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() self.assertTrue(len(history_experiences) == 0) @@ -191,7 +191,7 @@ async def test_generate( return_dict=True, ) self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) - self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.token_ids)) + self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() @@ -242,8 +242,8 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) - self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(hasattr(response.choices[0], "tokens")) + self.assertTrue(len(response.choices[0].tokens) > 0) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 3) response = openai_client.chat.completions.create( @@ -261,8 +261,8 @@ def test_api(self): model=model_id, messages=messages, n=2 ) self.assertEqual(2, len(response.choices)) - self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(hasattr(response.choices[0], "tokens")) + self.assertTrue(len(response.choices[0].tokens) > 0) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) @@ -284,7 +284,7 @@ def test_assistant_token_mask(self): }, ] tokenizer = AutoTokenizer.from_pretrained(get_model_path()) - token_ids, action_mask = tokenize_and_mask_messages_default( + tokens, action_mask = tokenize_and_mask_messages_default( tokenizer=tokenizer, messages=messages, chat_template=CHAT_TEMPLATE, @@ -294,7 +294,7 @@ def test_assistant_token_mask(self): messages=messages, chat_template=CHAT_TEMPLATE, ) - self.assertEqual(token_ids.shape, token_ids_hf.shape) + self.assertEqual(tokens.shape, token_ids_hf.shape) self.assertEqual(action_mask.shape, action_mask_hf.shape) - self.assertTrue(torch.equal(token_ids, token_ids_hf)) + self.assertTrue(torch.equal(tokens, token_ids_hf)) self.assertTrue(torch.equal(action_mask, action_mask_hf)) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 83d4656fbe..6a88e2a125 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -43,7 +43,7 @@ def run(self) -> List[Experience]: return [ Experience( - token_ids=torch.zeros(5), + tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success", info={"repeat_times": self.repeat_times}, diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index a750988a5f..80a4af7d49 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -57,7 +57,7 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.token_ids, dtype=torch.float32) + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) if exp.info is None: exp.info = {} exp.info["is_expert"] = True @@ -94,8 +94,8 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): "uid": np.array(experiences.group_ids), "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, - "input_ids": experiences.token_ids.long(), - "responses": experiences.token_ids[:, experiences.prompt_length :].long(), + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index 8bd868c22c..f9df00ee4e 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -16,8 +16,8 @@ def to_data_proto(experiences: Experiences) -> DataProto: "uid": np.array(experiences.group_ids), "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, - "input_ids": experiences.token_ids.long(), - "responses": experiences.token_ids[:, experiences.prompt_length :].long(), + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index e2ec4a69b9..1121ebfa63 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -129,14 +129,14 @@ def read( if self.prompt_type == PromptType.MESSAGES: for sample in samples: messages = sample[self.messages_key] - token_ids = self.tokenizer.apply_chat_template( + tokens = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] prompt_tokens_ids = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( - token_ids=token_ids, + tokens=tokens, prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -151,7 +151,7 @@ def read( response_messages = [response_messages] full_messages = prompt_messages + response_messages - token_ids = self.tokenizer.apply_chat_template( + tokens = self.tokenizer.apply_chat_template( full_messages, add_generation_prompt=False, return_tensors="pt" )[0] @@ -160,7 +160,7 @@ def read( )[0] experience = Experience( - token_ids=token_ids, + tokens=tokens, prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -170,10 +170,10 @@ def read( for sample in samples: prompt = sample[self.prompt_key] response = sample[self.response_key] - token_ids = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] + tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] prompt_tokens_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] experience = Experience( - token_ids=token_ids, + tokens=tokens, prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -251,7 +251,7 @@ def read( return_tensors="pt", )[0][prompt_length:] experience = Experience( - token_ids=prompt_tokens, + tokens=prompt_tokens, chosen_ids=chosen_tokens, rejected_ids=rejected_tokens, ) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index f41c8038ef..9fe128c690 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -85,13 +85,13 @@ def from_messages( """Convert a list of messages into a single instance of SFT data.""" from trinity.common.models.utils import tokenize_and_mask_messages_hf - token_ids, action_mask = tokenize_and_mask_messages_hf( + tokens, action_mask = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, chat_template=chat_template, ) exp = Experience( - token_ids=token_ids, + tokens=tokens, action_mask=action_mask, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, ) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 81a7398e70..aaae2c59e3 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -83,7 +83,7 @@ class ExperienceType(Enum): @dataclass class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience - token_ids: Optional[Tensor] = None # [seq_length] + tokens: Optional[Tensor] = None # [seq_length] logprobs: Optional[Tensor] = None # [seq_length] reward: Optional[float] = None # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected_ids @@ -116,7 +116,7 @@ def __init__( self, *, eid=None, - token_ids, + tokens, logprobs=None, reward=None, info=None, @@ -142,18 +142,16 @@ def __init__( assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." - assert ( - len(token_ids) > prompt_length - ), "Token ids must be longer than the prompt length." - action_mask = torch.zeros(len(token_ids), dtype=torch.bool) + assert len(tokens) > prompt_length, "Token ids must be longer than the prompt length." + action_mask = torch.zeros(len(tokens), dtype=torch.bool) action_mask[prompt_length:] = 1 elif experience_type == ExperienceType.MULTI_TURN: prompt_length = 1 elif experience_type == ExperienceType.DPO: - prompt_length = len(token_ids) + prompt_length = len(tokens) self.eid = eid or EID() - self.token_ids = token_ids + self.tokens = tokens self.logprobs = logprobs self.reward = reward self.experience_type = experience_type @@ -169,8 +167,8 @@ def __init__( self.chosen_text = chosen_text self.rejected_text = rejected_text - if not isinstance(self.token_ids, Tensor): - self.token_ids = torch.tensor(self.token_ids) + if not isinstance(self.tokens, Tensor): + self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): @@ -218,13 +216,11 @@ def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experie if exp_type == ExperienceType.DPO: experiences = split_dpo_experience_to_single_turn(experiences) max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] - max_response_length = max([len(exp.token_ids) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] + max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] eids = [exp.eid for exp in experiences] - # Gather token_ids - token_ids = gather_token_ids( - experiences, max_prompt_length, max_response_length, pad_token_id - ) + # Gather tokens + tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) # Gather rewards if experiences[0].reward is not None: @@ -249,7 +245,7 @@ def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experie return Experiences( eids=eids, - token_ids=token_ids, + tokens=tokens, rewards=rewards, attention_masks=attention_masks, action_masks=action_masks, @@ -269,11 +265,11 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E step=exp.eid.step, run=exp.eid.run, ), - token_ids=torch.cat([exp.token_ids, exp.chosen_ids]), + tokens=torch.cat([exp.tokens, exp.chosen_ids]), reward=exp.reward, info=exp.info, metrics=exp.metrics, - prompt_length=len(exp.token_ids), # type: ignore [arg-type] + prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.chosen_text, ) @@ -286,11 +282,11 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E step=exp.eid.step, run=exp.eid.run, ), - token_ids=torch.cat([exp.token_ids, exp.rejected_ids]), + tokens=torch.cat([exp.tokens, exp.rejected_ids]), reward=exp.reward, info=exp.info, metrics=exp.metrics, - prompt_length=len(exp.token_ids), # type: ignore [arg-type] + prompt_length=len(exp.tokens), # type: ignore [arg-type] prompt_text=exp.prompt_text, response_text=exp.rejected_text, ) @@ -305,7 +301,7 @@ class Experiences: Example: >>> |<- prompt_length ->| | - >>> token_ids: ('P' represents prompt, 'O' represents output) + >>> tokens: ('P' represents prompt, 'O' represents output) >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| >>> @@ -315,7 +311,7 @@ class Experiences: """ eids: List[EID] # Experience IDs of each experience in the batch - token_ids: Tensor + tokens: Tensor rewards: Tensor attention_masks: Tensor action_masks: Optional[Tensor] @@ -325,7 +321,7 @@ class Experiences: @property def batch_size(self) -> int: """Get the batch size.""" - return self.token_ids.size(0) + return self.tokens.size(0) @classmethod def gather_experiences( @@ -342,7 +338,7 @@ def gather_experiences( def empty_experiences() -> Experiences: return Experiences( - token_ids=torch.empty(0, dtype=torch.int32), + tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), @@ -355,7 +351,7 @@ def empty_experiences() -> Experiences: def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: - token_ids_dtype = experiences[0].token_ids.dtype + token_ids_dtype = experiences[0].tokens.dtype return torch.stack( [ torch.cat( @@ -365,9 +361,9 @@ def gather_token_ids( pad_token_id, dtype=token_ids_dtype, ), - exp.token_ids, + exp.tokens, torch.full( - (max_response_length + exp.prompt_length - len(exp.token_ids),), + (max_response_length + exp.prompt_length - len(exp.tokens),), pad_token_id, dtype=token_ids_dtype, ), @@ -390,7 +386,7 @@ def gather_action_masks(experiences, max_prompt_length: int, max_response_length ), exp.action_mask, torch.full( - (max_response_length + exp.prompt_length - len(exp.token_ids),), + (max_response_length + exp.prompt_length - len(exp.tokens),), 0, dtype=torch.bool, ), @@ -408,7 +404,7 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len for i, exp in enumerate(experiences): start = max_prompt_length - exp.prompt_length - end = start + len(exp.token_ids) + end = start + len(exp.tokens) attention_masks[i, start:end] = 1 return attention_masks @@ -427,7 +423,7 @@ def gather_logprobs(experiences, max_prompt_length: int, max_response_length: in ), exp.logprobs, torch.full( - (max_response_length + exp.prompt_length - len(exp.token_ids),), + (max_response_length + exp.prompt_length - len(exp.tokens),), 0.0, dtype=logprob_dtype, ), diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index 438636e35e..559c00578c 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -1,7 +1,7 @@ """Patch for vllm OpenAI API server. 1. Mocks the `add_signal_handler` method to do nothing. -2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. +2. Adds `tokens` and `prompt_token_ids` to the `ChatCompletionResponse`. """ import asyncio import functools @@ -40,7 +40,7 @@ class PatchedChatCompletionResponseChoice(ChatCompletionResponseChoice): - token_ids: list[int] = Field(default_factory=list) + tokens: list[int] = Field(default_factory=list) class PatchedChatCompletionResponse(ChatCompletionResponse): @@ -78,13 +78,13 @@ async def chat_completion_full_generator( # noqa C901 role = self.get_chat_request_role(request) for output in final_res.outputs: - token_ids = output.token_ids + tokens = output.tokens out_logprobs = output.logprobs if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( - token_ids=token_ids, + tokens=tokens, top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, tokenizer=tokenizer, @@ -219,7 +219,7 @@ async def chat_completion_full_generator( # noqa C901 if output.finish_reason else "stop", stop_reason=output.stop_reason, - token_ids=output.token_ids, + tokens=output.tokens, ) choices.append(choice_data) @@ -238,7 +238,7 @@ async def chat_completion_full_generator( # noqa C901 num_prompt_tokens = len(final_res.prompt_token_ids) if final_res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(final_res.encoder_prompt_token_ids) - num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + num_generated_tokens = sum(len(output.tokens) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 61f6a1af8c..c5dd27193c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -190,10 +190,10 @@ def convert_api_output_to_experience( """Convert the API output to a list of experiences.""" return [ Experience( - token_ids=torch.cat( + tokens=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(choice.token_ids, dtype=torch.int32), + torch.tensor(choice.tokens, dtype=torch.int32), ) ), logprobs=torch.cat( diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 087b190e86..05d072ccb1 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -25,7 +25,7 @@ def tokenize_and_mask_messages_hf( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) + Tuple[torch.Tensor, torch.Tensor]: The tokens (sequence_length) and assistant_masks (sequence_length). """ token_dict = tokenizer.apply_chat_template( @@ -55,7 +55,7 @@ def tokenize_and_mask_messages_default( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) + Tuple[torch.Tensor, torch.Tensor]: The tokens (sequence_length) and assistant_masks (sequence_length). Note: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 5bdf987c8b..bcd88d9cf0 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -146,10 +146,10 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: output = await self._generate_internal(prompt=prompt, **kwargs) experiences = [ Experience( - token_ids=torch.cat( + tokens=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), + torch.tensor(output.outputs[i].tokens, dtype=torch.int32), ) ), logprobs=torch.cat( @@ -176,10 +176,10 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: ] return experiences - async def logprobs(self, token_ids: List[int]) -> torch.Tensor: + async def logprobs(self, tokens: List[int]) -> torch.Tensor: """Calculate the logprobs of the given tokens in async.""" output = await self._generate_internal( - prompt={"prompt_token_ids": token_ids}, + prompt={"prompt_token_ids": tokens}, n=1, max_tokens=1, prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token @@ -217,12 +217,10 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.tokenizer = await self.async_llm.get_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask = self.action_mask_method( - self.tokenizer, messages, self.chat_template - ) - logprobs = await self.logprobs(token_ids=token_ids.tolist()) + tokens, action_mask = self.action_mask_method(self.tokenizer, messages, self.chat_template) + logprobs = await self.logprobs(tokens=tokens.tolist()) return Experience( - token_ids=token_ids, + tokens=tokens, logprobs=logprobs, action_mask=action_mask, ) diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index 4b04524a74..c2762ae43c 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -81,7 +81,7 @@ def run(self) -> List[Experience]: truth=self.truth, with_think=self.with_think, format_score_coef=self.format_score_coef, - response_token=response.token_ids[response.prompt_length :], + response_token=response.tokens[response.prompt_length :], ) if response.metrics is None: diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index d2418fd99f..ceabdc771a 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -133,13 +133,13 @@ def run(self) -> List[Experience]: def process_messages_to_experience(self, messages, reward, info={}) -> Experience: converted_experience = self.model.convert_messages_to_experience(messages) - token_ids = converted_experience.token_ids + tokens = converted_experience.tokens log_probs = converted_experience.logprobs assert converted_experience.action_mask is not None generation_mask = converted_experience.action_mask log_probs = log_probs * generation_mask - assert token_ids.shape == log_probs.shape + assert tokens.shape == log_probs.shape metrics = {} for k, v in info.items(): @@ -147,7 +147,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc metrics[k] = float(v) experience = Experience( - token_ids=token_ids, + tokens=tokens, action_mask=generation_mask, reward=reward, logprobs=log_probs, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index a2c195d146..2198f4d2d1 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -405,10 +405,10 @@ def _log_single_experience( ) -> None: reward = experiences.rewards[idx] attn_mask = experiences.attention_masks[idx].bool() - prompt_token = experiences.token_ids[idx][: experiences.prompt_length][ + prompt_token = experiences.tokens[idx][: experiences.prompt_length][ attn_mask[: experiences.prompt_length] ] - response_token = experiences.token_ids[idx][experiences.prompt_length :][ + response_token = experiences.tokens[idx][experiences.prompt_length :][ attn_mask[experiences.prompt_length :] ] prompt_text = self.tokenizer.decode(prompt_token, skip_special_tokens=skip_special_tokens) From d8045ca332f0bf57dbd67e1c3b7c019c7f8bf6bf Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 18:31:14 +0800 Subject: [PATCH 09/21] fix naming --- tests/common/experience_test.py | 22 +++++++-------- tests/common/vllm_test.py | 14 +++++----- trinity/buffer/reader/file_reader.py | 4 +-- trinity/buffer/schema/sql_schema.py | 4 +-- trinity/common/experience.py | 36 ++++++++++++------------- trinity/common/models/api/vllm_patch.py | 10 +++---- trinity/common/models/model.py | 2 +- trinity/common/models/utils.py | 4 +-- trinity/common/models/vllm_model.py | 14 +++++----- 9 files changed, 55 insertions(+), 55 deletions(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index be12219387..44f96e16f7 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -63,14 +63,12 @@ def test_multi_turn_experience(self): def test_dpo_experience(self): tokens = torch.tensor([1, 2]) - chosen_ids = torch.tensor([3, 4]) - rejected_ids = torch.tensor([5, 6]) - exp = Experience( - tokens=tokens, chosen_ids=chosen_ids, rejected_ids=rejected_ids, reward=0.5 - ) + chosen = torch.tensor([3, 4]) + rejected = torch.tensor([5, 6]) + exp = Experience(tokens=tokens, chosen=chosen, rejected=rejected, reward=0.5) self.assertEqual(exp.experience_type.name, "DPO") - self.assertTrue(torch.equal(exp.chosen_ids, chosen_ids)) - self.assertTrue(torch.equal(exp.rejected_ids, rejected_ids)) + self.assertTrue(torch.equal(exp.chosen, chosen)) + self.assertTrue(torch.equal(exp.rejected, rejected)) self.assertEqual(exp.prompt_length, 2) def test_serialize_deserialize(self): @@ -139,7 +137,7 @@ def test_assertions(self): with self.assertRaises(AssertionError): Experience(tokens=[1, 2], prompt_length=2) # DPO: tokens must match prompt_length - exp = Experience(tokens=[1, 2], chosen_ids=[3], rejected_ids=[4], prompt_length=1) + exp = Experience(tokens=[1, 2], chosen=[3], rejected=[4], prompt_length=1) exp.prompt_length = 2 # should automatically adjust @@ -281,13 +279,13 @@ def test_dpo_experience_batch_conversion(self): exps = [ Experience( tokens=torch.tensor([1, 2]), - chosen_ids=torch.tensor([3, 4]), - rejected_ids=torch.tensor([5, 6]), + chosen=torch.tensor([3, 4]), + rejected=torch.tensor([5, 6]), ), Experience( tokens=torch.tensor([7, 8, 9]), - chosen_ids=torch.tensor([10, 11]), - rejected_ids=torch.tensor([12, 13]), + chosen=torch.tensor([10, 11]), + rejected=torch.tensor([12, 13]), ), ] batch = Experiences.gather_experiences(exps) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index d84fba885d..4e61ec8f6d 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -242,8 +242,8 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) - self.assertTrue(hasattr(response.choices[0], "tokens")) - self.assertTrue(len(response.choices[0].tokens) > 0) + self.assertTrue(hasattr(response.choices[0], "tokens_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 3) response = openai_client.chat.completions.create( @@ -261,8 +261,8 @@ def test_api(self): model=model_id, messages=messages, n=2 ) self.assertEqual(2, len(response.choices)) - self.assertTrue(hasattr(response.choices[0], "tokens")) - self.assertTrue(len(response.choices[0].tokens) > 0) + self.assertTrue(hasattr(response.choices[0], "token_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) @@ -284,7 +284,7 @@ def test_assistant_token_mask(self): }, ] tokenizer = AutoTokenizer.from_pretrained(get_model_path()) - tokens, action_mask = tokenize_and_mask_messages_default( + token_ids, action_mask = tokenize_and_mask_messages_default( tokenizer=tokenizer, messages=messages, chat_template=CHAT_TEMPLATE, @@ -294,7 +294,7 @@ def test_assistant_token_mask(self): messages=messages, chat_template=CHAT_TEMPLATE, ) - self.assertEqual(tokens.shape, token_ids_hf.shape) + self.assertEqual(token_ids.shape, token_ids_hf.shape) self.assertEqual(action_mask.shape, action_mask_hf.shape) - self.assertTrue(torch.equal(tokens, token_ids_hf)) + self.assertTrue(torch.equal(token_ids, token_ids_hf)) self.assertTrue(torch.equal(action_mask, action_mask_hf)) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 1121ebfa63..de3481224c 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -252,8 +252,8 @@ def read( )[0][prompt_length:] experience = Experience( tokens=prompt_tokens, - chosen_ids=chosen_tokens, - rejected_ids=rejected_tokens, + chosen=chosen_tokens, + rejected=rejected_tokens, ) exp_list.append(experience) return exp_list diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 9fe128c690..5ac3da2666 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -119,8 +119,8 @@ class DPODataModel(Base): # type: ignore def to_experience(self) -> Experience: """Load the experience from the database.""" exp = Experience.deserialize(self.serialized_exp) - exp.chosen_ids = Experience.deserialize(self.chosen) - exp.rejected_ids = Experience.deserialize(self.rejected) + exp.chosen = Experience.deserialize(self.chosen) + exp.rejected = Experience.deserialize(self.rejected) return exp diff --git a/trinity/common/experience.py b/trinity/common/experience.py index aaae2c59e3..c23c7b94e9 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -86,7 +86,7 @@ class Experience: tokens: Optional[Tensor] = None # [seq_length] logprobs: Optional[Tensor] = None # [seq_length] reward: Optional[float] = None - # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected_ids + # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected experience_type: ExperienceType = ExperienceType.SINGLE_TURN info: Optional[dict] = field( default_factory=dict @@ -107,8 +107,8 @@ class Experience: messages: Optional[List[dict]] = None # List of messages # for dpo experiences - chosen_ids: Optional[Tensor] = None # Token ids of the chosen response [resp_length] - rejected_ids: Optional[Tensor] = None # Token ids of the rejected response [resp_length] + chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] + rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] chosen_text: Optional[str] = None # Text of the chosen response rejected_text: Optional[str] = None # Text of the rejected response @@ -126,14 +126,14 @@ def __init__( prompt_text=None, action_mask=None, messages=None, - chosen_ids=None, - rejected_ids=None, + chosen=None, + rejected=None, chosen_text=None, rejected_text=None, ): if action_mask is not None: experience_type = ExperienceType.MULTI_TURN - elif chosen_ids is not None and rejected_ids is not None: + elif chosen is not None and rejected is not None: experience_type = ExperienceType.DPO else: experience_type = ExperienceType.SINGLE_TURN @@ -162,8 +162,8 @@ def __init__( self.prompt_text = prompt_text self.action_mask = action_mask self.messages = messages - self.chosen_ids = chosen_ids - self.rejected_ids = rejected_ids + self.chosen = chosen + self.rejected = rejected self.chosen_text = chosen_text self.rejected_text = rejected_text @@ -173,10 +173,10 @@ def __init__( self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): self.action_mask = torch.tensor(self.action_mask) - if self.chosen_ids is not None and not isinstance(self.chosen_ids, Tensor): - self.chosen_ids = torch.tensor(self.chosen_ids) - if self.rejected_ids is not None and not isinstance(self.rejected_ids, Tensor): - self.rejected_ids = torch.tensor(self.rejected_ids) + if self.chosen is not None and not isinstance(self.chosen, Tensor): + self.chosen = torch.tensor(self.chosen) + if self.rejected is not None and not isinstance(self.rejected, Tensor): + self.rejected = torch.tensor(self.rejected) def serialize(self) -> bytes: """Serialize the experience to bytes.""" @@ -200,10 +200,10 @@ def to_dict(self) -> dict: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages - if self.chosen_ids is not None: - res["chosen_ids"] = self.chosen_ids.tolist() - if self.rejected_ids is not None: - res["rejected_ids"] = self.rejected_ids.tolist() + if self.chosen is not None: + res["chosen"] = self.chosen.tolist() + if self.rejected is not None: + res["rejected"] = self.rejected.tolist() if self.reward is not None: res["reward"] = float(self.reward) return res @@ -265,7 +265,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E step=exp.eid.step, run=exp.eid.run, ), - tokens=torch.cat([exp.tokens, exp.chosen_ids]), + tokens=torch.cat([exp.tokens, exp.chosen]), reward=exp.reward, info=exp.info, metrics=exp.metrics, @@ -282,7 +282,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E step=exp.eid.step, run=exp.eid.run, ), - tokens=torch.cat([exp.tokens, exp.rejected_ids]), + tokens=torch.cat([exp.tokens, exp.rejected]), reward=exp.reward, info=exp.info, metrics=exp.metrics, diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index 559c00578c..ec19e9871e 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -40,7 +40,7 @@ class PatchedChatCompletionResponseChoice(ChatCompletionResponseChoice): - tokens: list[int] = Field(default_factory=list) + token_ids: list[int] = Field(default_factory=list) class PatchedChatCompletionResponse(ChatCompletionResponse): @@ -78,13 +78,13 @@ async def chat_completion_full_generator( # noqa C901 role = self.get_chat_request_role(request) for output in final_res.outputs: - tokens = output.tokens + token_ids = output.token_ids out_logprobs = output.logprobs if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( - tokens=tokens, + token_ids=token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, tokenizer=tokenizer, @@ -219,7 +219,7 @@ async def chat_completion_full_generator( # noqa C901 if output.finish_reason else "stop", stop_reason=output.stop_reason, - tokens=output.tokens, + token_ids=output.token_ids, ) choices.append(choice_data) @@ -238,7 +238,7 @@ async def chat_completion_full_generator( # noqa C901 num_prompt_tokens = len(final_res.prompt_token_ids) if final_res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(final_res.encoder_prompt_token_ids) - num_generated_tokens = sum(len(output.tokens) for output in final_res.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index c5dd27193c..96d700678e 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -193,7 +193,7 @@ def convert_api_output_to_experience( tokens=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(choice.tokens, dtype=torch.int32), + torch.tensor(choice.token_ids, dtype=torch.int32), ) ), logprobs=torch.cat( diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 05d072ccb1..087b190e86 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -25,7 +25,7 @@ def tokenize_and_mask_messages_hf( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The tokens (sequence_length) + Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) and assistant_masks (sequence_length). """ token_dict = tokenizer.apply_chat_template( @@ -55,7 +55,7 @@ def tokenize_and_mask_messages_default( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The tokens (sequence_length) + Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) and assistant_masks (sequence_length). Note: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index bcd88d9cf0..30c3e00b3e 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -149,7 +149,7 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: tokens=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].tokens, dtype=torch.int32), + torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), ) ), logprobs=torch.cat( @@ -176,10 +176,10 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: ] return experiences - async def logprobs(self, tokens: List[int]) -> torch.Tensor: + async def logprobs(self, token_ids: List[int]) -> torch.Tensor: """Calculate the logprobs of the given tokens in async.""" output = await self._generate_internal( - prompt={"prompt_token_ids": tokens}, + prompt={"prompt_token_ids": token_ids}, n=1, max_tokens=1, prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token @@ -217,10 +217,12 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.tokenizer = await self.async_llm.get_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() - tokens, action_mask = self.action_mask_method(self.tokenizer, messages, self.chat_template) - logprobs = await self.logprobs(tokens=tokens.tolist()) + token_ids, action_mask = self.action_mask_method( + self.tokenizer, messages, self.chat_template + ) + logprobs = await self.logprobs(token_ids=token_ids.tolist()) return Experience( - tokens=tokens, + tokens=token_ids, logprobs=logprobs, action_mask=action_mask, ) From fcf93e9dac1692499139ff1b27c7dedd024904de Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 18:32:28 +0800 Subject: [PATCH 10/21] fix naming --- tests/common/vllm_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 4e61ec8f6d..71ccf32b7f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -242,7 +242,7 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) - self.assertTrue(hasattr(response.choices[0], "tokens_ids")) + self.assertTrue(hasattr(response.choices[0], "token_ids")) self.assertTrue(len(response.choices[0].token_ids) > 0) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 3) From 5303505cca92b82c3332fe8bc4127f85d0e9f6ed Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 18:59:34 +0800 Subject: [PATCH 11/21] update dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 621e5a989e..53ae5e7f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.4.1", "ray[default]>=2.45.0", - "vllm==0.9.2", + "vllm>=0.9.1", "tensordict==0.6.2", "wandb", "omegaconf", From 4802e9df43ee9a57fb79493dfc244e0385e178eb Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 19:42:21 +0800 Subject: [PATCH 12/21] fix ids --- docs/sphinx_doc/source/tutorial/example_mix_algo.md | 2 +- tests/utils/plugin_test.py | 4 ++-- tests/utils/plugins/my_workflow.py | 4 ++-- trinity/algorithm/sample_strategy/mix_sample_strategy.py | 4 ++-- trinity/algorithm/sample_strategy/utils.py | 4 ++-- trinity/trainer/verl/fsdp_workers.py | 3 +-- 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 0e08b8648b..63fdc63629 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -142,7 +142,7 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), + "uid": np.array(experiences.eid.uid), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index 01aa2f3967..e35e6eab0d 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -11,7 +11,7 @@ class PluginActor: def run(self): my_plugin_cls = WORKFLOWS.get("my_workflow") - return my_plugin_cls(None, None).run() + return my_plugin_cls(task=None, model=None).run() class TestPluginLoader(unittest.TestCase): @@ -22,7 +22,7 @@ def test_load_plugins(self): load_plugins(Path(__file__).resolve().parent / "plugins") my_plugin_cls = WORKFLOWS.get("my_workflow") self.assertIsNotNone(my_plugin_cls) - my_plugin = my_plugin_cls(None, None, None) + my_plugin = my_plugin_cls(task=None, model=None, auxiliary_models=None) self.assertTrue(my_plugin.__module__.startswith("trinity.plugins")) res = my_plugin.run() self.assertEqual(res[0], "Hello world") diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py index b999590a01..624969ee89 100644 --- a/tests/utils/plugins/my_workflow.py +++ b/tests/utils/plugins/my_workflow.py @@ -5,8 +5,8 @@ @WORKFLOWS.register_module("my_workflow") class MyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models=None): - super().__init__(model, task, auxiliary_models) + def __init__(self, *, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) def run(self) -> List: return ["Hello world", "Hi"] diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 80a4af7d49..5b654900fa 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -91,8 +91,8 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), - "unique_ids": np.array(experiences.unique_ids), + "uid": np.array(experiences.eid.tid), + "unique_ids": np.array(experiences.eid.uid), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index f9df00ee4e..741b7a69e7 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -13,8 +13,8 @@ def to_data_proto(experiences: Experiences) -> DataProto: cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), - "unique_ids": np.array(experiences.unique_ids), + "uid": np.array(experiences.eid.tid), + "unique_ids": np.array(experiences.eid.uid), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 75e6c57f46..00a67ee002 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -27,8 +27,7 @@ import torch import torch.distributed import torch.distributed as dist - -# import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. +import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. from codetiming import Timer from omegaconf import DictConfig, OmegaConf, open_dict from peft import LoraConfig, TaskType, get_peft_model From 5279415fd5f1e567c9f4aab5899ce3e9fe5e6bf4 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 19:58:45 +0800 Subject: [PATCH 13/21] fix uids --- docs/sphinx_doc/source/tutorial/example_mix_algo.md | 3 ++- trinity/algorithm/sample_strategy/mix_sample_strategy.py | 4 ++-- trinity/algorithm/sample_strategy/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 63fdc63629..37e51d3e4c 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -142,7 +142,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.eid.uid), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array(eid.uid for eid in experiences.eids), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 5b654900fa..07e351539c 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -91,8 +91,8 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.eid.tid), - "unique_ids": np.array(experiences.eid.uid), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array(eid.uid for eid in experiences.eids), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index 741b7a69e7..cba97e6d9e 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -13,8 +13,8 @@ def to_data_proto(experiences: Experiences) -> DataProto: cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.eid.tid), - "unique_ids": np.array(experiences.eid.uid), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), From 57f4f830bcaf92132deee4124a1cfdd261eb1b6c Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 20:02:03 +0800 Subject: [PATCH 14/21] record step in workflow --- trinity/common/workflows/step_wise_workflow.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py index 634a54030c..789bc5f596 100644 --- a/trinity/common/workflows/step_wise_workflow.py +++ b/trinity/common/workflows/step_wise_workflow.py @@ -28,7 +28,11 @@ def run(self) -> list[Experience]: # Collect experiences data of the current step exps = self.model.extract_experience_from_history() # Calculate the reward for the current step - exps = self.reward(exps, step_num=step) + reward = self.reward(exps, step_num=step) + for exp in exps: + exp.reward = reward + # set the step number in each experience + exp.eid.step = step # Store the step experiences experiences.extend(exps) if not continue_run: @@ -83,11 +87,16 @@ def run(self) -> list[Experience]: continue_run = self.step(step_num=step) # Collect experiences data of the current step exps = self.model.extract_experience_from_history() + # set the step number in each experience + for exp in exps: + exp.eid.step = step # Store the step experiences experiences.extend(exps) if not continue_run: break - self.reward(experiences) + reward = self.reward(experiences) + for exp in experiences: + exp.reward = reward return experiences @abstractmethod From 2227ed3c4ec5256c556e1fe7fc3ac20f8cf76adc Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 21 Jul 2025 20:46:15 +0800 Subject: [PATCH 15/21] add tests for add strategy --- tests/explorer/explorer_test.py | 53 +++++++++++++++++++ .../algorithm/add_strategy/add_strategy.py | 4 +- trinity/explorer/explorer.py | 13 +++-- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index cd697606cf..5b79d81417 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -11,6 +11,7 @@ get_template_config, get_unittest_dataset_config, ) +from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore @@ -71,3 +72,55 @@ def test_explorer(self): eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + +class TestExplorerWithAddStrategy(BaseExplorerCase): + def test_explorer(self): + import ray + + from trinity.explorer.explorer import Explorer + + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.add_strategy = "random" + self.config.name = f"explore-add-strategy-{datetime.now().strftime('%Y%m%d%H%M%S')}" + # some step may be skipped due to same reward + self.config.algorithm.add_strategy = "reward_variance" + self.config.check_and_update() + explorer = ( + ray.remote(Explorer) + .options( + name=self.config.explorer.name, + namespace=ray.get_runtime_context().namespace, + ) + .remote(self.config) + ) + ray.get(explorer.prepare.remote()) + ray.get(explorer.sync_weight.remote()) + ray.get(explorer.explore.remote()) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) == 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + self.assertTrue(parser.metric_exist("rollout/experience_count")) + experience_counts = parser.metric_values("rollout/experience_count") + self.assertTrue(len(experience_counts) == 4) + for count in experience_counts: + self.assertTrue(count >= 0) + self.assertTrue(count <= 2 * 4) # repeat_times * batch_size + self.assertTrue(count % (2 * 4) == 0) # should be multiple of repeat_times * batch_size + + reader = get_buffer_reader( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + exps = [] + try: + batch = reader.read() + exps.extend(batch) + except StopIteration: + pass + self.assertTrue(len(exps) <= 4 * 2 * 4) # step * repeat_times * batch_size + self.assertTrue(len(exps) % (2 * 4) == 0) # should be multiple of repeat_times * batch_size + ray.get(explorer.shutdown.remote()) diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index 18f91cbb2a..530886b1c2 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -36,11 +36,13 @@ def default_args(cls) -> dict: """ +@ADD_STRATEGY.register_module("reward_variance") class RewardVarianceAddStrategy(AddStrategy): """An example add strategy that filters experiences based on a reward threshold.""" def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: super().__init__(writer) + self.variance_threshold = variance_threshold async def add(self, experiences: List[Experience], step: int) -> int: cnt = 0 @@ -59,7 +61,7 @@ async def add(self, experiences: List[Experience], step: int) -> int: @classmethod def default_args(cls) -> dict: - return {"reward_threshold": 0.0} + return {"variance_threshold": 0.0} def group_by( diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 84dcf596e7..dc2b796246 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -84,7 +84,12 @@ def __init__(self, config: Config): self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 if self.collect_experiences: - self.add_strategy = ADD_STRATEGY.get(self.config.algorithm.add_strategy) + assert ( + self.experience_buffer is not None + ), "Experience buffer is required when collect_experiences is True." + self.add_strategy = ADD_STRATEGY.get(self.config.algorithm.add_strategy)( + self.experience_buffer, **self.config.algorithm.add_strategy_args + ) async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None @@ -349,11 +354,13 @@ async def _finish_steps(self, start_step: int, end_step: int) -> None: async def _finish_explore_step(self, step: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) + metric = {} if self.config.explorer.collect_experiences: - exp_cnt = self.add_strategy.add(exps, step) + exp_cnt = await self.add_strategy.add(exps, step) self.generated_experience_cnt += exp_cnt + metric["rollout/experience_count"] = exp_cnt if statuses: - metric = gather_metrics([status.metric for status in statuses], "rollout") + metric.update(gather_metrics([status.metric for status in statuses], "rollout")) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: From 3b5d8445670477c43c1e26374ef39040cc7c3117 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 10:08:39 +0800 Subject: [PATCH 16/21] fix dpo sample_strategy --- trinity/algorithm/algorithm.py | 2 +- .../sample_strategy/sample_strategy.py | 20 ------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 1f0de5be64..cf2aaa823c 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -145,7 +145,7 @@ class DPOAlgorithm(AlgorithmType): @classmethod def default_config(cls) -> Dict: return { - "sample_strategy": "dpo", + "sample_strategy": "warmup", "policy_loss_fn": "dpo", "kl_loss_fn": "k2", "entropy_loss_fn": "default", diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 6e530d32ce..b923ab17a6 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -120,23 +120,3 @@ def warmup_state(self, step: int) -> Tuple[bool, bool]: @classmethod def default_args(cls) -> dict: return {} - - -@SAMPLE_STRATEGY.register_module("dpo") -class DPOSampleStrategy(WarmupSampleStrategy): - def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: - metrics = {} - with Timer(metrics, "read_time"): - if step <= self.sft_warmup_steps: - exp_list = self.sft_buffer.read() - else: - exp_list = self.exp_buffer.read() - repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): - exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto(exps) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") From 9c7d9986f7c791ab101c29a7e9a756057f782db7 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 10:33:49 +0800 Subject: [PATCH 17/21] add docs --- README.md | 17 +++++++++-------- README_zh.md | 15 ++++++++------- docs/sphinx_doc/source/tutorial/faq.md | 2 +- .../source/tutorial/trinity_configs.md | 3 ++- trinity/common/config.py | 11 ++++++----- trinity/common/models/api/vllm_patch.py | 2 +- 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 0684d5a364..eee0b6ca45 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,12 @@ It is designed to support diverse application scenarios and serve as a unified p ### Step 1: installation +Requirements: +- Python version >= 3.10, <= 3.12 +- CUDA version >= 12.4, <= 12.8 +- At least 2 GPUs + + Installation from source **(recommended)**: ```shell @@ -181,13 +187,15 @@ pip install -e .[flash_attn] # for zsh pip install -e .\[flash_attn\] # Try the following command if you encounter errors during flash-attn installation -# pip install flash-attn -v --no-build-isolation +# pip install flash-attn==2.8.0.post2 -v --no-build-isolation ``` Installation using pip: ```shell pip install trinity-rft==0.2.0 +# install flash-attn separately +pip install flash-attn==2.8.0.post2 ``` Installation from docker: @@ -206,13 +214,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest . docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest ``` - -**Requirements:** -Python version >= 3.10, -CUDA version >= 12.4, -and at least 2 GPUs. - - ### Step 2: prepare dataset and model diff --git a/README_zh.md b/README_zh.md index 0f652aa22f..6d0f8df8f2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -151,6 +151,11 @@ Trinity-RFT是一个通用、灵活且易于使用的大语言模型强化微调 ### 第一步:安装 +环境要求: +- Python >= 3.10, <= 3.12 +- CUDA >= 12.4, <= 12.8 +- 至少 2 块 GPU + 源码安装 **(推荐)**: @@ -181,13 +186,15 @@ pip install -e .[flash_attn] # 适用于 zsh pip install -e .\[flash_attn\] # 如果安装 flash-attn 时遇到错误,可以尝试以下命令 -# pip install flash-attn -v --no-build-isolation +# pip install flash-attn==2.8.0.post2 -v --no-build-isolation ``` 使用 pip 安装: ```shell pip install trinity-rft==0.2.0 +# flash-attn 需要单独安装 +pip install flash-attn==2.8.0.post2 ``` 使用 Docker 安装: @@ -207,12 +214,6 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v = 3.10, -CUDA 版本 >= 12.4, -以及至少 2 块 GPU。 - - ### 第二步:准备数据集和模型 diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md index e84639255c..cc6c3c461b 100644 --- a/docs/sphinx_doc/source/tutorial/faq.md +++ b/docs/sphinx_doc/source/tutorial/faq.md @@ -65,7 +65,7 @@ File ".../flash_attn/flash_attn_interface.py", line 15, in ‹module> ImportError: ... ``` -**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn` or `pip install flash-attn -v --no-build-isolation`. +**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn==2.8.0.post2` or `pip install flash-attn==2.8.0.post2 -v --no-build-isolation`. --- diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 230e55592e..ea4b966d47 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -90,6 +90,7 @@ algorithm: kl_penalty_fn: "none" kl_loss_fn: "k2" entropy_loss_fn: "default" + add_strategy: null ``` - `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`, `sft`, `mix`. @@ -99,7 +100,7 @@ algorithm: - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. - `kl_loss_fn`: The KL loss function used for computing KL loss. - `entropy_loss_fn`: The entropy loss function used for computing entropy loss. - +- `add_strategy`: Strategy for adding new experiences to the experience buffer. If set, explorer will collect experiences from workflow runners and pre-process them before adding to the buffer. --- diff --git a/trinity/common/config.py b/trinity/common/config.py index f8d6fffd24..13b37ac26d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -333,11 +333,6 @@ class ExplorerConfig: runner_num: Optional[int] = None # deprecated - # Explorer collects experiences from workflow runners - # some algorithms (e.g., DAPO) need to collect experiences generated by the same task and do some post-processing - # will automatically set to True if `algorithm.add_strategy` is not None - collect_experiences: bool = False - # for inference models # for rollout model rollout_model: InferenceModelConfig = field(default_factory=InferenceModelConfig) @@ -351,6 +346,12 @@ class ExplorerConfig: # for benchmark bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint + # ! DO NOT SET + # Explorer collects experiences from workflow runners + # some algorithms (e.g., DAPO) need to collect experiences generated by the same task and do some post-processing + # will automatically set to True if `algorithm.add_strategy` is not None + collect_experiences: bool = False + @dataclass class TrainerConfig: diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index ec19e9871e..438636e35e 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -1,7 +1,7 @@ """Patch for vllm OpenAI API server. 1. Mocks the `add_signal_handler` method to do nothing. -2. Adds `tokens` and `prompt_token_ids` to the `ChatCompletionResponse`. +2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. """ import asyncio import functools From ca1b23b34da2a50bfcfdb093cd2b18abfa36b13a Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 11:10:36 +0800 Subject: [PATCH 18/21] update comments --- trinity/algorithm/add_strategy/add_strategy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index 530886b1c2..ca81d6927b 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -38,7 +38,7 @@ def default_args(cls) -> dict: @ADD_STRATEGY.register_module("reward_variance") class RewardVarianceAddStrategy(AddStrategy): - """An example add strategy that filters experiences based on a reward threshold.""" + """An example AddStrategy that filters experiences based on a reward variance threshold.""" def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: super().__init__(writer) @@ -53,7 +53,7 @@ async def add(self, experiences: List[Experience], step: int) -> int: # check if the rewards are the same rewards = [exp.reward for exp in group_exps] variance = np.var(rewards) - if variance < self.variance_threshold: + if variance <= self.variance_threshold: continue cnt += len(group_exps) await self.writer.write_async(group_exps) From 619f0fd9cccc143903c27d716318bd226defe2ad Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 11:19:04 +0800 Subject: [PATCH 19/21] fix comments --- tests/explorer/explorer_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 5b79d81417..2c9c94f1bc 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -80,6 +80,7 @@ def test_explorer(self): from trinity.explorer.explorer import Explorer + self.config.algorithm.repeat_times = 2 self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.buffer.explorer_input.add_strategy = "random" @@ -110,7 +111,7 @@ def test_explorer(self): for count in experience_counts: self.assertTrue(count >= 0) self.assertTrue(count <= 2 * 4) # repeat_times * batch_size - self.assertTrue(count % (2 * 4) == 0) # should be multiple of repeat_times * batch_size + self.assertTrue(count % 2 == 0) # should be multiple of repeat_times reader = get_buffer_reader( self.config.buffer.trainer_input.experience_buffer, self.config.buffer @@ -122,5 +123,5 @@ def test_explorer(self): except StopIteration: pass self.assertTrue(len(exps) <= 4 * 2 * 4) # step * repeat_times * batch_size - self.assertTrue(len(exps) % (2 * 4) == 0) # should be multiple of repeat_times * batch_size + self.assertTrue(len(exps) % (2 * 4) == 0) # should be multiple of repeat_times ray.get(explorer.shutdown.remote()) From 8373cdf66fb09243089192a4346853190870da50 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 11:30:26 +0800 Subject: [PATCH 20/21] fix comments --- docs/sphinx_doc/source/tutorial/example_mix_algo.md | 2 +- trinity/algorithm/sample_strategy/mix_sample_strategy.py | 2 +- trinity/common/experience.py | 4 +++- trinity/common/workflows/step_wise_workflow.py | 4 ++-- trinity/explorer/workflow_runner.py | 6 ++++-- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 37e51d3e4c..52aa212b0a 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -143,7 +143,7 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array(eid.uid for eid in experiences.eids), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 07e351539c..60f908afe2 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -92,7 +92,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array(eid.uid for eid in experiences.eids), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/common/experience.py b/trinity/common/experience.py index c23c7b94e9..0c5f98e89c 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -142,7 +142,9 @@ def __init__( assert ( prompt_length > 0 ), "Prompt length must be greater than 0 for single-turn experiences." - assert len(tokens) > prompt_length, "Token ids must be longer than the prompt length." + assert ( + len(tokens) > prompt_length + ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." action_mask = torch.zeros(len(tokens), dtype=torch.bool) action_mask[prompt_length:] = 1 elif experience_type == ExperienceType.MULTI_TURN: diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py index 789bc5f596..db571a7b71 100644 --- a/trinity/common/workflows/step_wise_workflow.py +++ b/trinity/common/workflows/step_wise_workflow.py @@ -22,7 +22,7 @@ def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): def run(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards.""" experiences = [] - for step in self.max_step_num: + for step in range(self.max_step_num): # Run a single step of the agent application continue_run = self.step(step_num=step) # Collect experiences data of the current step @@ -82,7 +82,7 @@ def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): def run(self) -> list[Experience]: """Run the workflow and return a list of experiences with step-wise rewards.""" experiences = [] - for step in self.max_step_num: + for step in range(self.max_step_num): # Run a single step of the agent application continue_run = self.step(step_num=step) # Collect experiences data of the current step diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 8e7a31e41e..608db16abb 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -82,7 +82,7 @@ def run_task(self, task: Task) -> Tuple[Status, List[Experience]]: assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) # set group id - for idx, exp in enumerate(exps): + for _, exp in enumerate(exps): exp.eid.batch = task.batch_id exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: @@ -100,14 +100,16 @@ def run_task(self, task: Task) -> Tuple[Status, List[Experience]]: if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore + if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer return Status(True, metric=metric), [] - if self.return_experiences: + elif self.return_experiences: return Status(True, metric=metric), exps else: self.experience_buffer.write(exps) return Status(True, metric=metric), [] + except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") From 209fe5946875ad5c51c747a6d393b0cb8d060c83 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 22 Jul 2025 15:11:47 +0800 Subject: [PATCH 21/21] fix unittest --- .github/workflows/unittest.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index e03426c46b..e488067f28 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -18,6 +18,7 @@ jobs: steps: - uses: actions/checkout@v4 with: + fetch-depth: 0 path: trinity-${{ github.run_id }} ref: refs/pull/${{ github.event.issue.number }}/head