diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py new file mode 100644 index 000000000000..137ed9db8589 --- /dev/null +++ b/tests/entrypoints/openai/test_truncation.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import openai +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" +max_model_len = 128 + +input = """Immerse yourself in the enchanting chronicle of calculus, a + mathematical domain that has radically transformed our comprehension of + change and motion. Despite its roots in ancient civilizations, the + formal birth of calculus predominantly occurred in the 17th century, + primarily under the influential guidance of Sir Isaac Newton and Gottfried + Wilhelm Leibniz. The earliest traces of calculus concepts are found in + ancient Greek mathematics,most notably in the works of Eudoxus and + Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a + technique for computing areas and volumes through the use of finite sums. + This methodology laid crucial foundational work for integral calculus. + In the 17th century, both Newton and Leibniz independently pioneered + calculus, each contributing unique perspectives that would shape this new + field.""" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embed", + "--dtype", + "bfloat16", + "--enforce-eager", + "--max-model-len", + str(max_model_len), + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_smaller_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 10 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == truncation_size + + +@pytest.mark.asyncio +async def test_bigger_truncation_size(client: openai.AsyncOpenAI): + truncation_size = max_model_len + 1 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + err = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert str(err) == f"""openai.BadRequestError: + Error code: 400 - {{'object': 'error', + 'message': 'truncate_prompt_tokens value + ({truncation_size}) + is greater than max_model_len ({max_model_len}). + Please, select a smaller truncation size.', + 'type': 'BadRequestError', + 'param': None, 'code': 400}}""" + + +@pytest.mark.asyncio +async def test_max_truncation_size(client: openai.AsyncOpenAI): + truncation_size = -1 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == max_model_len diff --git a/tests/models/embedding/language/test_truncation_control.py b/tests/models/embedding/language/test_truncation_control.py new file mode 100644 index 000000000000..a215e1ec564d --- /dev/null +++ b/tests/models/embedding/language/test_truncation_control.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" +max_model_len = 128 + +input_str = """Immerse yourself in the enchanting chronicle of calculus, a + mathematical domain that has radically transformed our comprehension of + change and motion. Despite its roots in ancient civilizations, the + formal birth of calculus predominantly occurred in the 17th century, + primarily under the influential guidance of Sir Isaac Newton and Gottfried + Wilhelm Leibniz. The earliest traces of calculus concepts are found in + ancient Greek mathematics,most notably in the works of Eudoxus and + Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a + technique for computing areas and volumes through the use of finite sums. + This methodology laid crucial foundational work for integral calculus. + In the 17th century, both Newton and Leibniz independently pioneered + calculus, each contributing unique perspectives that would shape this new + field.""" + + +def test_smaller_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + + truncate_prompt_tokens = 10 + + with vllm_runner(model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + vllm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + prompt_tokens = vllm_output[0].prompt_token_ids + + assert len(prompt_tokens) == truncate_prompt_tokens + + +def test_max_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + truncate_prompt_tokens = -1 + + with vllm_runner(model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + vllm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + prompt_tokens = vllm_output[0].prompt_token_ids + + assert len(prompt_tokens) == max_model_len + + +def test_bigger_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + + truncate_prompt_tokens = max_model_len + 1 + + with pytest.raises(ValueError), vllm_runner( + model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + + llm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + assert llm_output == f"""truncate_prompt_tokens value + ({truncate_prompt_tokens}) is greater than + max_model_len ({max_model_len}). Please, select + a smaller truncation size.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c23530990611..a4aefe61ebe5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -645,6 +645,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -678,6 +679,7 @@ def add_request( params: Optional[Union[SamplingParams, PoolingParams]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -758,6 +760,7 @@ def add_request( processed_inputs = self.input_preprocessor.preprocess( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 7e5ac3a28452..5632e8ad446d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -2,7 +2,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Mapping, Optional +from typing import AsyncGenerator, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig, VllmConfig @@ -256,7 +256,7 @@ async def is_tracing_enabled(self) -> bool: async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, + model_output: Optional[list[SamplerOutput]] = None, ) -> None: ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 948e8f36e0e6..f1f48c700f4b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -25,6 +25,7 @@ resolve_chat_template_content_format) from vllm.entrypoints.score_utils import (_cosine_similarity, _validate_score_input_lens) +from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.logger import init_logger @@ -793,6 +794,7 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -807,6 +809,7 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[list[int]] = None, + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -821,6 +824,7 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[list[list[int]]] = None, + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -836,6 +840,7 @@ def encode( Sequence[PoolingParams]]] = None, *, prompt_token_ids: list[int], + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -851,6 +856,7 @@ def encode( Sequence[PoolingParams]]] = None, *, prompt_token_ids: list[list[int]], + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -864,6 +870,7 @@ def encode( prompts: None, pooling_params: None, prompt_token_ids: Union[list[int], list[list[int]]], + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -882,6 +889,7 @@ def encode( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -946,10 +954,15 @@ def encode( for pooling_param in pooling_params: pooling_param.verify(self.llm_engine.model_config) + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + self._validate_and_add_requests( prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, prompt_adapter_request=prompt_adapter_request, ) @@ -962,6 +975,7 @@ def embed( prompts: Union[PromptType, Sequence[PromptType]], /, *, + truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -995,6 +1009,7 @@ def embed( "Embedding API is only enabled for `--task embed`") items = self.encode(prompts, + truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, pooling_params=pooling_params, lora_request=lora_request, @@ -1055,6 +1070,7 @@ def _embedding_score( encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, + truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) @@ -1098,9 +1114,8 @@ def _cross_encoding_score( pooling_params = PoolingParams() tokenization_kwargs: dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) parsed_prompts = [] @@ -1323,6 +1338,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[list[int]] = None, ) -> None: @@ -1359,6 +1375,7 @@ def _validate_and_add_requests( self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, @@ -1369,6 +1386,7 @@ def _add_request( self, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -1379,6 +1397,7 @@ def _add_request( prompt, params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, prompt_adapter_request=prompt_adapter_request, priority=priority, ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 015943762ab1..d444442a9762 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None @@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # doc: begin-chat-embedding-pooling-params additional_data: Optional[Any] = None @@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel): model: Optional[str] = None text_1: Union[list[str], str] text_2: Union[list[str], str] - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # doc: begin-score-pooling-params additional_data: Optional[Any] = None @@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel): query: str documents: list[str] top_n: int = Field(default_factory=lambda: 0) - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # doc: begin-rerank-pooling-params additional_data: Optional[Any] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index ba960de17cab..4b4d2d8b76f4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -21,6 +21,7 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) @@ -85,16 +86,7 @@ async def create_embedding( request_id = f"embd-{self._base_request_id(raw_request)}" created_time = int(time.time()) - truncate_prompt_tokens = None - - if request.truncate_prompt_tokens is not None: - if request.truncate_prompt_tokens <= self.max_model_len: - truncate_prompt_tokens = request.truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + truncate_prompt_tokens = request.truncate_prompt_tokens pooling_params = request.to_pooling_params() @@ -104,6 +96,8 @@ async def create_embedding( return self.create_error_response(str(e)) try: + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens) ( lora_request, prompt_adapter_request, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 49b346a23baf..c3121eff562d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -173,7 +173,7 @@ def _normalize_prompt_text_to_input( request: AnyRequest, tokenizer: AnyTokenizer, prompt: str, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], add_special_tokens: bool, ) -> TextTokensPrompt: if (self.model_config.encoder_config is not None @@ -271,7 +271,7 @@ def _tokenize_prompt_input( request: AnyRequest, tokenizer: AnyTokenizer, prompt_input: Union[str, list[int]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: """ @@ -292,7 +292,7 @@ def _tokenize_prompt_inputs( request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: """ @@ -321,7 +321,7 @@ def _tokenize_prompt_input_or_inputs( request: AnyRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> list[TextTokensPrompt]: """ @@ -356,7 +356,7 @@ async def _preprocess_completion( request: CompletionLikeRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]: request_prompts = await self._tokenize_prompt_input_or_inputs_async( diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 779a3eded2c1..7c401d4f5cb1 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -21,6 +21,7 @@ PoolingResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators @@ -85,18 +86,11 @@ async def create_pooling( request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - truncate_prompt_tokens = None - - if request.truncate_prompt_tokens is not None: - if request.truncate_prompt_tokens <= self.max_model_len: - truncate_prompt_tokens = request.truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + truncate_prompt_tokens = request.truncate_prompt_tokens try: + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens) ( lora_request, prompt_adapter_request, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 73b4288cbb0d..9bdacb5518d6 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -18,6 +18,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.score_utils import (_cosine_similarity, _validate_score_input_lens) +from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -231,11 +232,6 @@ async def _run_scoring( truncate_prompt_tokens: Optional[int] = None, ) -> list[PoolingRequestOutput]: - tokenization_kwargs: dict[str, Any] = {} - if truncate_prompt_tokens is not None: - tokenization_kwargs["truncation"] = True - tokenization_kwargs["max_length"] = truncate_prompt_tokens - ( lora_request, prompt_adapter_request, @@ -247,12 +243,9 @@ async def _run_scoring( tokenizer = await self.engine_client.get_tokenizer(lora_request) - if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: - raise ValueError( - f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " - f"is greater than max_model_len ({self.max_model_len})." - f" Please, select a smaller truncation size.") + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, + tokenization_kwargs) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 53411a27b41e..80b6c07c603f 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -46,4 +46,4 @@ def _validate_score_input_lens( if len(texts_1) == 0: raise ValueError("At least one text element must be given") if len(texts_2) == 0: - raise ValueError("At least one text_pair element must be given") + raise ValueError("At least one text_pair element must be given") \ No newline at end of file diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index b88c2b3a080f..2fe6e1a9e9c4 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -3,6 +3,7 @@ import asyncio import functools import os +from typing import Any, Optional from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse @@ -134,3 +135,26 @@ def cli_env_setup(): if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def _validate_truncation_size( + max_model_len: int, + truncate_prompt_tokens: Optional[int], + tokenization_kwargs: Optional[dict[str, Any]] = None, +) -> Optional[int]: + + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens <= -1: + truncate_prompt_tokens = max_model_len + + if truncate_prompt_tokens > max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({max_model_len})." + f" Please, select a smaller truncation size.") + + if tokenization_kwargs is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + return truncate_prompt_tokens diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 0edb6da06209..56b60b893913 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import Mapping -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from typing_extensions import assert_never @@ -183,18 +183,21 @@ def _tokenize_prompt( self, prompt: str, lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: """ Apply the model's tokenizer to a text prompt, returning the corresponding token IDs. """ tokenizer = self.get_tokenizer_group() - add_special_tokens = None + if tokenization_kwargs is None: + tokenization_kwargs = {} + if self.model_config.hf_config.model_type == "whisper": # For Whisper, special tokens should be provided by the user based # on the task and language of their request. Also needed to avoid # appending an EOS token to the prompt which disrupts generation. - add_special_tokens = False + tokenization_kwargs["add_special_tokens"] = False if (self.model_config.encoder_config is not None and self.model_config.encoder_config.get( @@ -203,25 +206,27 @@ def _tokenize_prompt( return tokenizer.encode(prompt=prompt, lora_request=lora_request, - add_special_tokens=add_special_tokens) + **tokenization_kwargs) async def _tokenize_prompt_async( self, prompt: str, lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group() - add_special_tokens = None + if tokenization_kwargs is None: + tokenization_kwargs = {} + if self.model_config.hf_config.model_type == "whisper": # For Whisper, special tokens should be provided by the user based # on the task and language of their request. Also needed to avoid # appending an EOS token to the prompt which disrupts generation. - add_special_tokens = False - return await tokenizer.encode_async( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) + tokenization_kwargs["add_special_tokens"] = False + return await tokenizer.encode_async(prompt=prompt, + lora_request=lora_request, + **tokenization_kwargs) def _process_multimodal( self, @@ -281,6 +286,7 @@ async def _process_multimodal_async( def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> SingletonInputs: @@ -304,6 +310,7 @@ def _prompt_to_llm_inputs( prompt_token_ids = self._tokenize_prompt( prompt_text, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) return token_inputs( @@ -352,6 +359,7 @@ def _prompt_to_llm_inputs( prompt_token_ids = self._tokenize_prompt( prompt_text, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) return token_inputs( @@ -364,6 +372,7 @@ def _prompt_to_llm_inputs( async def _prompt_to_llm_inputs_async( self, prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, return_mm_hashes: bool = False, ) -> SingletonInputs: @@ -375,6 +384,7 @@ async def _prompt_to_llm_inputs_async( prompt_token_ids = await self._tokenize_prompt_async( prompt_text, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) return token_inputs( @@ -517,6 +527,7 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( def _process_encoder_decoder_prompt( self, prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -553,7 +564,9 @@ def _process_encoder_decoder_prompt( if is_explicit_encoder_decoder_prompt(prompt): encoder_inputs = self._prompt_to_llm_inputs( - prompt["encoder_prompt"]) + prompt["encoder_prompt"], + tokenization_kwargs=tokenization_kwargs, + ) if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None else: @@ -565,7 +578,10 @@ def _process_encoder_decoder_prompt( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: - inputs = self._prompt_to_llm_inputs(prompt) + inputs = self._prompt_to_llm_inputs( + prompt, + tokenization_kwargs=tokenization_kwargs, + ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( @@ -581,6 +597,7 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_inputs: SingletonInputs @@ -588,13 +605,18 @@ async def _process_encoder_decoder_prompt_async( if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._prompt_to_llm_inputs_async( - prompt["encoder_prompt"]) + prompt["encoder_prompt"], + tokenization_kwargs=tokenization_kwargs, + ) if (decoder_input := prompt["decoder_prompt"]) is None: encoder_inputs = await encoder_task decoder_inputs = None else: - decoder_task = self._prompt_to_llm_inputs_async(decoder_input) + decoder_task = self._prompt_to_llm_inputs_async( + decoder_input, + tokenization_kwargs=tokenization_kwargs, + ) encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) @@ -606,7 +628,10 @@ async def _process_encoder_decoder_prompt_async( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: - inputs = await self._prompt_to_llm_inputs_async(prompt) + inputs = await self._prompt_to_llm_inputs_async( + prompt, + tokenization_kwargs=tokenization_kwargs, + ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( @@ -638,6 +663,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, @@ -660,6 +686,7 @@ def _process_decoder_only_prompt( prompt_comps = self._prompt_to_llm_inputs( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -672,6 +699,7 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, @@ -679,6 +707,7 @@ async def _process_decoder_only_prompt_async( """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._prompt_to_llm_inputs_async( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, return_mm_hashes=return_mm_hashes, ) @@ -691,6 +720,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, @@ -711,6 +741,7 @@ def preprocess( # Decoder-only operation return self._process_decoder_only_prompt( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, return_mm_hashes=return_mm_hashes, @@ -719,6 +750,7 @@ def preprocess( async def preprocess_async( self, prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, @@ -739,6 +771,7 @@ async def preprocess_async( # Decoder-only operation return await self._process_decoder_only_prompt_async( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, return_mm_hashes=return_mm_hashes, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c430b74a9db9..eedc02a763fd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -190,9 +190,10 @@ class SamplingParams( logits_processors: list of functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument. - truncate_prompt_tokens: If set to an integer k, will use only the last k - tokens from the prompt (i.e., left truncation). Defaults to None - (i.e., no truncation). + truncate_prompt_tokens: If set to -1, will use the truncation size + supported by the model. If set to an integer k, will use only + the last k tokens from the prompt (i.e., left truncation). + Defaults to None (i.e., no truncation). guided_decoding: If provided, the engine will construct a guided decoding logits processor from these parameters. Defaults to None. logit_bias: If provided, the engine will construct a logits processor diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index da5bec856662..57b9242b88ab 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -55,6 +55,8 @@ def encode_tokens( tokenizer: AnyTokenizer, text: str, *, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, add_special_tokens: Optional[bool] = None, ) -> list[int]: """ @@ -64,10 +66,18 @@ def encode_tokens( :code:`add_special_tokens=None` means to use the backend's default settings. """ + + kw_args: dict[str, Any] = {} + if max_length is not None: + kw_args["max_length"] = max_length + + if truncation is not None: + kw_args["truncation"] = truncation + if add_special_tokens is not None: - return tokenizer.encode(text, add_special_tokens=add_special_tokens) + kw_args["add_special_tokens"] = add_special_tokens - return tokenizer.encode(text) + return tokenizer.encode(text, **kw_args) def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index b4eb081c9b99..d69e5a6b4251 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -94,6 +94,8 @@ def encode_one( @abstractmethod def encode(self, text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index a829985cb459..aff2d2eb1c35 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -45,11 +45,16 @@ def _raise_if_input_too_long(self, def encode(self, prompt: str, + max_length: Optional[int] = None, + truncation: Optional[bool] = None, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) ret = encode_tokens(tokenizer, prompt, + max_length=max_length, + truncation=truncation, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret @@ -57,11 +62,15 @@ def encode(self, async def encode_async( self, prompt: str, + max_length: Optional[int] = None, + truncation: Optional[bool] = None, lora_request: Optional[LoRARequest] = None, add_special_tokens: Optional[bool] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) ret = encode_tokens(tokenizer, prompt, + max_length=max_length, + truncation=truncation, add_special_tokens=add_special_tokens) self._raise_if_input_too_long(ret, lora_request) return ret diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 296149a45695..6d4655781f9e 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -359,6 +359,8 @@ def encode_one( def encode(self, text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, add_special_tokens: Optional[bool] = None) -> List[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1334fb789aa4..2562fcc9cbd9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import AsyncGenerator, Mapping from copy import copy -from typing import Optional, Union +from typing import Any, Optional, Union import numpy as np @@ -201,6 +201,7 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -219,7 +220,8 @@ async def add_request( # Convert Input --> Request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, - trace_headers, prompt_adapter_request, priority) + tokenization_kwargs, trace_headers, prompt_adapter_request, + priority) if params.n == 1: await self._add_request(request, prompt_str, None, 0, queue) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 85da58451c78..b471b153657f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -175,6 +175,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -182,7 +183,8 @@ def add_request( # Process raw inputs into the request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, - trace_headers, prompt_adapter_request, priority) + tokenization_kwargs, trace_headers, prompt_adapter_request, + priority) n = params.n if isinstance(params, SamplingParams) else 1 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c15e8baef2b..5269248a2624 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -2,7 +2,7 @@ import time from collections.abc import Mapping, Sequence -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs @@ -198,6 +198,7 @@ def process_inputs( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, @@ -224,6 +225,7 @@ def process_inputs( # 3. Apply prompt adapter to prompt token ids if one exists. processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, + tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash,