Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support serving encoder/decoder models #7258

Merged
merged 35 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
33c9e25
Introduce `is_list_of`
DarkLight1337 Aug 7, 2024
e6dd6f5
Avoid circular imports
DarkLight1337 Aug 7, 2024
f938c86
Refactor prompt parsing and extend this to async engine
DarkLight1337 Aug 7, 2024
6332d1e
Remove unnecessary comments
DarkLight1337 Aug 7, 2024
07b4d21
Enable full async
DarkLight1337 Aug 7, 2024
e29864c
grammar
DarkLight1337 Aug 7, 2024
c9dfb40
Add description
DarkLight1337 Aug 7, 2024
1233192
Fix wrong type annotations
DarkLight1337 Aug 7, 2024
f332275
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
dcdebee
Remove redundant docs
DarkLight1337 Aug 7, 2024
65db3f1
Be more strict
DarkLight1337 Aug 7, 2024
9ffeb22
Fix docs
DarkLight1337 Aug 7, 2024
c9e0b08
Fix 2
DarkLight1337 Aug 7, 2024
14bca1f
Disallow multi-modal data for enc/dec models
DarkLight1337 Aug 7, 2024
8fc7099
Improve type narrowing behavior using `TypeIs`
DarkLight1337 Aug 7, 2024
3a8a072
Avoid sequential await
DarkLight1337 Aug 7, 2024
ef5327c
Fix type annotations based on test files
DarkLight1337 Aug 7, 2024
8a835cc
Properly handle `inputs["decoder_prompt"]=None`
DarkLight1337 Aug 7, 2024
e0024c2
Clean
DarkLight1337 Aug 7, 2024
76af172
Clean
DarkLight1337 Aug 7, 2024
5c16f2e
Fix incorrect decoder inputs in singleton case
DarkLight1337 Aug 7, 2024
e239ba9
Clean
DarkLight1337 Aug 7, 2024
4b0e3df
Move functions to a more appropriate place
DarkLight1337 Aug 7, 2024
53f7f50
Remove outdated comment
DarkLight1337 Aug 7, 2024
3afdbc5
Fix mismatch between hf and vllm output text
DarkLight1337 Aug 7, 2024
c61b01f
Factor out duplicate code
DarkLight1337 Aug 7, 2024
f8ed373
Factor out more duplicate code
DarkLight1337 Aug 7, 2024
a4df70a
Remove default values to avoid accidentally miss those arguments
DarkLight1337 Aug 7, 2024
5240bb3
Add test for serving encoder/decoder model with OpenAI server
DarkLight1337 Aug 7, 2024
d321c82
Use two type variables
DarkLight1337 Aug 7, 2024
931d1f6
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
a06c67f
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
9f64a05
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
e4c5c21
Update error message
DarkLight1337 Aug 8, 2024
68fbf5a
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
'''

from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, zip_enc_dec_prompt_lists)

dtype = "float"

Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
2 changes: 1 addition & 1 deletion requirements-lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5

# type checking
mypy==1.9.0
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools
2 changes: 1 addition & 1 deletion requirements-openvino.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.inputs import (TextPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
is_cpu)

logger = init_logger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt

STRING_INPUTS = [
'',
Expand Down
167 changes: 144 additions & 23 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Optional, Set, Tuple, Type, Union)

from transformers import PreTrainedTokenizer
from typing_extensions import assert_never

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
Expand All @@ -16,9 +17,12 @@
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -291,38 +295,153 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

async def process_model_inputs_async(
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")

return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)

async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> Tuple[Optional[str], List[int], Optional[MultiModalDataDict]]:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)

multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)

return prompt, prompt_token_ids, multi_modal_data

async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
explicit_inputs = self._to_explicit_encoder_decoder_prompt(inputs)
extracted_encoder_prompt = explicit_inputs["encoder_prompt"]
extracted_decoder_prompt = explicit_inputs["decoder_prompt"]

(
encoder_prompt,
encoder_prompt_token_ids,
_,
) = await self._extract_prompt_components_async(
extracted_encoder_prompt,
request_id=request_id,
)

# Avoid repeated processing if the input was originally in singleton
# form, see self._to_explicit_encoder_decoder_prompt
if extracted_decoder_prompt is extracted_encoder_prompt:
decoder_prompt_token_ids = encoder_prompt_token_ids
decoder_prompt = encoder_prompt
else:
(
decoder_prompt,
decoder_prompt_token_ids,
_,
) = await self._extract_prompt_components_async(
extracted_decoder_prompt,
request_id=request_id,
)

decoder_prompt_token_ids = (
self._prepare_decoder_input_ids_for_generation(
decoder_prompt_token_ids))

return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_token_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids,
encoder_prompt=encoder_prompt,
)

async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
"""Async version of :meth:`_process_decoder_only_prompt`."""
(
prompt,
prompt_token_ids,
multi_modal_data,
) = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)

prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)

if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)

prompt_token_ids = await tokenizer.encode_async(
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
)
else:
prompt_token_ids = inputs["prompt_token_ids"]

if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")

llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

return self.input_processor(llm_inputs)
return self.input_processor(model_inputs)

async def add_request_async(
self,
Expand All @@ -334,17 +453,19 @@ async def add_request_async(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not change in this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't quite get what you mean by this. Could you elaborate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function needs changes in this PR - just a nit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - this is just to keep the order of arguments consistent with the new ordering of parameters in process_model_inputs_async (which has been updated alongside process_model_inputs).

if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()

processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)

self._add_processed_request(
request_id=request_id,
Expand Down
Loading
Loading