Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 10 additions & 49 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated=
Expand Down Expand Up @@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
}

def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
self, max_tokens: int,
default_sampling_params: dict) -> BeamSearchParams:

if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
Expand All @@ -465,21 +451,10 @@ def to_beam_search_params(

def to_sampling_params(
self,
default_max_tokens: int,
max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
default_sampling_params: dict,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens

if default_sampling_params is None:
default_sampling_params = {}

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
Expand Down Expand Up @@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
}

def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
self,
max_tokens: int,
default_sampling_params: Optional[dict] = None,
) -> BeamSearchParams:
max_tokens = self.max_tokens

if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)

Expand All @@ -928,21 +896,14 @@ def to_beam_search_params(

def to_sampling_params(
self,
default_max_tokens: int,
max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens

if default_sampling_params is None:
default_sampling_params = {}

# Use minimum of context window, user request & server limit.
max_tokens = min(
val for val in (default_max_tokens, max_tokens,
default_sampling_params.get("max_tokens", None))
if val is not None)

# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
Expand Down Expand Up @@ -1813,7 +1774,7 @@ def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API

max_tokens = default_max_tokens

if default_sampling_params is None:
Expand Down Expand Up @@ -2029,7 +1990,7 @@ def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API

max_tokens = default_max_tokens

if default_sampling_params is None:
Expand Down
18 changes: 13 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
Expand Down Expand Up @@ -233,15 +234,22 @@ async def create_chat_completion(
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])

if self.default_sampling_params is None:
self.default_sampling_params = {}

max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)

if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)

self._log_inputs(request_id,
Expand Down
16 changes: 12 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_text_tokens_prompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
Expand Down Expand Up @@ -160,15 +161,22 @@ async def create_completion(
input_length = len(engine_prompt["prompt_token_ids"])
else:
assert_never(engine_prompt)
default_max_tokens = self.max_model_len - input_length

if self.default_sampling_params is None:
self.default_sampling_params = {}

max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)

if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)

request_id_item = f"{request_id}-{i}"
Expand Down
22 changes: 20 additions & 2 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import asyncio
import functools
import os
from typing import Any, Optional
import sys
from typing import Any, Optional, Union

from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -181,7 +185,6 @@ def _validate_truncation_size(

def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
subcommand_name: list[str]):
import sys

# Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup,
Expand Down Expand Up @@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1)


def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
CompletionRequest],
input_length: int, default_sampling_params: dict) -> int:

max_tokens = getattr(request, "max_completion_tokens",
None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)

return min(val
for val in (default_max_tokens, max_tokens, max_output_tokens,
default_sampling_params.get("max_tokens"))
if val is not None)
4 changes: 4 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import platform
import random
import sys
from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
Expand Down Expand Up @@ -164,6 +165,9 @@ def is_neuron(self) -> bool:
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

def get_max_output_tokens(self, prompt_len: int) -> int:
return sys.maxsize

def is_cuda_alike(self) -> bool:
"""Stateless version of [torch.cuda.is_available][]."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
Expand Down