Skip to content

Commit

Permalink
Add backwards compatibility for #8673
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Sep 24, 2024
1 parent 530821d commit 79955a7
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 12 deletions.
80 changes: 77 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType

import vllm.envs as envs
Expand All @@ -28,7 +28,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
from vllm.utils import deprecate_kwargs, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -402,6 +402,21 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

@overload
async def add_request_async(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@overload
async def add_request_async(
self,
request_id: str,
Expand All @@ -411,8 +426,30 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
Expand Down Expand Up @@ -774,6 +811,21 @@ async def run_engine_loop(engine_ref: ReferenceType):

# This method does not need to be async, but kept that way
# for backwards compatibility.
@overload
async def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
...

@overload
async def add_request(
self,
request_id: str,
Expand All @@ -782,8 +834,30 @@ async def add_request(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
Expand Down
41 changes: 39 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union
from typing import Set, Type, Union, overload

import torch
from typing_extensions import TypeVar
Expand Down Expand Up @@ -51,7 +51,7 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, weak_bind
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -686,6 +686,21 @@ def _add_processed_request(
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()

@overload
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@overload
def add_request(
self,
request_id: str,
Expand All @@ -695,6 +710,24 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Add a request to the engine's request pool.
Expand Down Expand Up @@ -737,6 +770,10 @@ def add_request(
>>> # continue the request processing
>>> ...
"""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
Expand Down
57 changes: 56 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Union, overload

from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs

VLLM_RPC_SUCCESS_STR = "SUCCESS"

Expand All @@ -30,6 +31,60 @@ class RPCProcessRequest:
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None

@overload
def __init__(
self,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@overload
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def __init__(
self,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
if inputs is not None:
prompt = inputs
assert (prompt is not None and params is not None
and request_id is not None)

super().__init__()

self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request


@dataclass
class RPCError:
Expand Down
Loading

0 comments on commit 79955a7

Please sign in to comment.