Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
trinity-node-1:
image: trinity-rft-unittest:20250918
image: trinity-rft-unittest:20250924
pull_policy: never
command: sh -c "pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block"
environment:
Expand Down Expand Up @@ -28,7 +28,7 @@ services:
capabilities: [gpu]

trinity-node-2:
image: trinity-rft-unittest:20250918
image: trinity-rft-unittest:20250924
pull_policy: never
command: sh -c "pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block"
environment:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion examples/grpo_vlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ qwen_vl_utils

For other detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).

The config file is located in [`vlm.yaml`](vlm.yaml).
The config file is located in [`vlm.yaml`](vlm.yaml), and the curve is shown below.

![vlm](../../docs/sphinx_doc/assets/geometry3k_qwen25_vl_3b_reward.png)
42 changes: 34 additions & 8 deletions tests/buffer/formatter_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import unittest

from datasets import load_dataset
from transformers import AutoTokenizer

from tests.tools import get_model_path
from tests.tools import (
get_model_path,
get_unittest_dataset_config,
get_vision_language_model_path,
)
from trinity.buffer.schema.formatter import FORMATTER
from trinity.common.config import FormatConfig, StorageConfig
from trinity.common.constants import PromptType
Expand All @@ -18,7 +23,7 @@ def test_sft_messages_formatter(self):
prompt_type=PromptType.MESSAGES,
messages_key="message_list",
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config)
sample = {
"message_list": [
{"role": "user", "content": "Hi"},
Expand Down Expand Up @@ -100,7 +105,7 @@ def test_sft_messages_formatter(self):
tools_key="tools",
enable_concatenated_multi_turn=False,
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
Expand All @@ -125,7 +130,7 @@ def test_sft_messages_formatter(self):
tools_key="tools",
enable_concatenated_multi_turn=True,
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config)
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
Expand Down Expand Up @@ -157,7 +162,7 @@ def test_sft_plaintext_formatter(self):
prompt_key="prompt",
response_key="response",
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config)
sample = {
"system": "You are a helpful assistant.",
"prompt": "What is 2+2?",
Expand All @@ -181,7 +186,7 @@ def test_sft_plaintext_formatter(self):
prompt_key="prompt",
response_key="response",
)
formatter = FORMATTER.get("sft")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("sft")(tokenizer_path=get_model_path(), format_config=config)

exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
Expand All @@ -201,7 +206,7 @@ def test_dpo_plaintext_formatter(self):
chosen_key="chosen",
rejected_key="rejected",
)
formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("dpo")(tokenizer_path=get_model_path(), format_config=config)
sample = {"prompt": "What is 2+2?", "chosen": "2+2=4", "rejected": "2+2=5"}
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
Expand All @@ -227,7 +232,7 @@ def test_dpo_messages_formatter(self):
chosen_key="chosen",
rejected_key="rejected",
)
formatter = FORMATTER.get("dpo")(tokenizer=self.tokenizer, format_config=config)
formatter = FORMATTER.get("dpo")(tokenizer_path=get_model_path(), format_config=config)
sample = {
"messages": [
{"role": "user", "content": "What is your name?"},
Expand Down Expand Up @@ -308,3 +313,24 @@ def test_task_formatter(self):
self.assertTrue(task.workflow_args.get("use_base"))
self.assertFalse(task.workflow_args.get("with_think"))
self.assertEqual(task.raw_task, sample)

def test_multi_modal_sft_formatter(self):
IMAGE_TOKEN_ID = 151655 # only for Qwen2.5 VL, if changed, please update this test
storage_config = get_unittest_dataset_config("geometry")

formatter = FORMATTER.get("sft")(
tokenizer_path=get_vision_language_model_path(), format_config=storage_config.format
)
ds = load_dataset(storage_config.path, split=storage_config.split)
count = 0
for sample in ds:
exp = formatter.format(sample)
self.assertIsInstance(exp, Experience)
self.assertIsNotNone(exp.tokens)
self.assertIn(IMAGE_TOKEN_ID, exp.tokens)
self.assertIsNotNone(exp.prompt_length)
self.assertTrue(exp.prompt_length < len(exp.tokens))
self.assertIsNotNone(exp.multi_modal_inputs)
self.assertTrue(len(exp.multi_modal_inputs) > 0)
count += 1
self.assertEqual(count, 8) # there are total 8 samples in geometry dataset
26 changes: 16 additions & 10 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,32 @@ async def new_finish_explore_step(self, step: int, model_version: int) -> None:

def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, intervals)
train(config)
ray.shutdown(_exiting_interpreter=True)
try:
trainer_monkey_patch(config, max_steps, intervals)
train(config)
finally:
ray.shutdown(_exiting_interpreter=True)


def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
ray.shutdown(_exiting_interpreter=True)
try:
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
finally:
ray.shutdown(_exiting_interpreter=True)


def run_both(
config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int]
) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
ray.shutdown(_exiting_interpreter=True)
try:
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
finally:
ray.shutdown(_exiting_interpreter=True)


class BaseTestSynchronizer(unittest.TestCase):
Expand Down
Binary file modified tests/template/data/geometry/train.parquet
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_checkpoint_path() -> str:
return path


def get_vision_languge_model_path() -> str:
def get_vision_language_model_path() -> str:
path = os.environ.get(VLM_MODEL_PATH_ENV_VAR)
if not path:
raise EnvironmentError(
Expand Down Expand Up @@ -147,6 +147,7 @@ def get_unittest_dataset_config(
path=os.path.join(os.path.dirname(__file__), "template", "data", "geometry"),
split="train",
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
prompt_key="problem",
response_key="answer",
image_key="images",
Expand Down
51 changes: 44 additions & 7 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_model_path,
get_template_config,
get_unittest_dataset_config,
get_vision_languge_model_path,
get_vision_language_model_path,
)
from trinity.cli.launcher import bench, both, explore, run, train
from trinity.common.config import (
Expand Down Expand Up @@ -682,23 +682,21 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerMultiModal(BaseTrainerCase):
class TestMultiModalGRPO(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
"""Test both mode with multi-modal data."""
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config(
"geometry"
) # Total 8 tasks
self.config.model.model_path = get_vision_languge_model_path()
self.config.model.model_path = get_vision_language_model_path()
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.kl_loss_fn = "none"
self.config.algorithm.repeat_times = 4
self.config.buffer.batch_size = 4
self.config.buffer.total_epochs = 1
self.config.trainer.save_interval = 1
self.config.cluster.node_num = 1
self.config.cluster.gpu_per_node = 4
self.config.trainer.save_interval = 2
self.config.check_and_update()
both(self.config)
# check metrics are available
Expand All @@ -712,7 +710,46 @@ def test_trainer(self):
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 2)
ray.shutdown(_exiting_interpreter=True)
# check save lastest checkpoint
checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0)
self.assertEqual(step_num, 2)

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestMultiModalSFT(BaseTrainerCase):
@unittest.skip("Require specific vllm/transformers version")
def test_trainer(self):
"""Test SFT mode with multi-modal data."""
self.config.mode = "train"
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config(
"geometry"
) # Total 8 tasks
self.config.model.model_path = get_vision_language_model_path()
self.config.algorithm.algorithm_type = "sft"
self.config.algorithm.policy_loss_fn = "sft"
self.config.algorithm.policy_loss_fn_args = {}
self.config.algorithm.kl_loss_fn = "none"
self.config.algorithm.entropy_loss_fn = "none"
self.config.buffer.train_batch_size = 4
self.config.buffer.total_epochs = 1
self.config.trainer.save_interval = 2
self.config.check_and_update()
train(self.config)
# check metrics are available
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 2)
# check save lastest checkpoint
checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
Expand Down
6 changes: 2 additions & 4 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional

import datasets
import transformers
from datasets import Dataset, load_dataset

from trinity.buffer.buffer_reader import BufferReader
Expand Down Expand Up @@ -100,12 +99,11 @@ async def read_async(self, batch_size: Optional[int] = None):


class ExperienceFileReader(BaseFileReader):
"""Reader for SFT file data."""
"""Reader for SFT / DPO file data."""

def __init__(self, meta: StorageConfig, config: BufferConfig):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
self.formatter = FORMATTER.get(meta.schema_type)(
tokenizer=self.tokenizer, format_config=meta.format
tokenizer_path=config.tokenizer_path, format_config=meta.format
)
self.read_batch_size = config.train_batch_size
self.dataset = _HFBatchReader(
Expand Down
Loading