Skip to content

Commit

Permalink
[Core] [Frontend] Priority scheduling for embeddings and in the OpenA…
Browse files Browse the repository at this point in the history
…I-API (#8965)
  • Loading branch information
schoennenbeck authored Oct 1, 2024
1 parent 1fe0a42 commit 35bd215
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 5 deletions.
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ async def encode(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Expand All @@ -1057,6 +1058,8 @@ async def encode(
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
Expand Down Expand Up @@ -1109,6 +1112,7 @@ async def encode(
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)

Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0

@overload # DEPRECATED
def __init__(
Expand All @@ -41,6 +42,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...

Expand All @@ -53,6 +55,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...

Expand All @@ -68,6 +71,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
Expand All @@ -84,6 +88,7 @@ def __init__(
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority


@dataclass
Expand Down
20 changes: 16 additions & 4 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...

Expand All @@ -392,6 +393,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...

Expand All @@ -407,6 +409,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -425,6 +428,9 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if inputs is not None:
prompt = inputs
Expand All @@ -433,7 +439,7 @@ def generate(

return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
prompt_adapter_request, priority)

@overload # DEPRECATED
def encode(
Expand All @@ -444,6 +450,7 @@ def encode(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...

Expand All @@ -455,6 +462,7 @@ def encode(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...

Expand All @@ -469,6 +477,7 @@ def encode(
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
Expand Down Expand Up @@ -496,7 +505,7 @@ def encode(
and request_id is not None)

return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers)
lora_request, trace_headers, priority)

async def _process_request(
self,
Expand All @@ -505,7 +514,8 @@ async def _process_request(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
Expand Down Expand Up @@ -550,7 +560,9 @@ async def _process_request(
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))

# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def generate(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
Expand All @@ -52,6 +53,7 @@ def encode(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...
Expand Down
22 changes: 22 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-completion-extra-params

Expand Down Expand Up @@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):

# doc: end-embedding-pooling-params

# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-embedding-extra-params

def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ async def create_chat_completion(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def create_completion(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)

generators.append(generator)
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def create_embedding(
pooling_params,
request_id_item,
lora_request=lora_request,
priority=request.priority,
)

generators.append(generator)
Expand Down

0 comments on commit 35bd215

Please sign in to comment.