Skip to content
3 changes: 3 additions & 0 deletions tests/entrypoints/openai/test_lora_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ async def mock_generate(*args, **kwargs):
models,
request_logger=None)

serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
name="engine_request"), {}))

return mock_engine, serving_completion


Expand Down
98 changes: 34 additions & 64 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -230,6 +230,7 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
Expand All @@ -244,11 +245,33 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


def _build_serving_chat(engine: AsyncLLM,
model_config: MockModelConfig) -> OpenAIServingChat:
models = OpenAIServingModels(engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=model_config)
serving_chat = OpenAIServingChat(engine,
model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
*, lora_request, trace_headers, priority):
return dict(engine_prompt), {}

serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat


@dataclass
class MockEngine:

Expand Down Expand Up @@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
messages = [{"role": "user", "content": "what is 1+1?"}]

async def return_model_name(*args):
Expand All @@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())

req = ChatCompletionRequest(
model=MODEL_NAME,
Expand Down Expand Up @@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)

# Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest(
Expand Down Expand Up @@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)

# Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest(
Expand Down Expand Up @@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)

req = ChatCompletionRequest(
model=MODEL_NAME,
Expand Down Expand Up @@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)

# Test cache_salt
req = ChatCompletionRequest(
Expand All @@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt

# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
7 changes: 6 additions & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
from vllm.v1.engine import EngineCoreRequest

logger = init_logger(__name__)

Expand Down Expand Up @@ -49,12 +50,16 @@ def dead_error(self) -> BaseException:
@abstractmethod
def generate(
self,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
sampling_params: SamplingParams,
request_id: str,
*,
prompt_text: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
Expand Down
18 changes: 16 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ async def create_chat_completion(
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
prompt_text, _, _ = (self._get_prompt_components(
request_prompts[i]))

if self.default_sampling_params is None:
self.default_sampling_params = {}
Expand All @@ -285,6 +286,7 @@ async def create_chat_completion(
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)

sampling_params: Union[SamplingParams, BeamSearchParams]
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
Expand All @@ -309,13 +311,25 @@ async def create_chat_completion(
lora_request=lora_request,
)
else:
engine_request, tokenization_kwargs = (
await self._process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))

generator = self.engine_client.generate(
engine_prompt,
engine_request,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)

generators.append(generator)
Expand Down
44 changes: 24 additions & 20 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import jinja2
from fastapi import Request
from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
Expand All @@ -32,8 +31,7 @@
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -157,23 +155,16 @@ async def create_completion(
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt))

input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
assert_never(engine_prompt)
raise NotImplementedError

if self.default_sampling_params is None:
self.default_sampling_params = {}
Expand All @@ -185,6 +176,7 @@ async def create_completion(
default_sampling_params=self.default_sampling_params,
)

sampling_params: Union[SamplingParams, BeamSearchParams]
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
Expand Down Expand Up @@ -220,13 +212,25 @@ async def create_completion(
lora_request=lora_request,
)
else:
engine_request, tokenization_kwargs = (
await self._process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))

generator = self.engine_client.generate(
engine_prompt,
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)

generators.append(generator)
Expand Down
Loading