Skip to content
13 changes: 13 additions & 0 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,16 @@ trainer:
log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
# stages: # Uncomment to add a SFT warmup stage before RFT
# - stage_name: sft_warmup
# mode: train
# algorithm:
# algorithm_type: sft
# buffer:
# train_batch_size: 128
# total_steps: 10
# trainer_input:
# experience_buffer:
# name: sft_warmup_dataset
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
# - stage_name: rft # leave empty to use the original configs for RFT
1 change: 1 addition & 0 deletions examples/mix_chord/mix_chord.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ buffer:
total_epochs: 25
name: SFT_data
storage_type: file
schema_type: sft
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
split: 'train'
format:
Expand Down
1 change: 1 addition & 0 deletions examples/mix_math/mix_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ buffer:
total_epochs: 10
name: math_sft
storage_type: file
schema_type: sft
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
split: 'train'
format:
Expand Down
222 changes: 126 additions & 96 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,47 @@ def setUp(self):
def tearDown(self):
sys.argv = self._orig_argv

@mock.patch("trinity.cli.launcher.serve")
@mock.patch("trinity.cli.launcher.explore")
@mock.patch("trinity.cli.launcher.train")
@mock.patch("trinity.cli.launcher.both")
@mock.patch("trinity.cli.launcher.bench")
@mock.patch("trinity.cli.launcher.load_config")
def test_main_run_command(self, mock_load, mock_bench, mock_both, mock_train, mock_explore):
def test_main_run_command(
self, mock_load, mock_bench, mock_both, mock_train, mock_explore, mock_serve
):
config = get_template_config()
mapping = {
"explore": mock_explore,
"train": mock_train,
"both": mock_both,
"bench": mock_bench,
"serve": mock_serve,
}
for mode in ["explore", "train", "both", "bench"]:
config.mode = mode
mock_load.return_value = config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir=None
),
):
launcher.main()
mock_load.assert_called_once_with("dummy.yaml")
mapping[mode].assert_called_once_with(config)
mock_load.reset_mock()
mapping[mode].reset_mock()
with mock.patch.dict(
launcher.MODE_MAP,
{
"explore": mock_explore,
"train": mock_train,
"both": mock_both,
"bench": mock_bench,
"serve": mock_serve,
},
):
for mode in ["explore", "train", "both", "bench", "serve"]:
config.mode = mode
mock_load.return_value = config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir=None
),
):
launcher.main()
mock_load.assert_called_once_with("dummy.yaml")
mapping[mode].assert_called_once_with(config)
mock_load.reset_mock()
mapping[mode].reset_mock()

@mock.patch("trinity.cli.launcher.stop_ray_cluster")
@mock.patch("trinity.cli.launcher.setup_ray_cluster")
Expand All @@ -73,35 +87,41 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock
config.log.group_by_node = True
mock_setup.return_value = "auto"
mock_load.return_value = config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins"
),
):
launcher.main()
mock_init.assert_called_once()
mock_init.assert_called_once_with(
address="auto",
ignore_reinit_error=True,
namespace=config.ray_namespace,
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: config.log.save_dir,
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "1",
}
with mock.patch.dict(
launcher.MODE_MAP,
{
"both": mock_both,
},
)
mock_load.assert_called_once_with("dummy.yaml")
mock_both.assert_called_once_with(config)
mock_setup.assert_called_once_with(
namespace=namespace,
)
mock_stop.assert_called_once_with(
namespace=namespace,
)
):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins"
),
):
launcher.main()
mock_init.assert_called_once()
mock_init.assert_called_once_with(
address="auto",
ignore_reinit_error=True,
namespace=config.ray_namespace,
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: config.log.save_dir,
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "1",
}
},
)
mock_load.assert_called_once_with("dummy.yaml")
mock_both.assert_called_once_with(config)
mock_setup.assert_called_once_with(
namespace=namespace,
)
mock_stop.assert_called_once_with(
namespace=namespace,
)

@mock.patch("trinity.cli.launcher.studio")
def test_main_studio_command(self, mock_studio):
Expand Down Expand Up @@ -156,60 +176,70 @@ def test_multi_stage_run(
]
mock_load.return_value = config
mock_checkpoint_path.return_value = "/path/to/hf/checkpoint"
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins"
),
with mock.patch.dict(
launcher.MODE_MAP,
{
"train": mock_train,
"both": mock_both,
},
):
launcher.main()
self.assertEqual(mock_init.call_count, 2)
self.assertEqual(mock_shutdown.call_count, 2)
mock_train.assert_called_once()
mock_both.assert_called_once()
expected_calls = [
mock.call(
address="auto",
ignore_reinit_error=True,
namespace=f"{config.project}/{config.name}/sft_warmup",
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir,
config.project,
f"{config.name}/sft_warmup",
"log",
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
}
},
),
mock.call(
address="auto",
ignore_reinit_error=True,
namespace=f"{config.project}/{config.name}/grpo",
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir, config.project, f"{config.name}/grpo", "log"
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
}
},
),
]
mock_init.assert_has_calls(expected_calls)
self.assertEqual(mock_checkpoint_path.call_count, 2)
self.assertEqual(mock_train.call_args[0][0].model.model_path, config.model.model_path)
self.assertEqual(mock_both.call_args[0][0].model.model_path, "/path/to/hf/checkpoint")
self.assertEqual(
mock_both.call_args[0][0].trainer.trainer_config.actor_rollout_ref.model.path,
"/path/to/hf/checkpoint",
)
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins"
),
):
launcher.main()
self.assertEqual(mock_init.call_count, 2)
self.assertEqual(mock_shutdown.call_count, 2)
mock_train.assert_called_once()
mock_both.assert_called_once()
expected_calls = [
mock.call(
address="auto",
ignore_reinit_error=True,
namespace=f"{config.project}/{config.name}/sft_warmup",
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir,
config.project,
f"{config.name}/sft_warmup",
"log",
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
}
},
),
mock.call(
address="auto",
ignore_reinit_error=True,
namespace=f"{config.project}/{config.name}/grpo",
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir,
config.project,
f"{config.name}/grpo",
"log",
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
}
},
),
]
mock_init.assert_has_calls(expected_calls)
self.assertEqual(mock_checkpoint_path.call_count, 2)
self.assertEqual(mock_train.call_args[0][0].model.model_path, config.model.model_path)
self.assertEqual(mock_both.call_args[0][0].model.model_path, "/path/to/hf/checkpoint")
self.assertEqual(
mock_both.call_args[0][0].trainer.trainer_config.actor_rollout_ref.model.path,
"/path/to/hf/checkpoint",
)


if __name__ == "__main__":
Expand Down
21 changes: 14 additions & 7 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers import AutoTokenizer

from tests.tools import (
RayUnittestBase,
RayUnittestBaseAysnc,
get_api_model_path,
get_model_path,
Expand Down Expand Up @@ -127,6 +126,7 @@ def setUp(self):
async def test_generate(
self,
):
await self.model_wrapper.prepare()
prompts = ["Hello, world!", "Hello, my name is"]
n = self.config.algorithm.repeat_times
if self.use_async:
Expand Down Expand Up @@ -228,7 +228,7 @@ async def test_generate(
(20, None, 1),
],
)
class TestModelLen(RayUnittestBase):
class TestModelLen(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand All @@ -242,7 +242,8 @@ def setUp(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

def test_model_len(self):
async def test_model_len(self):
await self.model_wrapper.prepare()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
Expand Down Expand Up @@ -272,7 +273,7 @@ def test_model_len(self):
self.assertEqual(len(exps[0].tokens), self.max_model_len)


class TestAPIServer(RayUnittestBase):
class TestAPIServer(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand All @@ -291,7 +292,9 @@ def setUp(self):
self.engines[0], engine_type="vllm", enable_history=False
)

def test_api(self):
async def test_api(self):
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_client()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -361,6 +364,8 @@ def setUp(self):
)

async def test_api_async(self):
await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
openai_client = self.model_wrapper.get_openai_async_client()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -528,7 +533,7 @@ def test_action_mask_with_tools(self):
(False, None),
],
)
class TestAPIServerToolCall(RayUnittestBase):
class TestAPIServerToolCall(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand All @@ -552,13 +557,15 @@ def setUp(self):
self.engines[0], engine_type="vllm", enable_history=False
)

def test_api_tool_calls(self):
async def test_api_tool_calls(self):
"""Tests the full conversation flow of a tool call via the OpenAI API.
Note: This test require a model that supports tool calls and thinking mode, e.g. Qwen3-1.7B.
"""
import json
import time

await self.model_wrapper.prepare()
await self.model_wrapper_no_history.prepare()
tokenizer = AutoTokenizer.from_pretrained(get_api_model_path())
print_debug("\n\n" + "=" * 30 + " Running test_api_tool_calls " + "=" * 30)
start_time = time.time()
Expand Down
Loading