diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index a32f3e258d..bd7d1857cf 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -72,3 +72,16 @@ trainer: log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size +# stages: # Uncomment to add a SFT warmup stage before RFT +# - stage_name: sft_warmup +# mode: train +# algorithm: +# algorithm_type: sft +# buffer: +# train_batch_size: 128 +# total_steps: 10 +# trainer_input: +# experience_buffer: +# name: sft_warmup_dataset +# path: ${oc.env:TRINITY_SFT_DATASET_PATH} +# - stage_name: rft # leave empty to use the original configs for RFT diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index bc989a0842..86622eaaa0 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -58,6 +58,7 @@ buffer: total_epochs: 25 name: SFT_data storage_type: file + schema_type: sft path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts} split: 'train' format: diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index f0fb123516..987538e1e3 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -58,6 +58,7 @@ buffer: total_epochs: 10 name: math_sft storage_type: file + schema_type: sft path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts} split: 'train' format: diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index b66ffc1963..dca5da8759 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -32,33 +32,47 @@ def setUp(self): def tearDown(self): sys.argv = self._orig_argv + @mock.patch("trinity.cli.launcher.serve") @mock.patch("trinity.cli.launcher.explore") @mock.patch("trinity.cli.launcher.train") @mock.patch("trinity.cli.launcher.both") @mock.patch("trinity.cli.launcher.bench") @mock.patch("trinity.cli.launcher.load_config") - def test_main_run_command(self, mock_load, mock_bench, mock_both, mock_train, mock_explore): + def test_main_run_command( + self, mock_load, mock_bench, mock_both, mock_train, mock_explore, mock_serve + ): config = get_template_config() mapping = { "explore": mock_explore, "train": mock_train, "both": mock_both, "bench": mock_bench, + "serve": mock_serve, } - for mode in ["explore", "train", "both", "bench"]: - config.mode = mode - mock_load.return_value = config - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=False, plugin_dir=None - ), - ): - launcher.main() - mock_load.assert_called_once_with("dummy.yaml") - mapping[mode].assert_called_once_with(config) - mock_load.reset_mock() - mapping[mode].reset_mock() + with mock.patch.dict( + launcher.MODE_MAP, + { + "explore": mock_explore, + "train": mock_train, + "both": mock_both, + "bench": mock_bench, + "serve": mock_serve, + }, + ): + for mode in ["explore", "train", "both", "bench", "serve"]: + config.mode = mode + mock_load.return_value = config + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="run", config="dummy.yaml", dlc=False, plugin_dir=None + ), + ): + launcher.main() + mock_load.assert_called_once_with("dummy.yaml") + mapping[mode].assert_called_once_with(config) + mock_load.reset_mock() + mapping[mode].reset_mock() @mock.patch("trinity.cli.launcher.stop_ray_cluster") @mock.patch("trinity.cli.launcher.setup_ray_cluster") @@ -73,35 +87,41 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock config.log.group_by_node = True mock_setup.return_value = "auto" mock_load.return_value = config - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins" - ), - ): - launcher.main() - mock_init.assert_called_once() - mock_init.assert_called_once_with( - address="auto", - ignore_reinit_error=True, - namespace=config.ray_namespace, - runtime_env={ - "env_vars": { - launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: config.log.save_dir, - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "1", - } + with mock.patch.dict( + launcher.MODE_MAP, + { + "both": mock_both, }, - ) - mock_load.assert_called_once_with("dummy.yaml") - mock_both.assert_called_once_with(config) - mock_setup.assert_called_once_with( - namespace=namespace, - ) - mock_stop.assert_called_once_with( - namespace=namespace, - ) + ): + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins" + ), + ): + launcher.main() + mock_init.assert_called_once() + mock_init.assert_called_once_with( + address="auto", + ignore_reinit_error=True, + namespace=config.ray_namespace, + runtime_env={ + "env_vars": { + launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + launcher.LOG_DIR_ENV_VAR: config.log.save_dir, + launcher.LOG_LEVEL_ENV_VAR: config.log.level, + launcher.LOG_NODE_IP_ENV_VAR: "1", + } + }, + ) + mock_load.assert_called_once_with("dummy.yaml") + mock_both.assert_called_once_with(config) + mock_setup.assert_called_once_with( + namespace=namespace, + ) + mock_stop.assert_called_once_with( + namespace=namespace, + ) @mock.patch("trinity.cli.launcher.studio") def test_main_studio_command(self, mock_studio): @@ -156,60 +176,70 @@ def test_multi_stage_run( ] mock_load.return_value = config mock_checkpoint_path.return_value = "/path/to/hf/checkpoint" - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins" - ), + with mock.patch.dict( + launcher.MODE_MAP, + { + "train": mock_train, + "both": mock_both, + }, ): - launcher.main() - self.assertEqual(mock_init.call_count, 2) - self.assertEqual(mock_shutdown.call_count, 2) - mock_train.assert_called_once() - mock_both.assert_called_once() - expected_calls = [ - mock.call( - address="auto", - ignore_reinit_error=True, - namespace=f"{config.project}/{config.name}/sft_warmup", - runtime_env={ - "env_vars": { - launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: os.path.join( - config.checkpoint_root_dir, - config.project, - f"{config.name}/sft_warmup", - "log", - ), - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "0", - } - }, - ), - mock.call( - address="auto", - ignore_reinit_error=True, - namespace=f"{config.project}/{config.name}/grpo", - runtime_env={ - "env_vars": { - launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: os.path.join( - config.checkpoint_root_dir, config.project, f"{config.name}/grpo", "log" - ), - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "0", - } - }, - ), - ] - mock_init.assert_has_calls(expected_calls) - self.assertEqual(mock_checkpoint_path.call_count, 2) - self.assertEqual(mock_train.call_args[0][0].model.model_path, config.model.model_path) - self.assertEqual(mock_both.call_args[0][0].model.model_path, "/path/to/hf/checkpoint") - self.assertEqual( - mock_both.call_args[0][0].trainer.trainer_config.actor_rollout_ref.model.path, - "/path/to/hf/checkpoint", - ) + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins" + ), + ): + launcher.main() + self.assertEqual(mock_init.call_count, 2) + self.assertEqual(mock_shutdown.call_count, 2) + mock_train.assert_called_once() + mock_both.assert_called_once() + expected_calls = [ + mock.call( + address="auto", + ignore_reinit_error=True, + namespace=f"{config.project}/{config.name}/sft_warmup", + runtime_env={ + "env_vars": { + launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + launcher.LOG_DIR_ENV_VAR: os.path.join( + config.checkpoint_root_dir, + config.project, + f"{config.name}/sft_warmup", + "log", + ), + launcher.LOG_LEVEL_ENV_VAR: config.log.level, + launcher.LOG_NODE_IP_ENV_VAR: "0", + } + }, + ), + mock.call( + address="auto", + ignore_reinit_error=True, + namespace=f"{config.project}/{config.name}/grpo", + runtime_env={ + "env_vars": { + launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + launcher.LOG_DIR_ENV_VAR: os.path.join( + config.checkpoint_root_dir, + config.project, + f"{config.name}/grpo", + "log", + ), + launcher.LOG_LEVEL_ENV_VAR: config.log.level, + launcher.LOG_NODE_IP_ENV_VAR: "0", + } + }, + ), + ] + mock_init.assert_has_calls(expected_calls) + self.assertEqual(mock_checkpoint_path.call_count, 2) + self.assertEqual(mock_train.call_args[0][0].model.model_path, config.model.model_path) + self.assertEqual(mock_both.call_args[0][0].model.model_path, "/path/to/hf/checkpoint") + self.assertEqual( + mock_both.call_args[0][0].trainer.trainer_config.actor_rollout_ref.model.path, + "/path/to/hf/checkpoint", + ) if __name__ == "__main__": diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index fb9cdf670d..985d22722f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -6,7 +6,6 @@ from transformers import AutoTokenizer from tests.tools import ( - RayUnittestBase, RayUnittestBaseAysnc, get_api_model_path, get_model_path, @@ -127,6 +126,7 @@ def setUp(self): async def test_generate( self, ): + await self.model_wrapper.prepare() prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times if self.use_async: @@ -228,7 +228,7 @@ async def test_generate( (20, None, 1), ], ) -class TestModelLen(RayUnittestBase): +class TestModelLen(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -242,7 +242,8 @@ def setUp(self): self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) - def test_model_len(self): + async def test_model_len(self): + await self.model_wrapper.prepare() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -272,7 +273,7 @@ def test_model_len(self): self.assertEqual(len(exps[0].tokens), self.max_model_len) -class TestAPIServer(RayUnittestBase): +class TestAPIServer(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -291,7 +292,9 @@ def setUp(self): self.engines[0], engine_type="vllm", enable_history=False ) - def test_api(self): + async def test_api(self): + await self.model_wrapper.prepare() + await self.model_wrapper_no_history.prepare() openai_client = self.model_wrapper.get_openai_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -361,6 +364,8 @@ def setUp(self): ) async def test_api_async(self): + await self.model_wrapper.prepare() + await self.model_wrapper_no_history.prepare() openai_client = self.model_wrapper.get_openai_async_client() messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -528,7 +533,7 @@ def test_action_mask_with_tools(self): (False, None), ], ) -class TestAPIServerToolCall(RayUnittestBase): +class TestAPIServerToolCall(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() self.config.mode = "explore" @@ -552,13 +557,15 @@ def setUp(self): self.engines[0], engine_type="vllm", enable_history=False ) - def test_api_tool_calls(self): + async def test_api_tool_calls(self): """Tests the full conversation flow of a tool call via the OpenAI API. Note: This test require a model that supports tool calls and thinking mode, e.g. Qwen3-1.7B. """ import json import time + await self.model_wrapper.prepare() + await self.model_wrapper_no_history.prepare() tokenizer = AutoTokenizer.from_pretrained(get_api_model_path()) print_debug("\n\n" + "=" * 30 + " Running test_api_tool_calls " + "=" * 30) start_time = time.time() diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index f0612c0bda..228e888c83 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -1,19 +1,32 @@ """Tests for explorer.""" +import asyncio import json +import multiprocessing import os -from abc import abstractmethod +import random +import shutil from datetime import datetime +import httpx +import openai +import ray + from tests.tools import ( RayUnittestBase, + RayUnittestBaseAysnc, TensorBoardParser, get_checkpoint_path, get_model_path, get_template_config, get_unittest_dataset_config, ) +from trinity.buffer import get_buffer_reader from trinity.buffer.utils import default_storage_path -from trinity.cli.launcher import explore +from trinity.cli.launcher import explore, run_stage +from trinity.common.config import StorageConfig +from trinity.common.constants import StorageType +from trinity.explorer.explorer import Explorer +from trinity.manager.state_manager import StateManager class BaseExplorerCase(RayUnittestBase): @@ -30,10 +43,6 @@ def setUp(self): self.config.synchronizer.sync_interval = 2 self.config.explorer.eval_interval = 4 - @abstractmethod - def test_explorer(self): - """Test explorer""" - class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): @@ -75,10 +84,6 @@ def test_explorer(self): class TestExplorerGSM8k(BaseExplorerCase): def test_explorer(self): - import ray - - from trinity.explorer.explorer import Explorer - self.config.algorithm.repeat_times = 2 self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") @@ -122,3 +127,133 @@ def test_explorer(self): exp = json.loads(lines[0]) self.assertEqual(exp["response_length"], 8192) ray.get(explorer.shutdown.remote()) + + +def run_serve(config): + config.check_and_update() + run_stage(config, "auto") + + +def run_agent(base_url, model_path: str): + client = openai.Client(base_url=base_url, api_key="testkey") + contents = [ + "Hello, how are you?", + "What is the capital of China?", + "Tell me a joke.", + "Explain the theory of relativity.", + "What is the meaning of life?", + "How does a computer work?", + "What is the weather like today?", + "Can you recommend a good book?", + "What is the best way to learn programming?", + "Describe the process of photosynthesis.", + ] + response = client.chat.completions.create( + model=model_path, + messages=[{"role": "user", "content": random.choice(contents)}], + ) + return response.choices[0].message.content + + +class ServeTest(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.config.mode = "serve" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.algorithm.repeat_times = 1 + self.config.monitor.monitor_type = "tensorboard" + self.config.project = "Trinity-unittest" + self.config.explorer.rollout_model.engine_num = 4 + self.config.explorer.rollout_model.enable_openai_api = True + self.config.checkpoint_root_dir = get_checkpoint_path() + self.config.explorer.api_port = 8010 + self.config.explorer.service_status_check_interval = 30 + self.config.buffer.trainer_input.experience_buffer = StorageConfig( + name="experience_buffer", + storage_type=StorageType.SQL, + ) + self.config.check_and_update() + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + + async def test_serve(self): # noqa: C901 + serve_process = multiprocessing.Process(target=run_serve, args=(self.config,)) + serve_process.start() + await asyncio.sleep(10) + + state_manager = StateManager( + path=self.config.checkpoint_job_dir, + explorer_name=self.config.explorer.name, + ) + + # wait for explorer initialization + for i in range(30): + try: + server_url = state_manager.load_explorer_server_url() + except Exception: + server_url = None + if server_url: + break + await asyncio.sleep(3) + if not server_url: + raise RuntimeError("Explorer server URL not found.") + # wait for server setup + for i in range(10): + try: + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health") + if response.status_code == 200: + break + except Exception: + pass + await asyncio.sleep(2) + + task_num = 10 + apps = [] + for i in range(task_num): + app_process = multiprocessing.Process( + target=run_agent, args=(server_url + "/v1", self.config.model.model_path) + ) + apps.append(app_process) + app_process.start() + + for app in apps: + app.join(timeout=60) + self.assertFalse(app.is_alive()) + + finish_step = None + + for i in range(20): + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/metrics") + self.assertEqual(response.status_code, 200) + metrics = response.json() + metrics_keys = list(metrics.keys()) + self.assertIn("explore_step_num", metrics_keys) + self.assertIn("rollout/total_experience_count", metrics_keys) + self.assertIn("rollout/model_0/total_request_count", metrics_keys) + self.assertIn("rollout/model_3/model_version", metrics_keys) + if not finish_step and metrics["rollout/total_experience_count"] == task_num: + finish_step = metrics["explore_step_num"] + if finish_step and metrics["explore_step_num"] >= finish_step + 1: + # wait for one more step to ensure all data are written to buffer + break + await asyncio.sleep(3) + + serve_process.terminate() + serve_process.join(timeout=10) + + # check buffer + self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 5 + buffer_reader = get_buffer_reader( + self.config.buffer.trainer_input.experience_buffer, + self.config.buffer, + ) + exps = await buffer_reader.read_async(batch_size=10) + for exp in exps: + self.assertTrue(len(exp.tokens) > 0) + self.assertEqual(len(exps), task_num) + + def tearDown(self): + shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index cd520b9fd2..36c73063a9 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,7 +1,7 @@ import asyncio import time import unittest -from typing import List +from typing import List, Optional import ray import torch @@ -147,6 +147,12 @@ def init_process_group( ) -> None: pass + def has_api_server(self) -> bool: + return False + + def get_api_server_url(self) -> Optional[str]: + return None + @ray.remote class DummyAuxiliaryModel(InferenceModel): @@ -171,8 +177,8 @@ def init_process_group( def has_api_server(self) -> bool: return True - def api_server_ready(self) -> str: - return "http://localhosts:12345" + def get_api_server_url(self) -> str: + return "http://localhost:12345" def generate_tasks( diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 2aaf61569c..3dc67b692f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -720,3 +720,7 @@ def test_trainer(self): ) self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0) self.assertEqual(step_num, 2) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index a741788a8b..5e535a6d25 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -38,6 +38,18 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): "`buffer_config.trainer_input.auxiliary_buffers` is required in MIX algorithm" ) + if buffer_config.trainer_input.auxiliary_buffers.get(self.sft_dataset_name) is None: + raise ValueError( + f"`{self.sft_dataset_name}` is not found in `buffer_config.trainer_input.auxiliary_buffers`" + ) + expert_storage_config = buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] + + if expert_storage_config.schema_type != "sft": + self.logger.warning( + f"schema_type of {self.sft_dataset_name} is not `sft`, set it to `sft`" + ) + expert_storage_config.schema_type = "sft" + # expert experience buffer expert_buffer_config = copy.deepcopy(buffer_config) expert_buffer_config.train_batch_size = expert_batch_size diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index 00a2ecb4b0..f92b4c638e 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -114,6 +114,7 @@ async def process(self, exps: List[Experience]) -> Dict: """ if self.input_store is not None: await self.input_store.write_async(exps) + metrics = {} # Process experiences through operators diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 4cbe083220..e7e42378de 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -32,7 +32,7 @@ class SQLStorage: def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(f"sql_{storage_config.name}", in_ray_actor=True) - if storage_config.path is None: + if not storage_config.path: storage_config.path = default_storage_path(storage_config, config) self.engine, self.table_model_cls = init_engine( db_url=storage_config.path, diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 21495cadc6..3b88e85cf7 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -63,6 +63,18 @@ def train(config: Config) -> None: logger.error(f"Trainer failed:\n{traceback.format_exc()}") +def serve(config: Config) -> None: + """Run explorer in server mode.""" + try: + explorer = Explorer.get_actor(config) + ray.get(explorer.prepare.remote()) + ray.get(explorer.sync_weight.remote()) + ray.get(explorer.serve.remote()) + ray.get(explorer.shutdown.remote()) + except Exception: + logger.error(f"Explorer failed:\n{traceback.format_exc()}") + + def both(config: Config) -> None: """Setup both explorer and trainer. @@ -125,6 +137,15 @@ def both(config: Config) -> None: logger.error(f"Explorer or Trainer failed:\n{traceback.format_exc()}") +MODE_MAP = { + "explore": explore, + "train": train, + "both": both, + "bench": bench, + "serve": serve, +} + + def run_stage(config: Config, ray_address: str) -> None: envs = { PLUGIN_DIRS_ENV_VAR: os.environ.get(PLUGIN_DIRS_ENV_VAR, ""), @@ -141,14 +162,7 @@ def run_stage(config: Config, ray_address: str) -> None: pprint(config) try: check_and_run_task_pipeline(config) - if config.mode == "explore": - explore(config) - elif config.mode == "train": - train(config) - elif config.mode == "both": - both(config) - elif config.mode == "bench": - bench(config) + MODE_MAP[config.mode](config) finally: if config.monitor.enable_ray_timeline: timeline_file = os.path.join(config.monitor.cache_dir, "timeline.json") diff --git a/trinity/common/config.py b/trinity/common/config.py index ac2764506f..258bd78b06 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -161,9 +161,8 @@ class ExperiencePipelineConfig: # The list of experience operators to apply, operators will be applied in the order they are defined operators: List[OperatorConfig] = field(default_factory=list) save_input: bool = True # whether to save the input experiences - input_save_path: Optional[ - str - ] = None # the path to save the input experiences, can be a jsonl file or a sqlite database file + # the path to save the input experiences, can be a jsonl file or a sqlite database file + input_save_path: Optional[str] = None # The following fields are experimental, do not set them unless you know what you are doing @@ -420,6 +419,15 @@ class ExplorerConfig: # for benchmark bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint + # for serve mode + api_port: int = 8010 + # listen on all interfaces by default + listen_address: str = "0.0.0.0" + # check the running status of the server every 60 seconds + service_status_check_interval: int = 60 + # keep at least 1 model in running status + min_running_model_num: int = 1 + @dataclass class TrainerConfig: @@ -881,7 +889,7 @@ def check_and_update(self) -> Config: # noqa: C901 self._check_algorithm() # check mode - if self.mode not in ["explore", "train", "both", "bench"]: + if self.mode not in ["explore", "train", "both", "bench", "serve"]: raise ValueError(f"Invalid mode: {self.mode}") # prepare for the checkpoint directory @@ -927,7 +935,7 @@ def check_and_update(self) -> Config: # noqa: C901 * self.explorer.rollout_model.tensor_parallel_size ) if ( - self.mode in ["train", "explore", "bench"] + self.mode in ["train", "explore", "bench", "serve"] and self.synchronizer.sync_method == SyncMethod.NCCL ): self.synchronizer.sync_method = SyncMethod.CHECKPOINT diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index c9ae2f40a1..fdbb2088a3 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -362,6 +362,7 @@ async def run_api_server_in_ray_actor( str(port), "--model", model_path, + "--enable-server-load-tracking", # enable tracking for load balancing ] if enable_auto_tool_choice: cli_args.append("--enable-auto-tool-choice") diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 35c4ec8b2a..80f978f227 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -2,15 +2,16 @@ """Base Model Class""" import asyncio import socket -import time from abc import ABC, abstractmethod -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union +import httpx import openai import ray import torch from torch import Tensor +from trinity.common.constants import RunningStatus from trinity.common.experience import Experience from trinity.utils.log import get_logger @@ -46,6 +47,14 @@ def get_available_address(self) -> Tuple[str, int]: port = s.getsockname()[1] return address, port + def has_api_server(self) -> bool: + """Check if the model has an API server.""" + return False + + def get_api_server_url(self) -> Optional[str]: + """Get the API server URL if available.""" + return None + def _history_recorder(func): """Decorator to record history of the model calls.""" @@ -77,6 +86,31 @@ def __init__(self, model: Any, engine_type: str = "vllm", enable_history: bool = self.logger = get_logger(__name__) self.enable_history = enable_history self.history = [] + self.status = RunningStatus.RUNNING + self.request_count = 0 + + async def prepare(self) -> None: + """Prepare the model wrapper.""" + if await self.model.has_api_server.remote(): + self.api_address = await self.model.get_api_server_url.remote() + if self.api_address is None: + raise RuntimeError( + "Failed to connect to the API server. Please set `enable_openai_api` to `True`." + ) + max_retries = 30 + interval = 2 # seconds + for i in range(max_retries): + try: + async with httpx.AsyncClient() as client: + response = await client.get(self.api_address + "/health", timeout=5) + if response.status_code == 200: + return + except Exception as e: + self.logger.info(f"API server not ready (attempt {i+1}/{max_retries}): {e}") + await asyncio.sleep(interval) + raise RuntimeError( + f"API server at {self.api_address} not ready after {max_retries} attempts." + ) def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: """Record experiences to history.""" @@ -172,31 +206,6 @@ async def model_version_async(self) -> int: """Get the version of the model.""" return await self.model.get_model_version.remote() - def _get_api_server_address(self) -> str: - """Get the address of the API server.""" - if self.api_address: - return self.api_address - if not ray.get(self.model.has_api_server.remote()): - raise ValueError( - "OpenAI API server is not running on current model." - "Please set `enable_openai_api` to `True`." - ) - api_address = None - while True: - api_address = ray.get(self.model.api_server_ready.remote()) - if api_address is not None: - break - else: - self.logger.info("Waiting for OpenAI API server to be ready...") - time.sleep(5) - if api_address is None: - raise RuntimeError( - "Failed to connect to the API server. Please check the API server is running." - ) - self.api_address = api_address - self.logger.info(f"Successfully connect to API server at {api_address}") - return api_address - def get_openai_client(self) -> openai.OpenAI: """Get the openai client. @@ -205,9 +214,12 @@ def get_openai_client(self) -> openai.OpenAI: """ if self.openai_client is not None: return self.openai_client - api_address = self._get_api_server_address() + if not self.api_address: + raise ValueError( + "API server is not enabled for this model. OpenAI client is unavailable." + ) self.openai_client = openai.OpenAI( - base_url=api_address, + base_url=f"{self.api_address}/v1", api_key="EMPTY", ) if self.enable_history: @@ -231,10 +243,13 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: """ if self.openai_async_client is not None: return self.openai_async_client + if not self.api_address: + raise ValueError( + "API server is not enabled for this model. OpenAI async client is unavailable." + ) # first make sure that we have the sync openai client - api_address = self._get_api_server_address() self.openai_async_client = openai.AsyncOpenAI( - base_url=api_address, + base_url=f"{self.api_address}/v1", api_key="EMPTY", ) if self.enable_history: @@ -252,6 +267,21 @@ async def record_chat_completions(*args, **kwargs): setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id) return self.openai_async_client + async def get_current_load(self) -> int: + """Get the current load metrics of the model.""" + if not self.api_address: + raise ValueError( + "API server is not enabled for this model. Load metrics is unavailable." + ) + with httpx.AsyncClient() as client: + response = await client.get(f"{self.api_address}/load") + data = response.json() + return data["server_load"] + + async def sync_model_weights(self, model_version: int) -> None: + """Sync the model weights""" + await self.model.sync_model.remote(model_version) + def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: """Extract experiences from the history.""" if not self.enable_history: diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 27c62b535c..2427eba64e 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -135,6 +135,7 @@ def get_checkpoint_dir_with_step_num( checkpoint_root_path: str, trainer_type: str = "verl", step_num: Optional[int] = None, + raise_error: bool = True, ) -> Tuple[str, int]: """Get the checkpoint directory from a root checkpoint directory. @@ -144,12 +145,16 @@ def get_checkpoint_dir_with_step_num( step_num (Optional[int], optional): The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None. + raise_error (bool): Whether to raise an error if the checkpoint does not exist. Returns: Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. + If the checkpoint does not exist and `raise_error` is False, return (None, 0). """ if trainer_type == "verl": - return get_verl_checkpoint_info(checkpoint_path=checkpoint_root_path, step_num=step_num) + return get_verl_checkpoint_info( + checkpoint_path=checkpoint_root_path, step_num=step_num, raise_error=raise_error + ) else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}") @@ -193,7 +198,7 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): def get_verl_checkpoint_info( - checkpoint_path: str, step_num: Optional[int] = None + checkpoint_path: str, step_num: Optional[int] = None, raise_error: bool = True ) -> Tuple[str, int]: """Get the checkpoint directory from a Verl root checkpoint directory. @@ -202,6 +207,7 @@ def get_verl_checkpoint_info( step_num (Optional[int], optional): The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None. + raise_error (bool): Whether to raise an error if the checkpoint does not exist. Returns: Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. @@ -215,11 +221,16 @@ def get_verl_checkpoint_info( ) as f: # TODO: this file may be modified simultaneously iteration = f.read().strip() return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration) - else: + elif raise_error: raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") + else: + return None, 0 # type: ignore else: # load specific iteration checkpoint - return os.path.join(checkpoint_path, f"global_step_{step_num}"), step_num + path = os.path.join(checkpoint_path, f"global_step_{step_num}") + if not os.path.exists(path) and raise_error: + raise FileNotFoundError(f"Checkpoint {path} not found") + return path, step_num # copy from verl/scripts/model_merger.py diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 3984453e78..d592c98654 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -1,9 +1,9 @@ """A wrapper around the vllm.AsyncEngine to handle async requests.""" +import asyncio import os -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence -import aiohttp import ray import torch import vllm @@ -94,6 +94,7 @@ def __init__( self.model_version = 0 # TODO: resume the value from the checkpoint self.api_server_host = None self.api_server_port = None + self.api_server = None async def _initialize_tokenizer(self): if self.tokenizer is None: @@ -350,12 +351,19 @@ async def convert_messages_to_experience( action_mask=action_mask[prompt_length:], # Exclude the prompt tokens ) - def shutdown(self): + async def shutdown(self): """Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be orphaned and will not be killed when the parent process exits, and they won't be able to be tracked by Ray anymore. """ + if self.api_server is not None: + self.api_server.cancel() + try: + await self.api_server + except asyncio.CancelledError: + pass + self.api_server = None if hasattr(self.async_llm, "shutdown"): self.logger.info("Shutting down vLLM engine") self.async_llm.shutdown() @@ -384,12 +392,12 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model(self, model_version: int) -> bool: + async def sync_model(self, model_version: int) -> int: """Sync model weights to vLLM.""" await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version - return True + return model_version async def init_process_group( self, @@ -420,50 +428,36 @@ async def init_process_group( ) async def run_api_server(self): - """Run the OpenAI API server in a Ray actor. - - Note: - Do not use `ray.get()` on this method. - This method will run forever until the server is shut down. - """ + """Run the OpenAI API server in a Ray actor.""" if not (self.api_server_host is None or self.api_server_port is None): raise RuntimeError("API server is already running.") from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor self.api_server_host, self.api_server_port = self.get_available_address() - await run_api_server_in_ray_actor( - self.async_llm, - self.api_server_host, - self.api_server_port, - self.config.model_path, - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + self.api_server_host, + self.api_server_port, + self.config.model_path, + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + ) ) - async def has_api_server(self) -> bool: + def has_api_server(self) -> bool: return self.config.enable_openai_api - async def api_server_ready(self) -> Union[str, None]: - """Check if the OpenAI API server is ready. + def get_api_server_url(self) -> Optional[str]: + """Get the URL of the OpenAI API server. Returns: api_url (str): The URL of the OpenAI API server. """ - if not await self.has_api_server(): - return None - try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://{self.api_server_host}:{self.api_server_port}/health" - ) as response: - if response.status == 200: - return f"http://{self.api_server_host}:{self.api_server_port}/v1" - else: - return None - except Exception as e: - self.logger.error(e) + if not self.has_api_server(): return None + return f"http://{self.api_server_host}:{self.api_server_port}" async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() diff --git a/trinity/explorer/api/__init__.py b/trinity/explorer/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py new file mode 100644 index 0000000000..66b5e2e97a --- /dev/null +++ b/trinity/explorer/api/api.py @@ -0,0 +1,65 @@ +import traceback + +import httpx +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response + +app = FastAPI() + + +# Forward openAI requests to a model instance + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + # Currently, we do not support streaming chat completions + body = await request.json() + url = await request.app.state.service.allocate_model() + try: + async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: + resp = await client.post(f"{url}/v1/chat/completions", json=body) + except Exception: + return Response( + status_code=500, + content=f"Error forwarding request to model at {url}: {traceback.format_exc()}", + ) + resp_data = resp.json() + await request.app.state.service.record_experience(resp_data) + return JSONResponse(content=resp_data) + + +@app.get("/v1/models") +async def show_available_models(request: Request): + body = await request.json() + url = await request.app.state.service.allocate_model(increase_count=False) + async with httpx.AsyncClient() as client: + resp = await client.get(f"{url}/v1/models", json=body) + return JSONResponse(content=resp.json()) + + +@app.get("/health") +async def health(request: Request) -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/metrics") +async def metrics(request: Request): + """Get the metrics of the service.""" + metrics = request.app.state.service.collect_metrics() + metrics["explore_step_num"] = request.app.state.service.explorer.explore_step_num + return JSONResponse(content=metrics) + + +async def serve_http(app: FastAPI, host: str, port: int = None): + config = uvicorn.Config(app, host=host, port=port) + server = uvicorn.Server(config) + await server.serve() + + +async def run_app(service, listen_address: str, port: int = None) -> FastAPI: + app.state.service = service + app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout + print(f"API server running on {listen_address}:{port}") + await serve_http(app, listen_address, port) diff --git a/trinity/explorer/api/service.py b/trinity/explorer/api/service.py new file mode 100644 index 0000000000..ffdf2cfd9a --- /dev/null +++ b/trinity/explorer/api/service.py @@ -0,0 +1,160 @@ +import asyncio +import time +from collections import deque +from typing import Dict, List + +import torch + +from trinity.common.constants import RunningStatus +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.explorer.explorer import Explorer +from trinity.utils.log import get_logger + + +class ExplorerService: + def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010): + self.logger = get_logger(__name__) + self.explorer = explorer + self.app = None + self.port = port + self.listen_address = listen_address + self.running = False + self.models: List[ModelWrapper] = [ModelWrapper(model) for model in explorer.models] + self.min_running_model_num = explorer.config.explorer.min_running_model_num + self.check_interval = explorer.config.explorer.service_status_check_interval + self.max_timeout = explorer.config.explorer.max_timeout + self.running_models: deque[int] = deque() # indices of running models + self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index + self.latest_model_version = 0 + self.experience_queue = asyncio.Queue() + self.experience_count = 0 + + async def serve(self): + from trinity.explorer.api.api import run_app + + if self.running: + self.logger.warning("Server is already running.") + return + + self.running = True + await asyncio.gather(*[model.prepare() for model in self.models]) + + for i, _ in enumerate(self.models): + self.running_models.append(i) + + self.serve_task = asyncio.create_task( + run_app(service=self, listen_address=self.listen_address, port=self.port) + ) + self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop()) + + async def model_weights_sync_loop(self): + self.logger.info("Starting model weights synchronization loop.") + while self.running: + for idx in list(self.running_models): + if ( + len(self.running_models) > self.explorer.config.explorer.min_running_model_num + and self.models[idx].model_version < self.latest_model_version + ): + self.running_models.remove(idx) + self.models[idx].status = RunningStatus.REQUIRE_SYNC + self.logger.info(f"Model {idx} scheduled for synchronization.") + future = asyncio.create_task(self._wait_for_sync_start(idx)) + self.sync_task_map[future] = idx + future.add_done_callback(self._sync_model_weights) + # wait half interval + await asyncio.sleep(self.check_interval / 2) + self.logger.info("Model weights synchronization loop stopped.") + + def set_latest_model_version(self, version: int) -> None: + if version > self.latest_model_version: + self.latest_model_version = version + self.logger.info(f"Updated latest model version to {version}.") + + async def _wait_for_sync_start(self, index: int): + start_time = time.time() + while time.time() - start_time < self.max_timeout: + current_load = await self.models[index].get_current_load() + if current_load == 0: + self.models[index].status = RunningStatus.WAITING_SYNC + self.logger.info(f"Model {index} begins synchronization.") + return + else: + await asyncio.sleep(2) + raise asyncio.TimeoutError( + f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}" + ) + + async def _sync_model_weights(self, task: asyncio.Future): + index = self.sync_task_map.pop(task) + latest_version = self.latest_model_version # capture the latest version + if task.cancelled(): + self.logger.warning(f"Synchronization of model {index} was cancelled.") + elif task.exception(): + self.logger.error(f"Error during synchronization of model {index}: {task.exception()}") + else: + await self.models[index].sync_model_weights(latest_version) + self.logger.info(f"Model {index} synchronized to version {latest_version}.") + self.running_models.append(index) + self.models[index].status = RunningStatus.RUNNING + + async def allocate_model(self, increase_count: bool = True) -> str: + model = self.models[self.running_models[0]] + if increase_count: + model.request_count += 1 + self.running_models.rotate(-1) + return model.api_address + + def collect_metrics(self) -> Dict: + metrics = {} + for i, model in enumerate(self.models): + metrics[f"rollout/model_{i}/total_request_count"] = model.request_count + metrics[f"rollout/model_{i}/model_version"] = model.model_version + metrics["rollout/total_experience_count"] = self.experience_count + return metrics + + async def check_requiring_sync_models(self): + if not self.running: + self.logger.warning("Server is not running.") + return + await asyncio.gather( + *[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)] + ) + + async def record_experience(self, response): + experiences = [] + for choice in response["choices"]: + exp = Experience( + tokens=torch.cat( + ( + torch.tensor(response["prompt_token_ids"], dtype=torch.int32), + torch.tensor(choice["token_ids"], dtype=torch.int32), + ) + ), + logprobs=choice.get("logprobs", None), + prompt_length=len(response["prompt_token_ids"]), + response_text=choice.get("message", {}).get("content", ""), + ) + experiences.append(exp) + self.experience_count += len(experiences) + for exp in experiences: + await self.experience_queue.put(exp) + + async def get_all_experiences(self) -> List: + experiences = [] + while not self.experience_queue.empty(): + experiences.append(await self.experience_queue.get()) + return experiences + + async def shutdown(self): + if not self.running: + self.logger.warning("Server is not running.") + return + self.sync_model_weights_task.cancel() + self.serve_task.cancel() + try: + await self.serve_task + except asyncio.CancelledError: + pass + self.running = False + self.logger.info("API server shut down.") diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 0d32aa2d10..e90a82cf8c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -23,9 +23,11 @@ SyncStyle, ) from trinity.common.models import create_inference_models +from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.explorer.scheduler import Scheduler from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer +from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR, gather_metrics from trinity.utils.plugin_loader import load_plugins @@ -48,10 +50,12 @@ def __init__(self, config: Config): self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0) - self.taskset = get_buffer_reader( - self.config.buffer.explorer_input.taskset, self.config.buffer + self.taskset = ( + get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + if self.config.mode != "serve" + else None ) - self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models) + self.scheduler = None self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, group=self.config.group, @@ -145,16 +149,18 @@ async def prepare(self) -> None: """Preparation before running.""" try: await self.experience_pipeline.prepare.remote() - + self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready model_ready_ref = [model.__ray_ready__.remote() for model in self.models] await asyncio.gather(*model_ready_ref) + self.logger.info("All rollout models are ready.") if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() await self.setup_weight_sync_group(master_address, master_port) - - await self.scheduler.start() + if self.config.mode != "serve": + self.scheduler = Scheduler(self.config, self.models, self.auxiliary_models) + await self.scheduler.start() if self.config.explorer.eval_on_startup and self.explore_step_num == 0: await self.eval() @@ -298,7 +304,10 @@ async def benchmark(self) -> bool: return True async def save_checkpoint(self, sync_weight: bool = False) -> None: - await self._finish_steps(self.last_sync_step + 1, self.explore_step_num, self.model_version) + if self.scheduler: + await self._finish_steps( + self.last_sync_step + 1, self.explore_step_num, self.model_version + ) if sync_weight: # sync weights @@ -391,6 +400,55 @@ def _init_experience_pipeline(self) -> ray.actor.ActorHandle: .remote(self.config) ) + @Experimental + async def serve(self) -> None: + """Run the explorer in serving mode. + + In serving mode, the explorer starts an OpenAI compatible server to handle requests. + Agent applications can be deployed separately and interact with the explorer via the API. + + + .. code-block:: python + + import openai + + + client = openai.OpenAI( + base_url=f"{explorer_server_url}/v1", + api_key="EMPTY", + ) + response = client.chat.completions.create( + model=config.model.model_path, + messages=[{"role": "user", "content": "Hello!"}] + ) + """ + from trinity.explorer.api.service import ExplorerService + + self.service = ExplorerService( + self, + listen_address=self.config.explorer.listen_address, + port=self.config.explorer.api_port, + ) + await self.service.serve() + self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}" + self.logger.info( + f"Explorer API Server is started on {self.server_url} and listening to {self.service.listen_address}." + ) + self.state.save_explorer_server_url(self.server_url) + while True: + self.explore_step_num += 1 + await asyncio.sleep(self.config.explorer.service_status_check_interval) + # process experiences generated in the last interval + exps = await self.service.get_all_experiences() + metrics = await self.experience_pipeline.process.remote(exps) + metrics.update(self.service.collect_metrics()) + self.monitor.log(metrics, self.explore_step_num) + # get the latest checkpoint + _, step_num = get_checkpoint_dir_with_step_num( + self.config.checkpoint_job_dir, raise_error=False + ) + self.service.set_latest_model_version(step_num) + @classmethod def get_actor(cls, config: Config): """Get a Ray actor for the explorer.""" diff --git a/trinity/explorer/explorer_client.py b/trinity/explorer/explorer_client.py new file mode 100644 index 0000000000..311b310038 --- /dev/null +++ b/trinity/explorer/explorer_client.py @@ -0,0 +1,49 @@ +from functools import partial + +import httpx +import openai +import requests + + +class ExplorerClient: + def __init__(self, base_url: str): + self.base_url = base_url + self.session_id = self.init_session() + + def init_session(self) -> str: + response = requests.post(f"{self.base_url}/allocate") + data = response.json() + return data["session_id"] + + def get_openai_client(self) -> openai.OpenAI: + client = openai.OpenAI( + base_url=self.base_url + "/v1", + api_key="EMPTY", + ) + client.chat.completions.create = partial( + client.chat.completions.create, extra_body={"session_id": self.session_id} + ) + return client + + def get_openai_async_client(self) -> openai.AsyncOpenAI: + client = openai.AsyncOpenAI( + base_url=self.base_url + "/v1", + api_key="EMPTY", + ) + client.chat.completions.create = partial( + client.chat.completions.create, extra_body={"session_id": self.session_id} + ) + return client + + def feedback(self, reward: float): + response = requests.post( + f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward} + ) + return response.json() + + async def feedback_async(self, reward: float): + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward} + ) + return response.json() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 6af14b3d83..fecde5e61b 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -62,6 +62,9 @@ def _create_runner(self): .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) ) + async def prepare(self): + await self.runner.prepare.remote() + async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: """ Returns: @@ -100,9 +103,10 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: status.metric["task_run_time"] = end_time - start_time return status, exps, self.runner_id - def restart_runner(self): + async def restart_runner(self): old_runner = self.runner self.runner = self._create_runner() + await self.runner.prepare.remote() try: ray.kill(old_runner) except Exception: @@ -164,7 +168,7 @@ def __init__( self.total_scheduled = 0 self.total_completed = 0 - def _create_runner( + async def _create_runner( self, runner_id: int, ): @@ -177,10 +181,11 @@ def _create_runner( ], config=self.config, ) + await runner.prepare() self.runners[runner_id] = runner self.idle_runners.add(runner_id) - def _restart_runner(self, runner_id: int): + async def _restart_runner(self, runner_id: int): """Restart a runner.""" self.runners[runner_id].restart_runner() @@ -257,8 +262,7 @@ async def start(self) -> None: if self.running: return self.running = True - for i in range(self.runner_num): - self._create_runner(i) + await asyncio.gather(*[self._create_runner(i) for i in range(self.runner_num)]) self.scheduler_task = asyncio.create_task(self._scheduler_loop()) ready_refs = [runner.runner.__ray_ready__.remote() for runner in self.runners.values()] await asyncio.gather(*ready_refs) @@ -372,7 +376,7 @@ async def get_results( self._clear_timeout_tasks(batch_id=batch_id) for runner_id, task in list(self.busy_runners.items()): if task.batch_id == batch_id: - self._restart_runner(runner_id) + await self._restart_runner(runner_id) statuses = [] experiences = [] @@ -444,6 +448,6 @@ async def wait_all( self._clear_timeout_tasks(batch_id) busy_runner_ids = list(self.busy_runners.keys()) for runner_id in busy_runner_ids: - self._restart_runner(runner_id) + await self._restart_runner(runner_id) raise TimeoutError(error_msg) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 4a6ed67e30..2d1ddddc31 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The Workflow Runner Module.""" +import asyncio import time import traceback from collections import defaultdict @@ -40,17 +41,26 @@ def __init__( config.explorer.rollout_model.engine_type, enable_history=config.explorer.rollout_model.enable_history, ) - self.auxiliary_models = [] - if auxiliary_models is not None: - for model in auxiliary_models: - api_client = ModelWrapper( - model, - "vllm_async", - ).get_openai_client() - self.auxiliary_models.append(api_client) + self.auxiliary_models = [ + ModelWrapper( + model, + ) + for model in (auxiliary_models or []) + ] + self.auxiliary_model_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id + async def prepare(self) -> None: + """Prepare the runner.""" + await asyncio.gather( + self.model_wrapper.prepare(), + *(aux_model.prepare() for aux_model in self.auxiliary_models), + ) + for model in self.auxiliary_models: + api_client = model.get_openai_client() + self.auxiliary_model_clients.append(api_client) + def is_alive(self): return True @@ -62,15 +72,17 @@ def _create_workflow_instance(self, task: Task) -> None: or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): - self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models) + self.workflow_instance = task.to_workflow( + self.model_wrapper, self.auxiliary_model_clients + ) else: self.workflow_instance.reset(task) - async def _run_workflow(self, workflow_isntance: Workflow) -> List[Experience]: - if workflow_isntance.asynchronous: - exps = await workflow_isntance.run_async() + async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: + if workflow_instance.asynchronous: + exps = await workflow_instance.run_async() else: - exps = workflow_isntance.run() + exps = workflow_instance.run() return exps async def _run_task(self, task: Task, repeat_times: int, run_id_base: int) -> List[Experience]: diff --git a/trinity/manager/state_manager.py b/trinity/manager/state_manager.py index 73ad4fe042..e47566d839 100644 --- a/trinity/manager/state_manager.py +++ b/trinity/manager/state_manager.py @@ -21,16 +21,19 @@ def __init__( ): self.logger = get_logger(__name__, in_ray_actor=True) self.cache_dir = path - os.makedirs(self.cache_dir, exist_ok=True) # type: ignore - self.stage_state_path = os.path.join(self.cache_dir, "stage_meta.json") # type: ignore - self.explorer_state_path = os.path.join(self.cache_dir, f"{explorer_name}_meta.json") # type: ignore - self.trainer_state_path = os.path.join(self.cache_dir, f"{trainer_name}_meta.json") # type: ignore + os.makedirs(self.cache_dir, exist_ok=True) + self.stage_state_path = os.path.join(self.cache_dir, "stage_meta.json") + self.explorer_state_path = os.path.join(self.cache_dir, f"{explorer_name}_meta.json") + self.trainer_state_path = os.path.join(self.cache_dir, f"{trainer_name}_meta.json") + self.explorer_server_url_path = os.path.join( + self.cache_dir, f"{explorer_name}_server_url.txt" + ) if check_config and config is not None: self._check_config_consistency(config) def _check_config_consistency(self, config: Config) -> None: """Check if the config is consistent with the cache dir backup.""" - backup_config_path = os.path.join(self.cache_dir, "config.json") # type: ignore + backup_config_path = os.path.join(self.cache_dir, "config.json") if not os.path.exists(backup_config_path): config.save(backup_config_path) else: @@ -75,6 +78,27 @@ def load_explorer(self) -> dict: self.logger.error(f"Failed to load explore state file: {e}") return {} + def save_explorer_server_url(self, url: str) -> None: + with open(self.explorer_server_url_path, "w", encoding="utf-8") as f: + f.write(url) + self.logger.info(f"Saved explorer server URL to {self.explorer_server_url_path}") + + def load_explorer_server_url(self) -> Optional[str]: + if os.path.exists(self.explorer_server_url_path): + try: + with open(self.explorer_server_url_path, "r", encoding="utf-8") as f: + url = f.read().strip() + self.logger.info( + "----------------------------------\n" + "Found existing explorer server URL:\n" + f" > {url}\n" + "----------------------------------" + ) + return url + except Exception as e: + self.logger.error(f"Failed to load explorer server URL file: {e}") + return None + def save_trainer( self, current_exp_index: int,