diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 3cf1850ee65a..14d5c8959b2f 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import List, Mapping, Optional, Union, overload +from typing import Any, List, Mapping, Optional, Union, overload from typing_extensions import deprecated @@ -36,6 +36,7 @@ class RPCProcessRequest: trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 + extra_args: Optional[Mapping[str, Any]] = None @overload def __init__( @@ -47,6 +48,7 @@ def __init__( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, ) -> None: ... @@ -62,6 +64,7 @@ def __init__( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, ) -> None: ... @@ -78,6 +81,7 @@ def __init__( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -95,6 +99,7 @@ def __init__( self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request self.priority = priority + self.extra_args = extra_args @dataclass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 85b5f31e3a4a..f2ae3665e24a 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -442,6 +442,7 @@ def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, ) -> AsyncGenerator[RequestOutput, None]: ... @@ -457,6 +458,7 @@ def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, ) -> AsyncGenerator[RequestOutput, None]: ... @@ -473,6 +475,7 @@ def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, *, inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: @@ -502,7 +505,7 @@ def generate( return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, - prompt_adapter_request, priority) + prompt_adapter_request, priority, extra_args) @overload def encode( @@ -586,6 +589,7 @@ async def _process_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + extra_args: Optional[Mapping[str, Any]] = None, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" @@ -639,6 +643,7 @@ async def _process_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + extra_args=extra_args, )) # 3) Send the RPCGenerateRequest to the MQLLMEngine.