From f338804cc4e0c1f1628a2ad379fe8a2304852b4f Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 23 Sep 2025 18:24:12 +0800 Subject: [PATCH 1/7] support mm sft --- tests/buffer/formatter_test.py | 14 +- trinity/buffer/reader/file_reader.py | 6 +- trinity/buffer/schema/formatter.py | 145 +++++++++++++----- trinity/buffer/storage/sql.py | 7 +- trinity/common/models/mm_utils.py | 98 ++++++------ trinity/common/models/model.py | 34 ++-- trinity/common/models/vllm_model.py | 46 +++--- .../common/workflows/simple_mm_workflow.py | 32 ++-- trinity/common/workflows/workflow.py | 6 +- trinity/trainer/verl_trainer.py | 1 - 10 files changed, 224 insertions(+), 165 deletions(-) diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py index 7ddc0f88d1..4e807f5ae5 100644 --- a/tests/buffer/formatter_test.py +++ b/tests/buffer/formatter_test.py @@ -18,7 +18,7 @@ def test_sft_messages_formatter(self): prompt_type=PromptType.MESSAGES, messages_key="message_list", ) - formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config) sample = { "message_list": [ {"role": "user", "content": "Hi"}, @@ -100,7 +100,7 @@ def test_sft_messages_formatter(self): tools_key="tools", enable_concatenated_multi_turn=False, ) - formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config) exp = formatter.format(sample) self.assertIsInstance(exp, Experience) self.assertIsNotNone(exp.tokens) @@ -125,7 +125,7 @@ def test_sft_messages_formatter(self): tools_key="tools", enable_concatenated_multi_turn=True, ) - formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config) exp = formatter.format(sample) self.assertIsInstance(exp, Experience) self.assertIsNotNone(exp.tokens) @@ -157,7 +157,7 @@ def test_sft_plaintext_formatter(self): prompt_key="prompt", response_key="response", ) - formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config) sample = { "system": "You are a helpful assistant.", "prompt": "What is 2+2?", @@ -181,7 +181,7 @@ def test_sft_plaintext_formatter(self): prompt_key="prompt", response_key="response", ) - formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config) exp = formatter.format(sample) self.assertIsInstance(exp, Experience) @@ -201,7 +201,7 @@ def test_dpo_plaintext_formatter(self): chosen_key="chosen", rejected_key="rejected", ) - formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("dpo")(tokenizer_path=get_model_path(), format_config=config) sample = {"prompt": "What is 2+2?", "chosen": "2+2=4", "rejected": "2+2=5"} exp = formatter.format(sample) self.assertIsInstance(exp, Experience) @@ -227,7 +227,7 @@ def test_dpo_messages_formatter(self): chosen_key="chosen", rejected_key="rejected", ) - formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config) + formatter = FORMATTER.get("dpo")(tokenizer_path=get_model_path(), format_config=config) sample = { "messages": [ {"role": "user", "content": "What is your name?"}, diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index f0a7e7a185..b79e87285d 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -3,7 +3,6 @@ from typing import List, Optional import datasets -import transformers from datasets import Dataset, load_dataset from trinity.buffer.buffer_reader import BufferReader @@ -100,12 +99,11 @@ async def read_async(self, batch_size: Optional[int] = None): class ExperienceFileReader(BaseFileReader): - """Reader for SFT file data.""" + """Reader for SFT / DPO file data.""" def __init__(self, meta: StorageConfig, config: BufferConfig): - self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) self.formatter = FORMATTER.get(meta.schema_type)( - tokenizer=self.tokenizer, format_config=meta.format + tokenizer_path=config.tokenizer_path, format_config=meta.format ) self.read_batch_size = config.train_batch_size self.dataset = _HFBatchReader( diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index cdf1447f00..856d3554b7 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional +import transformers + from trinity.common.config import FormatConfig, StorageConfig from trinity.common.constants import PromptType from trinity.common.experience import Experience @@ -99,16 +101,26 @@ class SFTFormatter(ExperienceFormatter): } """ - def __init__(self, tokenizer, format_config: FormatConfig): + def __init__(self, tokenizer_path: str, format_config: FormatConfig): self.logger = get_logger("sft_dataset_formatter", in_ray_actor=True) - self.tokenizer = tokenizer self.prompt_type = format_config.prompt_type self.enable_concatenated_multi_turn = format_config.enable_concatenated_multi_turn - self.chat_template = format_config.chat_template or tokenizer.chat_template + self.tools_key = format_config.tools_key + self.image_key = format_config.image_key + self.video_key = format_config.video_key + if self.image_key is not None or self.video_key is not None: + assert ( + self.enable_concatenated_multi_turn is False + ), "Concatenated multi-turn not supported for multi-modal data yet." + self.processor = transformers.AutoProcessor.from_pretrained(tokenizer_path) + self.tokenizer = self.processor.tokenizer + else: + self.processor = None + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) + self.chat_template = format_config.chat_template or self.tokenizer.chat_template # For messages type if self.prompt_type == PromptType.MESSAGES: self.messages_key = format_config.messages_key - self.tools_key = format_config.tools_key if format_config.enable_concatenated_multi_turn: self.action_mask_method = get_action_mask_method(self.chat_template) # For plaintext type @@ -123,27 +135,20 @@ def __init__(self, tokenizer, format_config: FormatConfig): def _messages_to_experience( self, - messages: List[Dict] | str, # or could be str from json dumps - tools: Optional[List[Dict] | str] = None, # or could also be str from json dumps + messages: List[Dict], + tools: Optional[List[Dict] | str] = None, + mm_data: Optional[Dict] = None, ) -> Experience: """Convert messages and tools into an Experience object. Args: - messages (List[Dict]|str): The list of message dictionaries or a JSON string. + messages (List[Dict]): The list of message dictionaries. tools (Optional[List[Dict]|str], optional): The list of tool dictionaries or a JSON string. Defaults to None. + mm_data (Optional[Dict], optional): Multi-modal data such as images or videos. Defaults to None. Returns: Experience: The resulting Experience object. """ - if isinstance(messages, str): - try: - messages = json.loads(messages) - except json.JSONDecodeError: - self.logger.error( - "[SFT Data Error] Failed to decode 'messages' JSON. please check your data format." - ) - raise ValueError("Invalid JSON format for messages") - # Warning if tools is accidentally provided as list of dicts (with Huggingface datasets this may cause schema issues) if tools is not None and isinstance(tools, list): self.logger.warning( "[SFT Data Warning] 'tools' is provided as a list of dictionaries. " @@ -160,13 +165,6 @@ def _messages_to_experience( "[SFT Data Error] Failed to decode 'tools' JSON. Please check your data format." ) raise ValueError("Invalid JSON format for tools") - tokens = self.tokenizer.apply_chat_template( - messages, - tools=tools, - add_generation_prompt=False, - return_tensors="pt", - chat_template=self.chat_template, - )[0] if self.enable_concatenated_multi_turn: token_ids, action_mask, prompt_length = self.action_mask_method( tokenizer=self.tokenizer, @@ -180,23 +178,91 @@ def _messages_to_experience( prompt_length=prompt_length, messages=messages, ) - else: - prompt_tokens_ids = self.tokenizer.apply_chat_template( - messages[:-1], - tools=tools, - add_generation_prompt=True, - return_tensors="pt", - chat_template=self.chat_template, - )[0] - return Experience( - tokens=tokens, - prompt_length=len(prompt_tokens_ids), - messages=messages, - ) + if mm_data: + return self.convert_mm_data_to_experiences(messages=messages, mm_data=mm_data) + token_ids = self.tokenizer.apply_chat_template( + messages, + tools=tools, + add_generation_prompt=False, + return_tensors="pt", + chat_template=self.chat_template, + )[0] + prompt_tokens_ids = self.tokenizer.apply_chat_template( + messages[:-1], + tools=tools, + add_generation_prompt=True, + return_tensors="pt", + chat_template=self.chat_template, + )[0] + return Experience( + tokens=token_ids, + prompt_length=len(prompt_tokens_ids), + messages=messages, + ) + + def load_mm_data(self, sample: Dict) -> Dict: + """Load multi-modal data such as images or videos. + + Returns: + Dict: A dictionary containing multi-modal data. + """ + from verl.utils.dataset.vision_utils import process_image, process_video + + mm_data = {} + if self.image_key: + mm_data["images"] = [process_image(img) for img in sample[self.image_key]] + if self.video_key: + mm_data["videos"] = [process_video(vid).numpy() for vid in sample[self.video_key]] + return mm_data + + def convert_mm_data_to_experiences( + self, + messages: List[Dict], + mm_data: Dict, + ) -> Experience: + from trinity.common.models.mm_utils import build_multi_modal_inputs + + sequence: str = self.processor.apply_chat_template( + messages, + add_generation_prompt=False, + chat_template=self.chat_template, + ) + prompt: str = self.processor.apply_chat_template( + messages[:-1], + add_generation_prompt=True, + chat_template=self.chat_template, + ) + sequence_data = build_multi_modal_inputs( + prompt=sequence, + images=mm_data.get("images", None), + videos=mm_data.get("videos", None), + processor=self.processor, + ) + prompt_data = build_multi_modal_inputs( + prompt=prompt, + images=mm_data.get("images", None), + videos=mm_data.get("videos", None), + processor=self.processor, + ) + return Experience( + tokens=sequence_data["prompt_token_ids"], + prompt_length=len(prompt_data["prompt_token_ids"]), + messages=messages, + multi_modal_inputs=sequence_data["multi_modal_inputs"], + ) def format(self, sample: Dict) -> Experience: if self.prompt_type == PromptType.MESSAGES: messages = sample[self.messages_key] + # load messages from json string if needed + if isinstance(messages, str): + try: + messages = json.loads(messages) + except json.JSONDecodeError: + self.logger.error( + "[SFT Data Error] Failed to decode 'messages' JSON. please check your data format." + ) + raise ValueError("Invalid JSON format for messages") elif self.prompt_type == PromptType.PLAINTEXT: messages = [] if self.system_prompt_key is not None: @@ -210,7 +276,8 @@ def format(self, sample: Dict) -> Experience: else: raise ValueError(f"Unsupported prompt_type: {self.prompt_type}") tools = sample.get(self.tools_key, None) - return self._messages_to_experience(messages, tools) + mm_data = self.load_mm_data(sample) if self.image_key or self.video_key else None + return self._messages_to_experience(messages, tools, mm_data) @FORMATTER.register_module("dpo") @@ -244,8 +311,8 @@ class DPOFormatter(ExperienceFormatter): } """ - def __init__(self, tokenizer, format_config: FormatConfig): - self.tokenizer = tokenizer + def __init__(self, tokenizer_path: str, format_config: FormatConfig): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) self.prompt_type = format_config.prompt_type self.chat_template = format_config.chat_template if self.prompt_type == PromptType.PLAINTEXT: diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index e7e42378de..de4e34f6ff 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -194,14 +194,13 @@ def read(self, batch_size: Optional[int] = None) -> List[Experience]: def load_from_dataset( cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig ) -> "SQLExperienceStorage": - import transformers - - tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) storage = cls( storage_config=storage_config, config=config, ) - formatter = FORMATTER.get(storage_config.schema_type)(tokenizer, storage_config.format) + formatter = FORMATTER.get(storage_config.schema_type)( + tokenizer_path=config.tokenizer_path, format_config=storage_config.format + ) batch_size = storage.batch_size batch = [] for item in dataset: diff --git a/trinity/common/models/mm_utils.py b/trinity/common/models/mm_utils.py index 6a66db261a..65f996f880 100644 --- a/trinity/common/models/mm_utils.py +++ b/trinity/common/models/mm_utils.py @@ -1,78 +1,70 @@ -from typing import Any, Dict +""""Multi-modal utilities for processing and handling multi-modal data such as images and videos. +Only support Qwen2.5 VL series. + +Modified from: verl/utils/dataset/rl_dataset.py +""" +import re +from typing import Any, Dict, List def build_multi_modal_inputs( prompt: str, - raw_mm_data: Dict[str, Any], + images: List, + videos: List, processor: Any, - **kwargs, ) -> Dict[str, Any]: """ Preprocess multi-modal data and build multi-modal inputs - Adapted from: verl/utils/dataset/rl_dataset.py """ - from verl.utils.dataset.vision_utils import process_image, process_video - if prompt is None: raise ValueError("Prompt is required for build multi-modal inputs") - raw_images, raw_videos = None, None - if "image" in raw_mm_data: - raw_images = raw_mm_data["image"] - if "video" in raw_mm_data: - raw_videos = raw_mm_data["video"] - multi_modal_data = {} - images, videos = None, None - if raw_images is not None: - images = [process_image(image) for image in raw_images] + if images: multi_modal_data["image"] = images - if raw_videos is not None: - videos = [process_video(video) for video in raw_videos] - multi_modal_data["video"] = [video.numpy() for video in videos] + if videos: + multi_modal_data["video"] = videos - model_inputs = processor(text=[prompt], images=images, videos=videos, return_tensors="pt") + model_inputs = processor( + text=[prompt], + images=multi_modal_data.get("image", None), + videos=multi_modal_data.get("video", None), + return_tensors="pt", + ) - model_inputs.pop("input_ids", None) # TODO: check - model_inputs.pop("attention_mask", None) + input_ids = model_inputs.pop("input_ids")[0] + model_inputs.pop("attention_mask") - # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature - multi_modal_inputs = dict(model_inputs) + if "second_per_grid_ts" in model_inputs: + model_inputs.pop("second_per_grid_ts") return { "prompt": prompt, - "multi_modal_inputs": multi_modal_inputs, + "prompt_token_ids": input_ids, "multi_modal_data": multi_modal_data, + "multi_modal_inputs": dict(model_inputs), } -def attach_images_to_messages(messages, raw_mm_data): - new_msgs = [dict(m) for m in messages] - imgs = (raw_mm_data or {}).get("image") or [] - if not imgs: - return new_msgs - - for i in range(len(new_msgs) - 1, -1, -1): - if new_msgs[i].get("role") == "user": - content = new_msgs[i].get("content", "") - items = [] - if isinstance(content, str): - text = content.replace("", "").replace("<|image_pad|>", "").strip() - if text: - items.append({"type": "text", "text": text}) - elif isinstance(content, list): - for c in content: - if isinstance(c, str): - t = c.replace("", "").replace("<|image_pad|>", "").strip() - if t: - items.append({"type": "text", "text": t}) - elif isinstance(c, dict): - items.append(c) - - for img in imgs: - items.append({"type": "image", "image": img}) - - new_msgs[i]["content"] = items - break +def convert_messages_to_mm_format(messages: List[Dict]) -> List[Dict]: + for message in messages: + content = message["content"] + content_list = [] + segments = re.split("(|