Skip to content

Commit 812b7f5

Browse files
[Renderer] Move Processor out of AsyncLLM (#24138)
Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 5f2cacd commit 812b7f5

File tree

7 files changed

+217
-127
lines changed

7 files changed

+217
-127
lines changed

tests/entrypoints/openai/test_lora_resolvers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ async def mock_generate(*args, **kwargs):
122122
models,
123123
request_logger=None)
124124

125+
serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
126+
name="engine_request"), {}))
127+
125128
return mock_engine, serving_completion
126129

127130

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import suppress
88
from dataclasses import dataclass, field
99
from typing import TYPE_CHECKING, Any, Optional
10-
from unittest.mock import MagicMock
10+
from unittest.mock import AsyncMock, MagicMock
1111

1212
import pytest
1313
import pytest_asyncio
@@ -230,6 +230,7 @@ class MockHFConfig:
230230
@dataclass
231231
class MockModelConfig:
232232
task = "generate"
233+
runner_type = "generate"
233234
tokenizer = MODEL_NAME
234235
trust_remote_code = False
235236
tokenizer_mode = "auto"
@@ -244,11 +245,33 @@ class MockModelConfig:
244245
encoder_config = None
245246
generation_config: str = "auto"
246247
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
248+
skip_tokenizer_init = False
247249

248250
def get_diff_sampling_param(self):
249251
return self.diff_sampling_param or {}
250252

251253

254+
def _build_serving_chat(engine: AsyncLLM,
255+
model_config: MockModelConfig) -> OpenAIServingChat:
256+
models = OpenAIServingModels(engine_client=engine,
257+
base_model_paths=BASE_MODEL_PATHS,
258+
model_config=model_config)
259+
serving_chat = OpenAIServingChat(engine,
260+
model_config,
261+
models,
262+
response_role="assistant",
263+
chat_template=CHAT_TEMPLATE,
264+
chat_template_content_format="auto",
265+
request_logger=None)
266+
267+
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
268+
*, lora_request, trace_headers, priority):
269+
return dict(engine_prompt), {}
270+
271+
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
272+
return serving_chat
273+
274+
252275
@dataclass
253276
class MockEngine:
254277

@@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
282305
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
283306
mock_engine.errored = False
284307

285-
models = OpenAIServingModels(engine_client=mock_engine,
286-
base_model_paths=BASE_MODEL_PATHS,
287-
model_config=MockModelConfig())
288-
serving_chat = OpenAIServingChat(mock_engine,
289-
MockModelConfig(),
290-
models,
291-
response_role="assistant",
292-
chat_template=CHAT_TEMPLATE,
293-
chat_template_content_format="auto",
294-
request_logger=None)
308+
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
295309
messages = [{"role": "user", "content": "what is 1+1?"}]
296310

297311
async def return_model_name(*args):
@@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
318332
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
319333
mock_engine.errored = False
320334

321-
models = OpenAIServingModels(engine_client=mock_engine,
322-
base_model_paths=BASE_MODEL_PATHS,
323-
model_config=MockModelConfig())
324-
serving_chat = OpenAIServingChat(mock_engine,
325-
MockModelConfig(),
326-
models,
327-
response_role="assistant",
328-
chat_template=CHAT_TEMPLATE,
329-
chat_template_content_format="auto",
330-
request_logger=None)
335+
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
331336

332337
req = ChatCompletionRequest(
333338
model=MODEL_NAME,
@@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
361366
mock_engine.errored = False
362367

363368
# Initialize the serving chat
364-
models = OpenAIServingModels(engine_client=mock_engine,
365-
base_model_paths=BASE_MODEL_PATHS,
366-
model_config=mock_model_config)
367-
serving_chat = OpenAIServingChat(mock_engine,
368-
mock_model_config,
369-
models,
370-
response_role="assistant",
371-
chat_template=CHAT_TEMPLATE,
372-
chat_template_content_format="auto",
373-
request_logger=None)
369+
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
374370

375371
# Test Case 1: No max_tokens specified in request
376372
req = ChatCompletionRequest(
@@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
415411
mock_engine.errored = False
416412

417413
# Initialize the serving chat
418-
models = OpenAIServingModels(engine_client=mock_engine,
419-
base_model_paths=BASE_MODEL_PATHS,
420-
model_config=mock_model_config)
421-
serving_chat = OpenAIServingChat(mock_engine,
422-
mock_model_config,
423-
models,
424-
response_role="assistant",
425-
chat_template=CHAT_TEMPLATE,
426-
chat_template_content_format="auto",
427-
request_logger=None)
414+
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
428415

429416
# Test case 1: No max_tokens specified, defaults to context_window
430417
req = ChatCompletionRequest(
@@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
471458
mock_engine.errored = False
472459

473460
# Initialize the serving chat
474-
models = OpenAIServingModels(engine_client=mock_engine,
475-
base_model_paths=BASE_MODEL_PATHS,
476-
model_config=mock_model_config)
477-
serving_chat = OpenAIServingChat(mock_engine,
478-
mock_model_config,
479-
models,
480-
response_role="assistant",
481-
chat_template=CHAT_TEMPLATE,
482-
chat_template_content_format="auto",
483-
request_logger=None)
461+
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
484462

485463
req = ChatCompletionRequest(
486464
model=MODEL_NAME,
@@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
525503
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
526504
mock_engine.errored = False
527505

528-
# Initialize the serving chat
529-
models = OpenAIServingModels(engine_client=mock_engine,
530-
base_model_paths=BASE_MODEL_PATHS,
531-
model_config=mock_model_config)
532-
serving_chat = OpenAIServingChat(mock_engine,
533-
mock_model_config,
534-
models,
535-
response_role="assistant",
536-
chat_template=CHAT_TEMPLATE,
537-
chat_template_content_format="auto",
538-
request_logger=None)
506+
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
539507

540508
# Test cache_salt
541509
req = ChatCompletionRequest(
@@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
549517
# By default, cache_salt in the engine prompt is not set
550518
with suppress(Exception):
551519
await serving_chat.create_chat_completion(req)
552-
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
520+
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
521+
assert "cache_salt" not in engine_prompt
553522

554523
# Test with certain cache_salt
555524
req.cache_salt = "test_salt"
556525
with suppress(Exception):
557526
await serving_chat.create_chat_completion(req)
558-
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
527+
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
528+
assert engine_prompt.get("cache_salt") == "test_salt"

vllm/engine/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.tasks import SupportedTask
2020
from vllm.transformers_utils.tokenizer import AnyTokenizer
2121
from vllm.utils import Device, collect_from_async_generator, random_uuid
22+
from vllm.v1.engine import EngineCoreRequest
2223

2324
logger = init_logger(__name__)
2425

@@ -49,12 +50,16 @@ def dead_error(self) -> BaseException:
4950
@abstractmethod
5051
def generate(
5152
self,
52-
prompt: PromptType,
53+
prompt: Union[EngineCoreRequest, PromptType],
5354
sampling_params: SamplingParams,
5455
request_id: str,
56+
*,
57+
prompt_text: Optional[str] = None,
5558
lora_request: Optional[LoRARequest] = None,
59+
tokenization_kwargs: Optional[dict[str, Any]] = None,
5660
trace_headers: Optional[Mapping[str, str]] = None,
5761
priority: int = 0,
62+
data_parallel_rank: Optional[int] = None,
5863
) -> AsyncGenerator[RequestOutput, None]:
5964
"""Generate outputs for a request."""
6065
...

vllm/entrypoints/openai/serving_chat.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ async def create_chat_completion(
274274
generators: list[AsyncGenerator[RequestOutput, None]] = []
275275
try:
276276
for i, engine_prompt in enumerate(engine_prompts):
277-
sampling_params: Union[SamplingParams, BeamSearchParams]
277+
prompt_text, _, _ = (self._get_prompt_components(
278+
request_prompts[i]))
278279

279280
if self.default_sampling_params is None:
280281
self.default_sampling_params = {}
@@ -285,6 +286,7 @@ async def create_chat_completion(
285286
input_length=len(engine_prompt["prompt_token_ids"]),
286287
default_sampling_params=self.default_sampling_params)
287288

289+
sampling_params: Union[SamplingParams, BeamSearchParams]
288290
if request.use_beam_search:
289291
sampling_params = request.to_beam_search_params(
290292
max_tokens, self.default_sampling_params)
@@ -309,13 +311,25 @@ async def create_chat_completion(
309311
lora_request=lora_request,
310312
)
311313
else:
314+
engine_request, tokenization_kwargs = (
315+
await self._process_inputs(
316+
request_id,
317+
engine_prompt,
318+
sampling_params,
319+
lora_request=lora_request,
320+
trace_headers=trace_headers,
321+
priority=request.priority,
322+
))
323+
312324
generator = self.engine_client.generate(
313-
engine_prompt,
325+
engine_request,
314326
sampling_params,
315327
request_id,
316328
lora_request=lora_request,
317329
trace_headers=trace_headers,
318330
priority=request.priority,
331+
prompt_text=prompt_text,
332+
tokenization_kwargs=tokenization_kwargs,
319333
)
320334

321335
generators.append(generator)

vllm/entrypoints/openai/serving_completion.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import jinja2
1111
from fastapi import Request
12-
from typing_extensions import assert_never
1312

1413
from vllm.config import ModelConfig
1514
from vllm.engine.protocol import EngineClient
@@ -32,8 +31,7 @@
3231
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3332
from vllm.entrypoints.renderer import RenderConfig
3433
from vllm.entrypoints.utils import get_max_tokens
35-
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
36-
is_tokens_prompt)
34+
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
3735
from vllm.logger import init_logger
3836
from vllm.logprobs import Logprob
3937
from vllm.outputs import RequestOutput
@@ -157,23 +155,16 @@ async def create_completion(
157155
generators: list[AsyncGenerator[RequestOutput, None]] = []
158156
try:
159157
for i, engine_prompt in enumerate(engine_prompts):
160-
sampling_params: Union[SamplingParams, BeamSearchParams]
161-
# Mypy does not infer that engine_prompt will have only one of
162-
# "prompt_token_ids" or "prompt_embeds" defined, and both of
163-
# these as Union[object, the expected type], where it infers
164-
# object if engine_prompt is a subclass of one of the
165-
# typeddicts that defines both keys. Worse, because of
166-
# https://github.com/python/mypy/issues/8586, mypy does not
167-
# infer the type of engine_prompt correctly because of the
168-
# enumerate. So we need an unnecessary cast here.
169-
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
170-
engine_prompt)
171-
if is_embeds_prompt(engine_prompt):
172-
input_length = len(engine_prompt["prompt_embeds"])
173-
elif is_tokens_prompt(engine_prompt):
174-
input_length = len(engine_prompt["prompt_token_ids"])
158+
prompt_text, prompt_token_ids, prompt_embeds = (
159+
self._get_prompt_components(engine_prompt))
160+
161+
input_length = None
162+
if prompt_token_ids is not None:
163+
input_length = len(prompt_token_ids)
164+
elif prompt_embeds is not None:
165+
input_length = len(prompt_embeds)
175166
else:
176-
assert_never(engine_prompt)
167+
raise NotImplementedError
177168

178169
if self.default_sampling_params is None:
179170
self.default_sampling_params = {}
@@ -185,6 +176,7 @@ async def create_completion(
185176
default_sampling_params=self.default_sampling_params,
186177
)
187178

179+
sampling_params: Union[SamplingParams, BeamSearchParams]
188180
if request.use_beam_search:
189181
sampling_params = request.to_beam_search_params(
190182
max_tokens, self.default_sampling_params)
@@ -220,13 +212,25 @@ async def create_completion(
220212
lora_request=lora_request,
221213
)
222214
else:
215+
engine_request, tokenization_kwargs = (
216+
await self._process_inputs(
217+
request_id_item,
218+
engine_prompt,
219+
sampling_params,
220+
lora_request=lora_request,
221+
trace_headers=trace_headers,
222+
priority=request.priority,
223+
))
224+
223225
generator = self.engine_client.generate(
224-
engine_prompt,
226+
engine_request,
225227
sampling_params,
226228
request_id_item,
227229
lora_request=lora_request,
228230
trace_headers=trace_headers,
229231
priority=request.priority,
232+
prompt_text=prompt_text,
233+
tokenization_kwargs=tokenization_kwargs,
230234
)
231235

232236
generators.append(generator)

0 commit comments

Comments
 (0)