From c726efbf106707ee7e97d00436e14c88f56047ff Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 15 Jul 2025 11:46:01 +0800 Subject: [PATCH 01/12] use async model by default --- tests/common/vllm_test.py | 16 - trinity/common/models/__init__.py | 16 +- trinity/common/models/model.py | 73 ++-- trinity/common/models/vllm_async_model.py | 364 ------------------- trinity/common/models/vllm_model.py | 406 +++++++++++++--------- 5 files changed, 273 insertions(+), 602 deletions(-) delete mode 100644 trinity/common/models/vllm_async_model.py diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 0146eb075c..fe9c13e74e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -130,22 +130,6 @@ def test_generate(self): self.assertRaises(ValueError, self.model_wrapper.get_openai_client) -class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm" - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.use_v1 = False - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") - - class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index f9d092807c..6a704b8793 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -43,24 +43,14 @@ def create_inference_models( from ray.util.placement_group import placement_group, placement_group_table from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel from trinity.common.models.vllm_model import vLLMRolloutModel engine_num = config.explorer.rollout_model.engine_num tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size - if ( - config.explorer.rollout_model.enable_openai_api - and config.explorer.rollout_model.engine_type != "vllm_async" - ): - raise ValueError("OpenAI API is only supported for vllm_async engine") - rollout_engines = [] - - if config.explorer.rollout_model.engine_type == "vllm": + if config.explorer.rollout_model.engine_type.startswith("vllm"): engine_cls = vLLMRolloutModel - elif config.explorer.rollout_model.engine_type == "vllm_async": - engine_cls = vLLMAysncRolloutModel else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") @@ -122,10 +112,10 @@ def create_inference_models( for _ in range(model_config.engine_num): bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True - model_config.engine_type = "vllm_async" + model_config.engine_type = "vllm" model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( - ray.remote(vLLMAysncRolloutModel) + ray.remote(vLLMRolloutModel) .options( num_cpus=0, num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 1e3eb87058..74aece6414 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Base Model Class""" +import asyncio import socket import time from abc import ABC, abstractmethod @@ -16,35 +17,19 @@ class InferenceModel(ABC): """A model for high performance for rollout inference.""" - def generate(self, prompts: List[str], **kwargs) -> List[Experience]: - """Generate a batch of responses from a batch of prompts.""" - raise NotImplementedError - - def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - """Generate experiences from a list of history chat messages.""" - raise NotImplementedError - - def logprobs(self, token_ids: List[int]) -> Tensor: - """Generate logprobs for a list of tokens.""" - raise NotImplementedError - - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: - """Convert a list of messages into an experience.""" - raise NotImplementedError - - async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> List[Experience]: """Generate a responses from a prompt in async.""" raise NotImplementedError - async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError - async def logprobs_async(self, tokens: List[int]) -> Tensor: + async def logprobs(self, tokens: List[int]) -> Tensor: """Generate logprobs for a list of tokens in async.""" raise NotImplementedError - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience in async.""" raise NotImplementedError @@ -66,38 +51,46 @@ class ModelWrapper: # TODO: check model_type inside __init__ def __init__(self, model: Any, model_type: str = "vllm"): + assert model_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model - self.use_async = model_type == "vllm_async" self.openai_client: openai.OpenAI = None self.logger = get_logger(__name__) def generate(self, prompts: List[str], **kwargs) -> List[Experience]: - if self.use_async: - results = ray.get( - [self.model.generate_async.remote(prompt, **kwargs) for prompt in prompts] - ) - return [exp for exps in results for exp in exps] - else: - return ray.get(self.model.generate.remote(prompts, **kwargs)) + """Generate a list of experiences from a list of prompts.""" + results = ray.get([self.model.generate.remote(prompt, **kwargs) for prompt in prompts]) + return [exp for exps in results for exp in exps] + + async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience]: + """Generate a list of experiences from a list of prompts in async.""" + results = await asyncio.gather( + *[self.model.generate.remote(prompt, **kwargs) for prompt in prompts] + ) + return [exp for exps in results for exp in exps] def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - if self.use_async: - return ray.get(self.model.chat_async.remote(messages, **kwargs)) - else: - return ray.get(self.model.chat.remote(messages, **kwargs)) + """Generate a list of experiences from a list of messages.""" + return ray.get(self.model.chat.remote(messages, **kwargs)) + + async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: + """Generate a list of experiences from a list of messages in async.""" + return await self.model.chat.remote(messages, **kwargs) def logprobs(self, tokens: List[int]) -> Tensor: - if self.use_async: - return ray.get(self.model.logprobs_async.remote(tokens)) - else: - return ray.get(self.model.logprobs.remote(tokens)) + """Calculate the logprobs of the given tokens.""" + return ray.get(self.model.logprobs.remote(tokens)) + + async def logprobs_async(self, tokens: List[int]) -> Tensor: + """Calculate the logprobs of the given tokens in async.""" + return await self.model.logprobs.remote(tokens) def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" - if self.use_async: - return ray.get(self.model.convert_messages_to_experience_async.remote(messages)) - else: - return ray.get(self.model.convert_messages_to_experience.remote(messages)) + return ray.get(self.model.convert_messages_to_experience.remote(messages)) + + async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + """Convert a list of messages into an experience in async.""" + return await self.model.convert_messages_to_experience.remote(messages) @property def model_version(self) -> int: diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py deleted file mode 100644 index f3253bcf4b..0000000000 --- a/trinity/common/models/vllm_async_model.py +++ /dev/null @@ -1,364 +0,0 @@ -"""vLLM AsyncEngine wrapper. - -Modified from Ray python/ray/llm/_internal/batch/stages/vllm_engine_stage.py -""" - -import os -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import aiohttp -import ray -import torch -import vllm -from vllm.sampling_params import RequestOutputKind - -from trinity.common.config import InferenceModelConfig -from trinity.common.experience import Experience -from trinity.common.models.model import InferenceModel -from trinity.common.models.utils import ( - tokenize_and_mask_messages_default, - tokenize_and_mask_messages_hf, -) -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -# TODO: merge into vLLMRolloutModel -# TODO: remove V0 when V1 is stable -class vLLMAysncRolloutModel(InferenceModel): - """Wrapper around the vLLM engine to handle async requests. - - Args: - config (Config): The config. - kwargs (dict): The keyword arguments for the engine. - """ - - def __init__( - self, - config: InferenceModelConfig, - ) -> None: - self.logger = get_logger(__name__) - self.config = config - self.use_v1 = config.use_v1 - if config.tensor_parallel_size != 1: - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices - if not vllm.envs.is_set("VLLM_USE_V1"): - self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") - os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) - if config.use_v1: - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - self.default_sampling_params = vllm.SamplingParams( - n=1, - temperature=0.0, - max_tokens=config.max_response_tokens, - min_tokens=1, - truncate_prompt_tokens=config.max_prompt_tokens, - skip_special_tokens=True, - include_stop_str_in_output=False, - output_kind=RequestOutputKind.FINAL_ONLY, - logprobs=0, - ) - self.enable_thinking = config.enable_thinking - self.request_id = 0 - max_model_len = None - if config.max_prompt_tokens is not None and config.max_response_tokens is not None: - max_model_len = config.max_prompt_tokens + config.max_response_tokens - engine_args = vllm.AsyncEngineArgs( - model=config.model_path, - enforce_eager=config.enforce_eager, - worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", - tensor_parallel_size=config.tensor_parallel_size, - seed=config.seed, - distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), - max_model_len=max_model_len, - enable_prefix_caching=config.enable_prefix_caching, - dtype=config.dtype, - trust_remote_code=True, - task="generate", - disable_log_requests=True, - gpu_memory_utilization=config.gpu_memory_utilization, - enable_chunked_prefill=config.enable_chunked_prefill, - # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage - ) - self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) - self.tokenizer = None - self.chat_template = None - if self.config.chat_template: - self.chat_template = self.config.chat_template - if self.chat_template is None or not re.search( - r"\{\%-?\s*generation\s*-?\%\}", self.chat_template - ): - self.logger.warning( - "The provided chat template does not support `return_assitant_tokens_mask`. " - "The default assistant mask method will be used, which may cause performance " - "degradation and lead to incorrect results." - ) - self.action_mask_method = tokenize_and_mask_messages_default - else: - self.action_mask_method = tokenize_and_mask_messages_hf - self.state_dict_meta = None - self.model_version = 0 # TODO: resume the value from the checkpoint - self.api_server_host = None - self.api_server_port = None - - async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]: - """Chat with the model with a list of messages in async. - - Args: - messages (List[dict]): The input history messages. - kwargs (dict): A dictionary of sampling parameters. - - Returns: - A list of experiences. - """ - if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - if messages[-1]["role"] == "assistant": - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - chat_template=self.chat_template, - ) - else: - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) - return await self.generate_async(prompt=prompt, **kwargs) - - async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: - """Generate a response from the provided prompt in async. - - Args: - prompt (str): The input prompt. - kwargs (dict): A dictionary of sampling parameters. - - Returns: - A list of experiences. - """ - output = await self._generate_internal(prompt=prompt, **kwargs) - experiences = [ - Experience( - tokens=torch.cat( - ( - torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), - ) - ), - logprobs=torch.cat( - ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), - torch.tensor( - [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.outputs[i].logprobs - ], - dtype=torch.float32, - ), - ) - ), - prompt_length=len(output.prompt_token_ids), - prompt_text=output.prompt, - response_text=output.outputs[i].text, - ) - for i in range(len(output.outputs)) - ] - return experiences - - async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor: - """Calculate the logprobs of the given tokens in async.""" - output = await self._generate_internal( - prompt={"prompt_token_ids": token_ids}, - n=1, - max_tokens=1, - prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token - ) - return torch.tensor( - [0] - + [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.prompt_logprobs[1:] - ], - dtype=torch.float32, - ) - - async def _generate_internal(self, prompt: Any, **kwargs) -> Any: - # Send the request to the LLM engine. - self.request_id += 1 - stream = self.async_llm.generate( - request_id=str(self.request_id), - prompt=prompt, - sampling_params=self._create_sampling_params(**kwargs), - ) - - # Consume the stream until the request is finished. - async for request_output in stream: - if request_output.finished: - # Bypass the original full prompt. - # request_output.prompt = request.prompt - return request_output - - raise RuntimeError("[vLLM] The request is not finished. This should not happen.") - - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: - """Convert a list of messages into an experience.""" - if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask = self.action_mask_method( - self.tokenizer, messages, self.chat_template - ) - logprobs = await self.logprobs_async(token_ids=token_ids.tolist()) - return Experience( - tokens=token_ids, - prompt_length=len(token_ids), - logprobs=logprobs, - action_mask=action_mask, - ) - - def shutdown(self): - """Shutdown the vLLM v1 engine. This kills child processes forked - by the vLLM engine. If not called, the child processes will be - orphaned and will not be killed when the parent process exits, - and they won't be able to be tracked by Ray anymore. - """ - if hasattr(self.async_llm, "shutdown"): - logger.info("Shutting down vLLM engine") - self.async_llm.shutdown() - - def _create_sampling_params(self, **kwargs): - """Create sampling params.""" - if len(kwargs) == 0: - return self.default_sampling_params - params = self.default_sampling_params.clone() - for k, v in kwargs.items(): - if hasattr(params, k): - setattr(params, k, v) - return params - - async def _collective_rpc( - self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - ): - if self.use_v1: - return await self.async_llm.collective_rpc(method, timeout, args, kwargs) - else: - return self.async_llm.engine.model_executor.collective_rpc( - method, timeout, args, kwargs - ) - - async def sync_model( - self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None - ) -> bool: - """Sync model weights to vLLM.""" - if update_weight_args_list is not None: - await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) - await self._collective_rpc("update_weight") - self.logger.info("Sync model weights to vLLM successfully.") - self.model_version = model_version - return True - - async def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - explorer_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - state_dict_meta: dict = None, - ): - return await self._collective_rpc( - "init_process_group", - args=( - master_address, - master_port, - rank_offset, - world_size, - group_name, - backend, - timeout, - update_with_checkpoint, - state_dict_meta, - explorer_name, - ray.get_runtime_context().namespace, - ), - ) - - async def run_api_server(self): - """Run the OpenAI API server in a Ray actor. - - Note: - Do not use `ray.get()` on this method. - This method will run forever until the server is shut down. - """ - if not (self.api_server_host is None or self.api_server_port is None): - raise RuntimeError("API server is already running.") - from trinity.common.models.openai_api import run_api_server_in_ray_actor - - self.api_server_host, self.api_server_port = self.get_available_address() - await run_api_server_in_ray_actor( - self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path - ) - - async def has_api_server(self) -> bool: - return self.config.enable_openai_api - - async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: - """Check if the OpenAI API server is ready. - - Returns: - api_url (str): The URL of the OpenAI API server. - model_path (str): The path of the model. - """ - if not await self.has_api_server(): - return None, None - try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://{self.api_server_host}:{self.api_server_port}/health" - ) as response: - if response.status == 200: - return ( - f"http://{self.api_server_host}:{self.api_server_port}/v1", - self.config.model_path, - ) - else: - return None, None - except Exception as e: - self.logger.error(e) - return None, None - - async def reset_prefix_cache(self) -> None: - await self.async_llm.reset_prefix_cache() - - def get_model_version(self) -> int: - return self.model_version - - async def sleep(self, level: int = 1) -> None: - await self.async_llm.sleep(level=level) - - async def wake_up(self) -> None: - await self.async_llm.wake_up() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 643a124f72..a91ae58a8e 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -1,20 +1,15 @@ -# -*- coding: utf-8 -*- -"""vLLM related modules. - -Modified from OpenRLHF openrlhf/trainer/ray/vllm_engine.py +"""A wrapper around the vllm.AsyncEngine to handle async requests. """ -from __future__ import annotations import os import re -import threading -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union +import aiohttp import ray import torch import vllm -from vllm import LLM -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience @@ -25,13 +20,25 @@ ) from trinity.utils.log import get_logger +logger = get_logger(__name__) + +# TODO: remove V0 when V1 is stable class vLLMRolloutModel(InferenceModel): - """Actor for vLLM.""" + """Wrapper around the vLLM engine to handle async requests. + + Args: + config (Config): The config. + kwargs (dict): The keyword arguments for the engine. + """ - def __init__(self, config: InferenceModelConfig): + def __init__( + self, + config: InferenceModelConfig, + ) -> None: self.logger = get_logger(__name__) self.config = config + self.use_v1 = config.use_v1 if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices @@ -42,7 +49,7 @@ def __init__(self, config: InferenceModelConfig): os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - self.default_sampling_params = SamplingParams( + self.default_sampling_params = vllm.SamplingParams( n=1, temperature=0.0, max_tokens=config.max_response_tokens, @@ -50,13 +57,15 @@ def __init__(self, config: InferenceModelConfig): truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, + output_kind=RequestOutputKind.FINAL_ONLY, logprobs=0, ) + self.enable_thinking = config.enable_thinking + self.request_id = 0 max_model_len = None if config.max_prompt_tokens is not None and config.max_response_tokens is not None: max_model_len = config.max_prompt_tokens + config.max_response_tokens - self.llm = LLM( - # TODO: check checkpoint path + engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", @@ -67,16 +76,20 @@ def __init__(self, config: InferenceModelConfig): enable_prefix_caching=config.enable_prefix_caching, dtype=config.dtype, trust_remote_code=True, + task="generate", + disable_log_requests=True, gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, - # max_num_batched_tokens=256, + # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage ) - self.tokenizer = self.llm.get_tokenizer() - self.chat_template = self.tokenizer.get_chat_template() - self.enable_thinking = config.enable_thinking + self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) + self.tokenizer = None + self.chat_template = None if self.config.chat_template: self.chat_template = self.config.chat_template - if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template): + if self.chat_template is None or not re.search( + r"\{\%-?\s*generation\s*-?\%\}", self.chat_template + ): self.logger.warning( "The provided chat template does not support `return_assitant_tokens_mask`. " "The default assistant mask method will be used, which may cause performance " @@ -85,146 +98,25 @@ def __init__(self, config: InferenceModelConfig): self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf - self.lock = threading.Lock() self.state_dict_meta = None self.model_version = 0 # TODO: resume the value from the checkpoint + self.api_server_host = None + self.api_server_port = None - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - explorer_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - state_dict_meta: dict = None, - ): - return self.llm.collective_rpc( - "init_process_group", - args=( - master_address, - master_port, - rank_offset, - world_size, - group_name, - backend, - timeout, - update_with_checkpoint, - state_dict_meta, - explorer_name, - ray.get_runtime_context().namespace, - ), - ) - - def reset_prefix_cache(self): - self.llm.llm_engine.reset_prefix_cache() - - def sleep(self, level=1): - self.llm.sleep(level=level) - - def wake_up(self): - self.llm.wake_up() - - def _create_sampling_params(self, **kwargs): - """Create sampling params.""" - if len(kwargs) == 0: - return self.default_sampling_params - params = self.default_sampling_params.clone() - for k, v in kwargs.items(): - if hasattr(params, k): - setattr(params, k, v) - return params - - def generate(self, prompts: List[str], **kwargs) -> List: - """ - Generate a batch of responses from a batch of prompts. - - Note: - - This method will not apply chat template. - You need to apply chat template before calling this method. + async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: + """Chat with the model with a list of messages in async. Args: - prompts (List[str]): A list of prompts. + messages (List[dict]): The input history messages. kwargs (dict): A dictionary of sampling parameters. Returns: - List[Experience]: A list of experiences. - - Example: - - >>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2 - >>> - >>> prompts = [ - >>> "Hello, world!", - >>> "How are you?" - >>> ] - >>> experiences = model.generate(prompts) - >>> print(experiences) - [ - Experience(tokens=tensor()...), # first sequnece for prompts[0] - Experience(tokens=tensor()...), # second sequnece for prompts[0] - Experience(tokens=tensor()...), # first sequence for prompts[1] - Experience(tokens=tensor()...) # second sequence for prompts[1] - ] + A list of experiences. """ - with self.lock: - sampling_params = self._create_sampling_params(**kwargs) - outputs = self.llm.generate(prompts, sampling_params, use_tqdm=False) - experiences = [] - for output in outputs: - for i in range(sampling_params.n): - experiences.append( - Experience( - tokens=torch.cat( - ( - torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), - ) - ), - logprobs=torch.cat( - ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), - torch.tensor( - [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.outputs[i].logprobs - ], - dtype=torch.float32, - ), - ) - ), - prompt_length=len(output.prompt_token_ids), - prompt_text=output.prompt, - response_text=output.outputs[i].text, - ) - ) - return experiences - - def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - """Chat with the model with a list of messages. - - Args: - messages (List[dict]): A list of messages. - - Example: - - >>> [ - >>> {"role": "system", "content": "You are a helpful assistant."}, - >>> {"role": "user", "content": "Hello, world!"}, - >>> ] - - Returns: - List[Experience]: A list of experiences containing the response text. - """ - # TODO: support tools and other fields + if self.tokenizer is None: + self.tokenizer = await self.async_llm.get_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() if messages[-1]["role"] == "assistant": prompt = self.tokenizer.apply_chat_template( messages, @@ -240,34 +132,96 @@ def chat(self, messages: List[dict], **kwargs) -> List[Experience]: chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) - return self.generate([prompt], **kwargs) - - def logprobs(self, token_ids: List[int]) -> torch.Tensor: - with self.lock: - outputs = self.llm.generate( - prompts={"prompt_token_ids": token_ids}, - sampling_params=self._create_sampling_params( - n=1, - max_tokens=1, - prompt_logprobs=0, + return await self.generate(prompt=prompt, **kwargs) + + async def generate(self, prompt: str, **kwargs) -> List[Experience]: + """Generate a response from the provided prompt in async. + + Args: + prompt (str): The input prompt. + kwargs (dict): A dictionary of sampling parameters. + + Returns: + A list of experiences. + """ + output = await self._generate_internal(prompt=prompt, **kwargs) + experiences = [ + Experience( + tokens=torch.cat( + ( + torch.tensor(output.prompt_token_ids, dtype=torch.int32), + torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), + ) ), - use_tqdm=False, + logprobs=torch.cat( + ( + torch.full( + (len(output.prompt_token_ids),), + 0.0, + dtype=torch.float32, + ), + torch.tensor( + [ + list(logprob_dict.values())[0].logprob + for logprob_dict in output.outputs[i].logprobs + ], + dtype=torch.float32, + ), + ) + ), + prompt_length=len(output.prompt_token_ids), + prompt_text=output.prompt, + response_text=output.outputs[i].text, ) + for i in range(len(output.outputs)) + ] + return experiences + + async def logprobs(self, token_ids: List[int]) -> torch.Tensor: + """Calculate the logprobs of the given tokens in async.""" + output = await self._generate_internal( + prompt={"prompt_token_ids": token_ids}, + n=1, + max_tokens=1, + prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + ) return torch.tensor( [0] + [ list(logprob_dict.values())[0].logprob - for logprob_dict in outputs[0].prompt_logprobs[1:] + for logprob_dict in output.prompt_logprobs[1:] ], dtype=torch.float32, ) - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + async def _generate_internal(self, prompt: Any, **kwargs) -> Any: + # Send the request to the LLM engine. + self.request_id += 1 + stream = self.async_llm.generate( + request_id=str(self.request_id), + prompt=prompt, + sampling_params=self._create_sampling_params(**kwargs), + ) + + # Consume the stream until the request is finished. + async for request_output in stream: + if request_output.finished: + # Bypass the original full prompt. + # request_output.prompt = request.prompt + return request_output + + raise RuntimeError("[vLLM] The request is not finished. This should not happen.") + + async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" + if self.tokenizer is None: + self.tokenizer = await self.async_llm.get_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() token_ids, action_mask = self.action_mask_method( self.tokenizer, messages, self.chat_template ) - logprobs = self.logprobs(token_ids=token_ids.tolist()) + logprobs = await self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, prompt_length=len(token_ids), @@ -275,19 +229,133 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: action_mask=action_mask, ) - def has_api_server(self) -> bool: - return False + def shutdown(self): + """Shutdown the vLLM v1 engine. This kills child processes forked + by the vLLM engine. If not called, the child processes will be + orphaned and will not be killed when the parent process exits, + and they won't be able to be tracked by Ray anymore. + """ + if hasattr(self.async_llm, "shutdown"): + logger.info("Shutting down vLLM engine") + self.async_llm.shutdown() - def sync_model( + def _create_sampling_params(self, **kwargs): + """Create sampling params.""" + if len(kwargs) == 0: + return self.default_sampling_params + params = self.default_sampling_params.clone() + for k, v in kwargs.items(): + if hasattr(params, k): + setattr(params, k, v) + return params + + async def _collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): + if self.use_v1: + return await self.async_llm.collective_rpc(method, timeout, args, kwargs) + else: + return self.async_llm.engine.model_executor.collective_rpc( + method, timeout, args, kwargs + ) + + async def sync_model( self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None ) -> bool: """Sync model weights to vLLM.""" if update_weight_args_list is not None: - self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) - self._collective_rpc("update_weight") + await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version return True + async def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + explorer_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + state_dict_meta: dict = None, + ): + return await self._collective_rpc( + "init_process_group", + args=( + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend, + timeout, + update_with_checkpoint, + state_dict_meta, + explorer_name, + ray.get_runtime_context().namespace, + ), + ) + + async def run_api_server(self): + """Run the OpenAI API server in a Ray actor. + + Note: + Do not use `ray.get()` on this method. + This method will run forever until the server is shut down. + """ + if not (self.api_server_host is None or self.api_server_port is None): + raise RuntimeError("API server is already running.") + from trinity.common.models.openai_api import run_api_server_in_ray_actor + + self.api_server_host, self.api_server_port = self.get_available_address() + await run_api_server_in_ray_actor( + self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path + ) + + async def has_api_server(self) -> bool: + return self.config.enable_openai_api + + async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: + """Check if the OpenAI API server is ready. + + Returns: + api_url (str): The URL of the OpenAI API server. + model_path (str): The path of the model. + """ + if not await self.has_api_server(): + return None, None + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{self.api_server_host}:{self.api_server_port}/health" + ) as response: + if response.status == 200: + return ( + f"http://{self.api_server_host}:{self.api_server_port}/v1", + self.config.model_path, + ) + else: + return None, None + except Exception as e: + self.logger.error(e) + return None, None + + async def reset_prefix_cache(self) -> None: + await self.async_llm.reset_prefix_cache() + def get_model_version(self) -> int: return self.model_version + + async def sleep(self, level: int = 1) -> None: + await self.async_llm.sleep(level=level) + + async def wake_up(self) -> None: + await self.async_llm.wake_up() From 440114923acecd004df5a895802d06f1ef0dd6e4 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 15 Jul 2025 13:38:25 +0800 Subject: [PATCH 02/12] get model name from openai api --- .../tutorial/trinity_programming_guide.md | 1 + tests/common/vllm_test.py | 7 +++--- tests/explorer/scheduler_test.py | 6 ++--- trinity/common/models/model.py | 23 ++++++++++++------- trinity/common/models/vllm_model.py | 14 ++++------- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index fb75d084b1..a7c92bef61 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -122,6 +122,7 @@ During initialization, `Workflow` receives the following parameters: ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. +And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`. ``` Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index fe9c13e74e..417253f18a 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -213,13 +213,12 @@ def test_api(self): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is your name?"}, ] - response = openai_client.chat.completions.create( - model=self.config.model.model_path, messages=messages, n=1 - ) + model_id = openai_client.models.list().data[0].id + response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1) self.assertEqual(1, len(response.choices)) self.assertTrue(len(response.choices[0].message.content) > 0) response = openai_client.chat.completions.create( - model=self.config.model.model_path, + model=model_id, messages=messages, n=2, temperature=0.5, diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 20ecb179c9..d25c394bff 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,7 +1,7 @@ import asyncio import time import unittest -from typing import List, Tuple +from typing import List import ray import torch @@ -98,8 +98,8 @@ def init_process_group( def has_api_server(self) -> bool: return True - def api_server_ready(self) -> Tuple[str, str]: - return "http://localhosts:12345", "placeholder" + def api_server_ready(self) -> str: + return "http://localhosts:12345" def generate_tasks( diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 74aece6414..5d143bbb72 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -50,11 +50,12 @@ class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" # TODO: check model_type inside __init__ - def __init__(self, model: Any, model_type: str = "vllm"): + def __init__(self, model: Any, model_type: str = "vllm", record_history: bool = True): assert model_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.openai_client: openai.OpenAI = None self.logger = get_logger(__name__) + self.record_history = record_history def generate(self, prompts: List[str], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of prompts.""" @@ -110,9 +111,9 @@ def get_openai_client(self) -> openai.OpenAI: "OpenAI API server is not running on current model." "Please set `enable_openai_api` to `True`." ) - api_address, model_path = None, None + api_address = None while True: - api_address, model_path = ray.get(self.model.api_server_ready.remote()) + api_address = ray.get(self.model.api_server_ready.remote()) if api_address is not None: break else: @@ -123,9 +124,15 @@ def get_openai_client(self) -> openai.OpenAI: "Failed to connect to the API server. Please check the API server is running." ) self.logger.info(f"Successfully connect to API server at {api_address}") - self.openai_client = openai.OpenAI( - base_url=api_address, - api_key="EMPTY", - ) - setattr(self.openai_client, "model_path", model_path) # TODO: may be removed + if self.record_history: + # add a decorator to the openai client to record history + self.openai_client = openai.OpenAI( + base_url=api_address, + api_key="EMPTY", + ) + else: + self.openai_client = openai.OpenAI( + base_url=api_address, + api_key="EMPTY", + ) return self.openai_client diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index a91ae58a8e..88b3a7ebc7 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -323,30 +323,26 @@ async def run_api_server(self): async def has_api_server(self) -> bool: return self.config.enable_openai_api - async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: + async def api_server_ready(self) -> Union[str, None]: """Check if the OpenAI API server is ready. Returns: api_url (str): The URL of the OpenAI API server. - model_path (str): The path of the model. """ if not await self.has_api_server(): - return None, None + return None try: async with aiohttp.ClientSession() as session: async with session.get( f"http://{self.api_server_host}:{self.api_server_port}/health" ) as response: if response.status == 200: - return ( - f"http://{self.api_server_host}:{self.api_server_port}/v1", - self.config.model_path, - ) + return f"http://{self.api_server_host}:{self.api_server_port}/v1" else: - return None, None + return None except Exception as e: self.logger.error(e) - return None, None + return None async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() From 009bba5d0f8348695ef779ed52c7674f52973034 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 15 Jul 2025 15:56:20 +0800 Subject: [PATCH 03/12] record history in model wrapper --- tests/common/vllm_test.py | 24 +++++++++++++++-- trinity/common/models/model.py | 49 +++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 417253f18a..ae5071c7b6 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -86,8 +86,17 @@ class BaseTestModelWrapper: def test_generate(self): prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times - results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) - self.assertEqual(len(results), len(prompts) * n) + generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) + self.assertEqual(len(generate_results), len(prompts) * n) + history_experiences = self.model_wrapper.extract_experience_from_history( + clear_history=False + ) + self.assertEqual(len(history_experiences), len(generate_results)) + for exp, history_exp in zip(generate_results, history_experiences): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs, history_exp.logprobs) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -99,6 +108,13 @@ def test_generate(self): ] results = self.model_wrapper.chat(messages, n=n, temperature=1.0) self.assertEqual(len(results), n) + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(history_experiences) - len(generate_results), len(results)) + for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs, history_exp.logprobs) for result in results: input_logprobs = result.logprobs[: result.prompt_length] output_logprobs = result.logprobs[result.prompt_length :] @@ -106,6 +122,8 @@ def test_generate(self): self.assertTrue(torch.any(output_logprobs != 0)) logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) messages.append( { "role": "assistant", @@ -128,6 +146,8 @@ def test_generate(self): self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 5d143bbb72..13bec96a48 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,7 +4,7 @@ import socket import time from abc import ABC, abstractmethod -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union import openai import ray @@ -46,22 +46,52 @@ def get_available_address(self) -> Tuple[str, int]: return address, port +def _history_recorder(func): + """Decorator to record history of the model calls.""" + + async def async_wrapper(self, *args, **kwargs): + result = await func(self, *args, **kwargs) + if self.enable_history: + self.history.append(result) + return result + + def sync_wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if self.enable_history: + self._record_history(result) + return result + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" # TODO: check model_type inside __init__ - def __init__(self, model: Any, model_type: str = "vllm", record_history: bool = True): + def __init__(self, model: Any, model_type: str = "vllm", enable_history: bool = True): assert model_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.openai_client: openai.OpenAI = None self.logger = get_logger(__name__) - self.record_history = record_history + self.enable_history = enable_history + self.history = [] + + def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: + """Record experiences to history.""" + if isinstance(exps, Experience): + self.history.append(exps) + elif isinstance(exps, list): + self.history.extend(exps) + else: + raise TypeError("Expected Experience or List[Experience], got {}".format(type(exps))) + @_history_recorder def generate(self, prompts: List[str], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of prompts.""" results = ray.get([self.model.generate.remote(prompt, **kwargs) for prompt in prompts]) return [exp for exps in results for exp in exps] + @_history_recorder async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of prompts in async.""" results = await asyncio.gather( @@ -69,10 +99,12 @@ async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience] ) return [exp for exps in results for exp in exps] + @_history_recorder def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages.""" return ray.get(self.model.chat.remote(messages, **kwargs)) + @_history_recorder async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages in async.""" return await self.model.chat.remote(messages, **kwargs) @@ -124,7 +156,7 @@ def get_openai_client(self) -> openai.OpenAI: "Failed to connect to the API server. Please check the API server is running." ) self.logger.info(f"Successfully connect to API server at {api_address}") - if self.record_history: + if self.enable_history: # add a decorator to the openai client to record history self.openai_client = openai.OpenAI( base_url=api_address, @@ -136,3 +168,12 @@ def get_openai_client(self) -> openai.OpenAI: api_key="EMPTY", ) return self.openai_client + + def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: + """Extract experiences from the history.""" + if not self.enable_history: + raise ValueError("History recording is not enabled.") + exps = [exp for exp in self.history] + if clear_history: + self.history.clear() + return exps From e4ea9f89a09d7b41d3d047d260962c294bc8c3f7 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 15 Jul 2025 16:33:26 +0800 Subject: [PATCH 04/12] add tests for enable history --- .../source/tutorial/trinity_configs.md | 2 + tests/common/vllm_test.py | 57 ++++++++++++------- trinity/common/config.py | 3 + trinity/common/models/__init__.py | 7 ++- trinity/common/models/model.py | 2 +- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f34f1dffd4..5ccf755d0e 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -324,6 +324,7 @@ explorer: engine_type: vllm_async engine_num: 1 tensor_parallel_size: 1 + enable_history: False auxiliary_models: - model_path: /PATH/TO/MODEL tensor_parallel_size: 1 @@ -339,6 +340,7 @@ explorer: - `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. +- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`. - `auxiliary_models`: Additional models used for custom workflows. - `eval_interval`: Interval (in steps) for evaluating the model. - `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index ae5071c7b6..8337e465ad 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -88,15 +88,19 @@ def test_generate(self): n = self.config.algorithm.repeat_times generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) self.assertEqual(len(generate_results), len(prompts) * n) - history_experiences = self.model_wrapper.extract_experience_from_history( - clear_history=False - ) - self.assertEqual(len(history_experiences), len(generate_results)) - for exp, history_exp in zip(generate_results, history_experiences): - self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) - self.assertEqual(exp.prompt_length, history_exp.prompt_length) - self.assertEqual(exp.logprobs, history_exp.logprobs) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history( + clear_history=False + ) + self.assertEqual(len(history_experiences), len(generate_results)) + for exp, history_exp in zip(generate_results, history_experiences): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) + else: + with self.assertRaises(ValueError): + self.model_wrapper.extract_experience_from_history(clear_history=False) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -108,13 +112,14 @@ def test_generate(self): ] results = self.model_wrapper.chat(messages, n=n, temperature=1.0) self.assertEqual(len(results), n) - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertEqual(len(history_experiences) - len(generate_results), len(results)) - for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): - self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) - self.assertEqual(exp.prompt_length, history_exp.prompt_length) - self.assertEqual(exp.logprobs, history_exp.logprobs) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(history_experiences) - len(generate_results), len(results)) + for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) for result in results: input_logprobs = result.logprobs[: result.prompt_length] output_logprobs = result.logprobs[result.prompt_length :] @@ -122,8 +127,9 @@ def test_generate(self): self.assertTrue(torch.any(output_logprobs != 0)) logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertTrue(len(history_experiences) == 0) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) messages.append( { "role": "assistant", @@ -146,8 +152,9 @@ def test_generate(self): self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertTrue(len(history_experiences) == 0) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): @@ -161,9 +168,12 @@ def setUp(self): self.config.explorer.rollout_model.use_v1 = False self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.algorithm.repeat_times = 2 + self.config.explorer.rollout_model.enable_history = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=True + ) class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): @@ -192,9 +202,12 @@ def setUp(self): self.config.explorer.rollout_model.use_v1 = True self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.algorithm.repeat_times = 2 + self.config.explorer.rollout_model.enable_history = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=True + ) class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): diff --git a/trinity/common/config.py b/trinity/common/config.py index 1e0bcc5e9d..7360f0be5c 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -203,6 +203,9 @@ class InferenceModelConfig: # For Qwen3 enable_thinking: bool = False + # For history recording + enable_history: bool = False + # For OpenAI API enable_openai_api: bool = False diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 6a704b8793..fb85590b06 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -45,6 +45,7 @@ def create_inference_models( from trinity.common.models.vllm_model import vLLMRolloutModel + logger = get_logger(__name__) engine_num = config.explorer.rollout_model.engine_num tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size @@ -105,7 +106,11 @@ def create_inference_models( if config.explorer.rollout_model.enable_openai_api: for engine in rollout_engines: engine.run_api_server.remote() - + if config.explorer.rollout_model.enable_history: + logger.info( + "Model History recording is enabled. Please periodically extract " + "history via `extract_experience_from_history` to avoid out-of-memory issues." + ) # create auxiliary models for model_config in config.explorer.auxiliary_models: engines = [] diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 13bec96a48..e7e53c6f39 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -68,7 +68,7 @@ class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" # TODO: check model_type inside __init__ - def __init__(self, model: Any, model_type: str = "vllm", enable_history: bool = True): + def __init__(self, model: Any, model_type: str = "vllm", enable_history: bool = False): assert model_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.openai_client: openai.OpenAI = None From d27977f01c9b25225ba16cc3cb968986a990b3c6 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 16 Jul 2025 19:45:49 +0800 Subject: [PATCH 05/12] patch vllm to get prompt id in openai api --- pyproject.toml | 2 +- tests/common/vllm_test.py | 2 + trinity/__init__.py | 2 +- .../sample_strategy/mix_sample_strategy.py | 5 +- trinity/common/models/api/vllm_patch.py | 330 ++++++++++++++++++ trinity/common/models/model.py | 21 +- trinity/common/models/openai_api.py | 79 ----- trinity/common/models/vllm_model.py | 2 +- 8 files changed, 350 insertions(+), 93 deletions(-) create mode 100644 trinity/common/models/api/vllm_patch.py delete mode 100644 trinity/common/models/openai_api.py diff --git a/pyproject.toml b/pyproject.toml index efcc8a7aab..4fa8d764df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" -version = "0.2.0" +version = "0.2.1.dev0" authors = [ {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, ] diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 8337e465ad..4128f34efd 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -262,6 +262,8 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs is not None) self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertTrue(hasattr(response, "prompt_token_ids")) + self.assertTrue(len(response.prompt_token_ids) > 0) class TestTokenizer(unittest.TestCase): diff --git a/trinity/__init__.py b/trinity/__init__.py index dc3b8ca098..63f1db4fdc 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Trinity-RFT (Reinforcement Fine-Tuning)""" -__version__ = "0.2.0" +__version__ = "0.2.1.dev0" diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index c6858931b1..80a4af7d49 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -4,7 +4,6 @@ import numpy as np import torch -from verl.trainer.ppo.ray_trainer import DataProto from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, @@ -85,7 +84,9 @@ def default_args(cls) -> Dict: } -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): + from verl.trainer.ppo.ray_trainer import DataProto + attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py new file mode 100644 index 0000000000..6a4b5fb647 --- /dev/null +++ b/trinity/common/models/api/vllm_patch.py @@ -0,0 +1,330 @@ +"""Patch for vllm OpenAI API server. + +1. Mocks the `add_signal_handler` method to do nothing. +2. Adds `prompt` and `prompt_token_ids` to the `ChatCompletionResponse`. +""" +import asyncio +import functools +import json +import time +from typing import Optional, Union + +from pydantic import TypeAdapter +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + init_app_state, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + FunctionCall, + FunctionDefinition, + PromptTokenUsageInfo, + ToolCall, + UsageInfo, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import clamp_prompt_logprobs +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.outputs import RequestOutput +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.utils import FlexibleArgumentParser, set_ulimit + +from trinity.utils.log import get_logger + + +class PatchedChatCompletionResponse(ChatCompletionResponse): + prompt_token_ids: list[int] = [] + + def __init__(self, *args, prompt_token_ids=None, **kwargs): + super().__init__(*args, **kwargs) + self.prompt_token_ids = prompt_token_ids or [] + + +# TODO: add patch to stream generator +async def chat_completion_full_generator( # noqa C901 + self, + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, +) -> Union[ErrorResponse, ChatCompletionResponse]: + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + logger = get_logger(__name__) + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + auto_tools_called = False + + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = reasoning_parser.extract_reasoning_content( + output.text, request=request + ) + else: + reasoning_content = None + content = output.text + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required" + ): + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # if the request uses tools and specified a tool choice + elif ( + request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam + ): + tool_call_class = ( + MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + ) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class( + function=FunctionCall( + name=request.tool_choice.function.name, arguments=content + ) + ) + ], + ) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = ( + MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + ) + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + message = ChatMessage( + role=role, + content="", + tool_calls=[ + tool_call_class( + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) + ) + for tool_call in tool_calls + ], + ) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # handle when there are tools and tool choice is auto + elif ( + request.tools + and (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools + and self.tool_parser + ): + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", request=request + ) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion." + ) + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" + if auto_tools_called + else output.finish_reason + if output.finish_reason + else "stop", + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens + ) + + request_metadata.final_usage_info = usage + + print(str(final_res)) + + response = PatchedChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + prompt_token_ids=final_res.prompt_token_ids, + ) + + return response + + +async def run_server_in_ray(args, engine_client): + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host, args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + app = build_app(args) + + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) + + await patch_and_serve_http(app, sock, args) + + # # NB: Await server shutdown only after the backend context is exited + # try: + # await shutdown_task + # finally: + # sock.close() + + +def dummy_add_signal_handler(self, *args, **kwargs): + # DO NOTHING HERE + pass + + +async def patch_and_serve_http(app, sock, args): + """Patch the add_signal_handler method and serve the app.""" + loop = asyncio.get_event_loop() + original_add_signal_handler = loop.add_signal_handler + loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) + OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator + + try: + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level="info", + access_log=True, + timeout_keep_alive=10, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + ) + await shutdown_task + finally: + loop.add_signal_handler = original_add_signal_handler + sock.close() + + +async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str): + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path]) + print(args) + await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index e7e53c6f39..af6c22349a 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -156,17 +156,20 @@ def get_openai_client(self) -> openai.OpenAI: "Failed to connect to the API server. Please check the API server is running." ) self.logger.info(f"Successfully connect to API server at {api_address}") + self.openai_client = openai.OpenAI( + base_url=api_address, + api_key="EMPTY", + ) if self.enable_history: # add a decorator to the openai client to record history - self.openai_client = openai.OpenAI( - base_url=api_address, - api_key="EMPTY", - ) - else: - self.openai_client = openai.OpenAI( - base_url=api_address, - api_key="EMPTY", - ) + ori_create = self.openai_client.chat.completions.create + + def record_chat_completions(*args, **kwargs): + response = ori_create(*args, **kwargs) + return response + + self.openai_client.chat.completions.create = record_chat_completions + return self.openai_client def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: diff --git a/trinity/common/models/openai_api.py b/trinity/common/models/openai_api.py deleted file mode 100644 index c26b0ca54b..0000000000 --- a/trinity/common/models/openai_api.py +++ /dev/null @@ -1,79 +0,0 @@ -"""OpenAI API server related tools. - -Modified from vllm/entrypoints/openai/api_server.py -""" -import asyncio -import functools - -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.openai.api_server import ( - build_app, - create_server_socket, - init_app_state, -) -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser, set_ulimit - - -async def run_server_in_ray(args, engine_client): - # workaround to make sure that we bind the port before the engine is set up. - # This avoids race conditions with ray. - # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host, args.port) - sock = create_server_socket(sock_addr) - - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active - set_ulimit() - app = build_app(args) - - vllm_config = await engine_client.get_vllm_config() - await init_app_state(engine_client, vllm_config, app.state, args) - - await patch_and_serve_http(app, sock, args) - - # # NB: Await server shutdown only after the backend context is exited - # try: - # await shutdown_task - # finally: - # sock.close() - - -def dummy_add_signal_handler(self, *args, **kwargs): - # DO NOTHING HERE - pass - - -async def patch_and_serve_http(app, sock, args): - """Patch the add_signal_handler method and serve the app.""" - loop = asyncio.get_event_loop() - original_add_signal_handler = loop.add_signal_handler - loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) - - try: - shutdown_task = await serve_http( - app, - sock=sock, - enable_ssl_refresh=args.enable_ssl_refresh, - host=args.host, - port=args.port, - log_level="info", - access_log=True, - timeout_keep_alive=10, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - ) - await shutdown_task - finally: - loop.add_signal_handler = original_add_signal_handler - sock.close() - - -async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str): - parser = FlexibleArgumentParser(description="Run the OpenAI API server.") - args = make_arg_parser(parser) - args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path]) - print(args) - await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 88b3a7ebc7..01b8135511 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -313,7 +313,7 @@ async def run_api_server(self): """ if not (self.api_server_host is None or self.api_server_port is None): raise RuntimeError("API server is already running.") - from trinity.common.models.openai_api import run_api_server_in_ray_actor + from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor self.api_server_host, self.api_server_port = self.get_available_address() await run_api_server_in_ray_actor( From 73904de688dbaa56e148c1cee993e952750248a7 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 16 Jul 2025 21:06:40 +0800 Subject: [PATCH 06/12] record history in chat completion --- tests/common/vllm_test.py | 8 ++++- trinity/common/models/api/vllm_patch.py | 14 +++++++-- trinity/common/models/model.py | 42 +++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 4128f34efd..380dedcf7b 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -238,7 +238,9 @@ def setUp(self): self.config.explorer.rollout_model.enable_openai_api = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=True + ) def test_api(self): openai_client = self.model_wrapper.get_openai_client() @@ -264,6 +266,10 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) + self.assertTrue(hasattr(response.choices[0], "token_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 3) class TestTokenizer(unittest.TestCase): diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index 6a4b5fb647..f928735e6a 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -42,9 +42,18 @@ class PatchedChatCompletionResponse(ChatCompletionResponse): prompt_token_ids: list[int] = [] - def __init__(self, *args, prompt_token_ids=None, **kwargs): + def __init__(self, *args, prompt_token_ids=None, response_token_ids=None, **kwargs): super().__init__(*args, **kwargs) self.prompt_token_ids = prompt_token_ids or [] + self.response_token_ids = response_token_ids or [] + + +class PatchedChatCompletionResponseChoice(ChatCompletionResponseChoice): + token_ids: list[int] = [] + + def __init__(self, *args, token_ids=None, **kwargs): + super().__init__(*args, **kwargs) + self.token_ids = token_ids or [] # TODO: add patch to stream generator @@ -218,6 +227,7 @@ async def chat_completion_full_generator( # noqa C901 if output.finish_reason else "stop", stop_reason=output.stop_reason, + token_ids=output.token_ids, ) choices.append(choice_data) @@ -249,8 +259,6 @@ async def chat_completion_full_generator( # noqa C901 request_metadata.final_usage_info = usage - print(str(final_res)) - response = PatchedChatCompletionResponse( id=request_id, created=created_time, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index af6c22349a..c02d797d06 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -8,6 +8,7 @@ import openai import ray +import torch from torch import Tensor from trinity.common.experience import Experience @@ -166,6 +167,7 @@ def get_openai_client(self) -> openai.OpenAI: def record_chat_completions(*args, **kwargs): response = ori_create(*args, **kwargs) + self.history.extend(convert_api_output_to_experience(response)) return response self.openai_client.chat.completions.create = record_chat_completions @@ -180,3 +182,43 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex if clear_history: self.history.clear() return exps + + +def convert_api_output_to_experience( + output, +) -> List[Experience]: + """Convert the API output to a list of experiences.""" + return [ + Experience( + tokens=torch.cat( + ( + torch.tensor(output.prompt_token_ids, dtype=torch.int32), + torch.tensor(choice.token_ids, dtype=torch.int32), + ) + ), + logprobs=torch.cat( + ( + torch.full( + (len(output.prompt_token_ids),), + 0.0, + dtype=torch.float32, + ), + extract_logprobs(choice), + ) + ), + prompt_length=len(output.prompt_token_ids), + prompt_text=None, + response_text=choice.message.content, + ) + for choice in output.choices + ] + + +def extract_logprobs(choice) -> Tensor: + """Extract logprobs from a list of logprob dictionaries.""" + if not hasattr(choice, "logprobs") or choice.logprobs is None: + return torch.tensor([], dtype=torch.float32) + return torch.tensor( + [logprob.logprob for logprob in choice.logprobs.content], + dtype=torch.float32, + ) From 5c6c400862694aca1b4ed137c4d9a70e40cd5386 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 14:52:52 +0800 Subject: [PATCH 07/12] add more tests --- tests/common/vllm_test.py | 198 ++++++++++++++++++++++++++------------ 1 file changed, 139 insertions(+), 59 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 380dedcf7b..12621d0028 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -2,9 +2,10 @@ import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer -from tests.tools import RayUnittestBase, get_template_config +from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -82,8 +83,39 @@ def get_model_path() -> str: """ -class BaseTestModelWrapper: - def test_generate(self): +class ModelWrapperSyncTest(RayUnittestBase): + @parameterized.expand( + [ + (1, 2, False, 2, True), + (2, 2, False, 1, False), + (2, 2, True, 2, True), + (1, 2, True, 1, False), + ] + ) + def test_generate( + self, + tensor_parallel_size, + engine_num, + use_v1, + repeat_times, + enable_history, + ): + # configure the model + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_num = engine_num + self.config.explorer.rollout_model.tensor_parallel_size = tensor_parallel_size + self.config.explorer.rollout_model.use_v1 = use_v1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.algorithm.repeat_times = repeat_times + self.config.explorer.rollout_model.enable_history = enable_history + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=True + ) + # tests prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) @@ -157,72 +189,110 @@ def test_generate(self): self.assertTrue(len(history_experiences) == 0) -class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): +class ModelWrapperAsyncTest(RayUnittestBaseAysnc): + @parameterized.expand( + [ + (1, 2, False, 2, True), + (2, 2, False, 1, False), + (2, 2, True, 2, True), + (1, 2, True, 1, False), + ] + ) + async def test_generate( + self, + tensor_parallel_size, + engine_num, + use_v1, + repeat_times, + enable_history, + ): + # configure the model self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.use_v1 = False + self.config.explorer.rollout_model.engine_num = engine_num + self.config.explorer.rollout_model.tensor_parallel_size = tensor_parallel_size + self.config.explorer.rollout_model.use_v1 = use_v1 self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.explorer.rollout_model.enable_history = True + self.config.algorithm.repeat_times = repeat_times + self.config.explorer.rollout_model.enable_history = enable_history self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( self.engines[0], model_type="vllm_async", enable_history=True ) - - -class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 2 - self.config.explorer.rollout_model.use_v1 = False - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") - - -class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 2 - self.config.explorer.rollout_model.use_v1 = True - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.explorer.rollout_model.enable_history = True - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True + # tests + prompts = ["Hello, world!", "Hello, my name is"] + n = self.config.algorithm.repeat_times + generate_results = await self.model_wrapper.generate_async(prompts, n=n, temperature=1.0) + self.assertEqual(len(generate_results), len(prompts) * n) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history( + clear_history=False + ) + self.assertEqual(len(history_experiences), len(generate_results)) + for exp, history_exp in zip(generate_results, history_experiences): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) + else: + with self.assertRaises(ValueError): + self.model_wrapper.extract_experience_from_history(clear_history=False) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like today?"}, + { + "role": "assistant", + "content": "I'm sorry, but as an AI language model, I don't have access to real-time weather information. To get accurate weather information for your location, you can check a weather website or app, or look outside if possible.", + }, + {"role": "user", "content": "OK, thanks!"}, + ] + results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0) + self.assertEqual(len(results), n) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(history_experiences) - len(generate_results), len(results)) + for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) + for result in results: + input_logprobs = result.logprobs[: result.prompt_length] + output_logprobs = result.logprobs[result.prompt_length :] + self.assertTrue(torch.all(input_logprobs == 0)) + self.assertTrue(torch.any(output_logprobs != 0)) + logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) + self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) + messages.append( + { + "role": "assistant", + "content": results[0].response_text, + } ) - - -class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.use_v1 = True - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + exp = self.model_wrapper.convert_messages_to_experience(messages) + tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) + result_dict = tokenizer.apply_chat_template( + messages, + chat_template=CHAT_TEMPLATE, + add_generation_prompt=False, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + return_assistant_tokens_mask=True, + return_dict=True, + ) + self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) + self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) + self.assertRaises(ValueError, self.model_wrapper.get_openai_client) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) class TestAPIServer(RayUnittestBase): @@ -270,6 +340,16 @@ def test_api(self): self.assertTrue(len(response.choices[0].token_ids) > 0) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 3) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + n=4, + temperature=0.5, + logprobs=True, + top_logprobs=0, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 4) class TestTokenizer(unittest.TestCase): From 73f3826bd7b2feb83a4df912306c12fcf70ca1ae Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 15:27:28 +0800 Subject: [PATCH 08/12] add more tests --- .../source/tutorial/trinity_configs.md | 2 +- tests/common/vllm_test.py | 161 ++++-------------- 2 files changed, 33 insertions(+), 130 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 5ccf755d0e..bdd881182e 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -337,7 +337,7 @@ explorer: - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. -- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. +- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async`/`vllm` is supported. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`. diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 12621d0028..9434971b81 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -2,7 +2,7 @@ import unittest import torch -from parameterized import parameterized +from parameterized import parameterized_class from transformers import AutoTokenizer from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config @@ -83,148 +83,45 @@ def get_model_path() -> str: """ -class ModelWrapperSyncTest(RayUnittestBase): - @parameterized.expand( - [ - (1, 2, False, 2, True), - (2, 2, False, 1, False), - (2, 2, True, 2, True), - (1, 2, True, 1, False), - ] - ) - def test_generate( - self, - tensor_parallel_size, - engine_num, - use_v1, - repeat_times, - enable_history, - ): +@parameterized_class( + ("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"), + [ + (1, 2, False, 2, True, False), + (2, 2, False, 1, False, True), + (2, 2, True, 2, True, False), + (1, 2, True, 1, False, True), + (2, 1, True, 3, True, True), + ], +) +class ModelWrapperTest(RayUnittestBaseAysnc): + def setUp(self): # configure the model self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_num = engine_num - self.config.explorer.rollout_model.tensor_parallel_size = tensor_parallel_size - self.config.explorer.rollout_model.use_v1 = use_v1 + self.config.explorer.rollout_model.engine_num = self.engine_num + self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size + self.config.explorer.rollout_model.use_v1 = self.use_v1 self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = repeat_times - self.config.explorer.rollout_model.enable_history = enable_history + self.config.algorithm.repeat_times = self.repeat_times + self.config.explorer.rollout_model.enable_history = self.enable_history self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( self.engines[0], model_type="vllm_async", enable_history=True ) - # tests - prompts = ["Hello, world!", "Hello, my name is"] - n = self.config.algorithm.repeat_times - generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) - self.assertEqual(len(generate_results), len(prompts) * n) - if self.config.explorer.rollout_model.enable_history: - history_experiences = self.model_wrapper.extract_experience_from_history( - clear_history=False - ) - self.assertEqual(len(history_experiences), len(generate_results)) - for exp, history_exp in zip(generate_results, history_experiences): - self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) - self.assertEqual(exp.prompt_length, history_exp.prompt_length) - self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) - else: - with self.assertRaises(ValueError): - self.model_wrapper.extract_experience_from_history(clear_history=False) - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What's the weather like today?"}, - { - "role": "assistant", - "content": "I'm sorry, but as an AI language model, I don't have access to real-time weather information. To get accurate weather information for your location, you can check a weather website or app, or look outside if possible.", - }, - {"role": "user", "content": "OK, thanks!"}, - ] - results = self.model_wrapper.chat(messages, n=n, temperature=1.0) - self.assertEqual(len(results), n) - if self.config.explorer.rollout_model.enable_history: - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertEqual(len(history_experiences) - len(generate_results), len(results)) - for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): - self.assertEqual(exp.response_text, history_exp.response_text) - self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) - self.assertEqual(exp.prompt_length, history_exp.prompt_length) - self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) - for result in results: - input_logprobs = result.logprobs[: result.prompt_length] - output_logprobs = result.logprobs[result.prompt_length :] - self.assertTrue(torch.all(input_logprobs == 0)) - self.assertTrue(torch.any(output_logprobs != 0)) - logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) - self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) - if self.config.explorer.rollout_model.enable_history: - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertTrue(len(history_experiences) == 0) - messages.append( - { - "role": "assistant", - "content": results[0].response_text, - } - ) - exp = self.model_wrapper.convert_messages_to_experience(messages) - tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) - result_dict = tokenizer.apply_chat_template( - messages, - chat_template=CHAT_TEMPLATE, - add_generation_prompt=False, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - return_assistant_tokens_mask=True, - return_dict=True, - ) - self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) - self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) - self.assertRaises(ValueError, self.model_wrapper.get_openai_client) - if self.config.explorer.rollout_model.enable_history: - history_experiences = self.model_wrapper.extract_experience_from_history() - self.assertTrue(len(history_experiences) == 0) - -class ModelWrapperAsyncTest(RayUnittestBaseAysnc): - @parameterized.expand( - [ - (1, 2, False, 2, True), - (2, 2, False, 1, False), - (2, 2, True, 2, True), - (1, 2, True, 1, False), - ] - ) async def test_generate( self, - tensor_parallel_size, - engine_num, - use_v1, - repeat_times, - enable_history, ): - # configure the model - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_num = engine_num - self.config.explorer.rollout_model.tensor_parallel_size = tensor_parallel_size - self.config.explorer.rollout_model.use_v1 = use_v1 - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = repeat_times - self.config.explorer.rollout_model.enable_history = enable_history - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True - ) - # tests prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times - generate_results = await self.model_wrapper.generate_async(prompts, n=n, temperature=1.0) + if self.use_async: + generate_results = await self.model_wrapper.generate_async( + prompts, n=n, temperature=1.0 + ) + else: + generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) self.assertEqual(len(generate_results), len(prompts) * n) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history( @@ -248,7 +145,10 @@ async def test_generate( }, {"role": "user", "content": "OK, thanks!"}, ] - results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0) + if self.use_async: + results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0) + else: + results = self.model_wrapper.chat(messages, n=n, temperature=1.0) self.assertEqual(len(results), n) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() @@ -263,7 +163,10 @@ async def test_generate( output_logprobs = result.logprobs[result.prompt_length :] self.assertTrue(torch.all(input_logprobs == 0)) self.assertTrue(torch.any(output_logprobs != 0)) - logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) + if self.use_async: + logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) + else: + logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() From 6d254e302d9a5502056b33f086a1e23f3c41256e Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 15:31:26 +0800 Subject: [PATCH 09/12] fix doc --- docs/sphinx_doc/source/tutorial/trinity_configs.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index bdd881182e..1c8ba2e36e 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -321,7 +321,7 @@ explorer: max_retry_times: 2 env_vars: {} rollout_model: - engine_type: vllm_async + engine_type: vllm engine_num: 1 tensor_parallel_size: 1 enable_history: False @@ -337,7 +337,7 @@ explorer: - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. -- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async`/`vllm` is supported. +- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async` and `vllm` is supported, they have the same meaning and both use the asynchronous engine. In subsequent versions, only `vllm` may be retained for simplicity. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`. From ff9fcbe6c08275accf72eb52e3a6ae2d8e5b2ca3 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 15:52:42 +0800 Subject: [PATCH 10/12] fix tests --- tests/common/vllm_test.py | 9 ++++++++- trinity/common/models/model.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 9434971b81..e9160f4c6f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -108,7 +108,7 @@ def setUp(self): self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True + self.engines[0], model_type="vllm_async", enable_history=self.enable_history ) async def test_generate( @@ -214,6 +214,9 @@ def setUp(self): self.model_wrapper = ModelWrapper( self.engines[0], model_type="vllm_async", enable_history=True ) + self.model_wrapper_no_history = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=False + ) def test_api(self): openai_client = self.model_wrapper.get_openai_client() @@ -253,6 +256,10 @@ def test_api(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) + self.assertTrue(len(self.model_wrapper.extract_experience_from_history()), 0) + response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(model_id, messages=messages, n=2) + self.assertEqual(2, len(response.choices)) + self.assertEqual(len(self.model_wrapper_no_history.extract_experience_from_history()), 0) class TestTokenizer(unittest.TestCase): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index c02d797d06..1213bf2da3 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -53,7 +53,7 @@ def _history_recorder(func): async def async_wrapper(self, *args, **kwargs): result = await func(self, *args, **kwargs) if self.enable_history: - self.history.append(result) + self._record_history(result) return result def sync_wrapper(self, *args, **kwargs): From 167c388c2303d6cd4f68a6cd9c65bc662fe901cb Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 17:01:29 +0800 Subject: [PATCH 11/12] fix openai api --- tests/common/vllm_test.py | 12 +++++++++--- trinity/common/models/api/vllm_patch.py | 24 ++++++++---------------- trinity/common/models/model.py | 2 +- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index e9160f4c6f..71ccf32b7f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -256,10 +256,16 @@ def test_api(self): ) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 4) - self.assertTrue(len(self.model_wrapper.extract_experience_from_history()), 0) - response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(model_id, messages=messages, n=2) + self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = self.model_wrapper_no_history.get_openai_client().chat.completions.create( + model=model_id, messages=messages, n=2 + ) self.assertEqual(2, len(response.choices)) - self.assertEqual(len(self.model_wrapper_no_history.extract_experience_from_history()), 0) + self.assertTrue(hasattr(response.choices[0], "token_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) + with self.assertRaises(ValueError): + self.model_wrapper_no_history.extract_experience_from_history() + self.assertEqual(len(self.model_wrapper_no_history.history), 0) class TestTokenizer(unittest.TestCase): diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py index f928735e6a..438636e35e 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/api/vllm_patch.py @@ -1,7 +1,7 @@ """Patch for vllm OpenAI API server. 1. Mocks the `add_signal_handler` method to do nothing. -2. Adds `prompt` and `prompt_token_ids` to the `ChatCompletionResponse`. +2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. """ import asyncio import functools @@ -9,7 +9,7 @@ import time from typing import Optional, Union -from pydantic import TypeAdapter +from pydantic import Field, TypeAdapter from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( build_app, @@ -39,21 +39,13 @@ from trinity.utils.log import get_logger -class PatchedChatCompletionResponse(ChatCompletionResponse): - prompt_token_ids: list[int] = [] - - def __init__(self, *args, prompt_token_ids=None, response_token_ids=None, **kwargs): - super().__init__(*args, **kwargs) - self.prompt_token_ids = prompt_token_ids or [] - self.response_token_ids = response_token_ids or [] - - class PatchedChatCompletionResponseChoice(ChatCompletionResponseChoice): - token_ids: list[int] = [] + token_ids: list[int] = Field(default_factory=list) + - def __init__(self, *args, token_ids=None, **kwargs): - super().__init__(*args, **kwargs) - self.token_ids = token_ids or [] +class PatchedChatCompletionResponse(ChatCompletionResponse): + prompt_token_ids: list[int] = Field(default_factory=list) + choices: list[PatchedChatCompletionResponseChoice] = list[ChatCompletionResponseChoice] # TODO: add patch to stream generator @@ -217,7 +209,7 @@ async def chat_completion_full_generator( # noqa C901 ) message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) - choice_data = ChatCompletionResponseChoice( + choice_data = PatchedChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 1213bf2da3..2568c88005 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -171,7 +171,7 @@ def record_chat_completions(*args, **kwargs): return response self.openai_client.chat.completions.create = record_chat_completions - + setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id) return self.openai_client def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: From a90fe9563750fd0eb0f2fdc21fbebd0bdc6868d3 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 17 Jul 2025 18:10:52 +0800 Subject: [PATCH 12/12] simplify config --- trinity/common/config.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 7360f0be5c..3a8bcd7f91 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -313,8 +313,6 @@ class ExplorerConfig: name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. - # For sync engine (vllm), it should be `1`. - # For async engine (vllm_async), it could be a large number. runner_per_model: int = 8 # number of runners per each rollout model max_timeout: int = 1800 # wait each task for 30 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout @@ -722,11 +720,6 @@ def check_and_update(self) -> None: # noqa: C901 self.model.critic_model_path = self.model.model_path # check explorer - if ( - self.explorer.rollout_model.engine_type != "vllm_async" - and self.explorer.rollout_model.enable_openai_api - ): - raise ValueError("OpenAI API server only support `vllm_async` engine.") if self.explorer.rollout_model.max_prompt_tokens is None: self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens if self.explorer.rollout_model.max_response_tokens is None: