Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
run: |
REPORT=report.json
if [ -f "$REPORT" ]; then
jq '(.results.tests[] | .duration, .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT"
jq '(.results.tests[] | .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT"
fi

- name: Clean checkpoint dir
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"matplotlib",
"transformers>=4.51.0",
"datasets>=4.0.0",
"typer>=0.20.1",
]

[project.scripts]
Expand Down
177 changes: 95 additions & 82 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from unittest import mock
from unittest.mock import MagicMock

from typer.testing import CliRunner as TyperCliRunner

from tests.tools import (
get_checkpoint_path,
get_model_path,
Expand All @@ -27,6 +29,8 @@
)
from trinity.common.models import get_debug_explorer_model

runner = TyperCliRunner()


class TestLauncherMain(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -73,13 +77,11 @@ def test_main_run_command(
for mode in ["explore", "train", "both", "bench", "serve"]:
config.mode = mode
mock_load.return_value = config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir=None
),
):
launcher.main()
result = runner.invoke(
launcher.app,
["run", "--config", "dummy.yaml"],
)
self.assertEqual(result.exit_code, 0, msg=result.output)
mock_load.assert_called_once_with("dummy.yaml")
mapping[mode].assert_called_once_with(config)
mock_load.reset_mock()
Expand All @@ -104,13 +106,11 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock
"both": mock_both,
},
):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins"
),
):
launcher.main()
result = runner.invoke(
launcher.app,
["run", "--config", "dummy.yaml", "--dlc", "--plugin-dir", "/path/to/plugins"],
)
self.assertEqual(result.exit_code, 0, msg=result.output)
mock_init.assert_called_once()
mock_init.assert_called_once_with(
address="auto",
Expand All @@ -134,14 +134,20 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock
namespace=namespace,
)

@mock.patch("trinity.cli.launcher.studio")
def test_main_studio_command(self, mock_studio):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(command="studio", port=9999),
):
launcher.main()
mock_studio.assert_called_once_with(9999)
@mock.patch("trinity.manager.config_manager.ConfigManager.run")
def test_main_studio_command(self, mock_studio_fn):
result = runner.invoke(
launcher.app,
["studio", "--port", "9999"],
)
self.assertEqual(result.exit_code, 0, msg=result.output)
mock_studio_fn.assert_called_once()
# Typer calls the function with keyword args; verify port was passed
call_kwargs = mock_studio_fn.call_args
# The typer-decorated function receives port=9999
self.assertEqual(
call_kwargs.kwargs.get("port", call_kwargs.args[0] if call_kwargs.args else None), 9999
)

@mock.patch("trinity.trainer.verl.utils.get_latest_hf_checkpoint_path")
@mock.patch("trinity.cli.launcher.both")
Expand Down Expand Up @@ -194,13 +200,17 @@ def test_multi_stage_run(
"both": mock_both,
},
):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins"
),
):
launcher.main()
result = runner.invoke(
launcher.app,
[
"run",
"--config",
"dummy.yaml",
"--plugin-dir",
"/path/to/plugins",
],
)
self.assertEqual(result.exit_code, 0, msg=result.output)
self.assertEqual(mock_init.call_count, 2)
self.assertEqual(mock_shutdown.call_count, 2)
mock_train.assert_called_once()
Expand Down Expand Up @@ -268,62 +278,66 @@ def test_debug_mode(self, mock_load):
output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=True,
disable_overwrite=False,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()

# First run: workflow with profiling enabled
result = runner.invoke(
launcher.app,
[
"debug",
"--config",
"dummy.yaml",
"--module",
"workflow",
"--enable-profiling",
"--output-dir",
output_dir,
],
)
self.assertEqual(result.exit_code, 0, msg=result.output)

self.assertFalse(os.path.exists(output_file))
self.assertTrue(os.path.exists(output_dir))
self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html")))
self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db")))

# add a dummy file to test overwrite behavior
with open(os.path.join(output_dir, "dummy.txt"), "w") as f:
f.write("not empty")

with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=False,
disable_overwrite=False,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()
# Second run: workflow without profiling, overwrite allowed (default)
result = runner.invoke(
launcher.app,
[
"debug",
"--config",
"dummy.yaml",
"--module",
"workflow",
"--output-dir",
output_dir,
],
)
self.assertEqual(result.exit_code, 0, msg=result.output)

dirs = os.listdir(self.config.checkpoint_job_dir)
target_output_dir = [d for d in dirs if d.startswith("debug_output_")]
self.assertEqual(len(target_output_dir), 0)

with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
enable_profiling=False,
disable_overwrite=True,
output_dir=output_dir,
output_file=output_file,
plugin_dir="",
),
):
launcher.main()
# Third run: workflow without profiling, overwrite disabled
result = runner.invoke(
launcher.app,
[
"debug",
"--config",
"dummy.yaml",
"--module",
"workflow",
"--disable-overwrite",
"--output-dir",
output_dir,
],
)
self.assertEqual(result.exit_code, 0, msg=result.output)

self.assertFalse(os.path.exists(output_file))
# test the original files are not overwritten
Expand Down Expand Up @@ -357,14 +371,13 @@ def debug_inference_model_process():
config.model.model_path = get_model_path()
config.check_and_update()
with mock.patch("trinity.cli.launcher.load_config", return_value=config):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="inference_model",
plugin_dir=None,
output_file=None,
),
):
launcher.main()
runner.invoke(
launcher.app,
[
"debug",
"--config",
"dummy.yaml",
"--module",
"inference_model",
],
)
31 changes: 22 additions & 9 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,9 @@ async def test_workflow_runner(self):
workflow_args={"output_format": "json"},
)

status, exps = await runner.run_task(task, repeat_times=3, run_id_base=0)
status, exps = await runner.run_task(
task, batch_id="test", repeat_times=3, run_id_base=0
)

self.assertTrue(status.ok)
self.assertIsInstance(exps, list)
Expand All @@ -693,7 +695,9 @@ async def test_workflow_runner(self):
workflow_args={"output_format": "yaml"},
)

status, exps = await runner.run_task(task, repeat_times=2, run_id_base=0)
status, exps = await runner.run_task(
task, batch_id="test", repeat_times=2, run_id_base=0
)
self.assertTrue(status.ok)
self.assertIsInstance(exps, list)
self.assertEqual(len(exps), 2)
Expand Down Expand Up @@ -750,7 +754,10 @@ async def monitor_routine():
return count

await asyncio.gather(
*[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)]
*[
monitor_routine(),
runner.run_task(task, batch_id="test", repeat_times=3, run_id_base=0),
]
)

async def test_workflow_with_openai(self):
Expand Down Expand Up @@ -784,13 +791,15 @@ async def test_workflow_with_openai(self):
]

status, exps = await runner.run_task(
tasks[0], repeat_times=2, run_id_base=0
tasks[0], batch_id="test", repeat_times=2, run_id_base=0
) # test exception handling
self.assertEqual(status.ok, False)
self.assertEqual(len(exps), 0)
exps = runner.model_wrapper.extract_experience_from_history(clear_history=False)
self.assertEqual(len(exps), 1)
status, exps = await runner.run_task(tasks[1], repeat_times=2, run_id_base=0) # normal run
status, exps = await runner.run_task(
tasks[1], batch_id="test", repeat_times=2, run_id_base=0
) # normal run
self.assertEqual(status.ok, True)
self.assertEqual(len(exps), 2)
exps = runner.model_wrapper.extract_experience_from_history(clear_history=False)
Expand Down Expand Up @@ -870,19 +879,23 @@ async def test_concurrent_workflow_runner(self):
raw_task={"text": "Hello, world!"},
)
# warmup
async_status, async_exps = await async_runner.run_task(task, repeat_times=2, run_id_base=0)
async_status, async_exps = await async_runner.run_task(
task, batch_id="test", repeat_times=2, run_id_base=0
)

st = time.time()
async_status, async_exps = await async_runner.run_task(task, repeat_times=4, run_id_base=0)
async_status, async_exps = await async_runner.run_task(
task, batch_id="test", repeat_times=4, run_id_base=0
)
async_runtime = time.time() - st
st = time.time()
thread_status, thread_exps = await thread_runner.run_task(
task, repeat_times=4, run_id_base=0
task, batch_id="test", repeat_times=4, run_id_base=0
)
thread_runtime = time.time() - st
st = time.time()
sequential_status, sequential_exps = await sequential_runner.run_task(
task, repeat_times=4, run_id_base=0
task, batch_id="test", repeat_times=4, run_id_base=0
)
sequential_runtime = time.time() - st

Expand Down
34 changes: 34 additions & 0 deletions trinity/buffer/viewer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import sys
from pathlib import Path
from typing import List

import streamlit as st
Expand Down Expand Up @@ -37,6 +39,38 @@ def total_experiences(self) -> int:
count = session.query(self.table_model_cls).count()
return count

@staticmethod
def run_viewer(model_path: str, db_url: str, table_name: str, port: int):
"""Start the Streamlit viewer.

Args:
model_path (str): Path to the tokenizer/model directory.
db_url (str): Database URL for the experience database.
table_name (str): Name of the experience table in the database.
port (int): Port number to run the Streamlit app on.
"""

from streamlit.web import cli

viewer_path = Path(__file__)
sys.argv = [
"streamlit",
"run",
str(viewer_path.resolve()),
"--server.port",
str(port),
"--server.fileWatcherType",
"none",
"--",
"--db-url",
db_url,
"--table",
table_name,
"--tokenizer",
model_path,
]
sys.exit(cli.main())


st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide")

Expand Down
Loading