Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9b0ce65
[FEATURE] Enables offline /score for embedding models
gmarinho2 Jan 21, 2025
74cd2dd
Merge branch 'upstream_main'
gmarinho2 Jan 21, 2025
590ab4d
Changes variable name and uses tuple instead of list.
gmarinho2 Jan 21, 2025
f219024
Separates scoring logic and makes small ajustments
gmarinho2 Jan 22, 2025
47211e5
Separates scoring logic and makes small ajustments
gmarinho2 Jan 22, 2025
1a89033
Moves scoring functions declaration to llm class
gmarinho2 Jan 22, 2025
db7919b
Completes embedding_score function signature
gmarinho2 Jan 23, 2025
41b11be
Passes new parameters to self.encode() in embeddind_score
gmarinho2 Jan 23, 2025
b57e01a
Adds type annotations for the parameters
gmarinho2 Jan 23, 2025
4956eae
Minor adjustments
gmarinho2 Jan 23, 2025
f8b8d8c
Minor adjustments in embedding_score
gmarinho2 Jan 23, 2025
cd835de
trigger ci
gmarinho2 Jan 24, 2025
0e95ead
trigger ci
gmarinho2 Jan 24, 2025
bba2ea6
Merge branch 'vllm-project:main' into main
gmarinho2 Feb 4, 2025
635b8e8
Merge branch 'vllm-project:main' into main
gmarinho2 Feb 5, 2025
9e1dbbe
Merge branch 'vllm-project:main' into main
gmarinho2 Feb 20, 2025
2ed8d7b
Merge branch 'vllm-project:main' into main
gmarinho2 Feb 21, 2025
7e05fb0
Merge branch 'vllm-project:main' into main
gmarinho2 Feb 25, 2025
3931763
Merge branch 'vllm-project:main' into main
gmarinho2 Mar 11, 2025
45dba65
Merge branch 'vllm-project:main' into truncation-control
gmarinho2 Mar 12, 2025
8520fd3
Set truncation size to max_model_len when truncate_prompt_tokens is -1
gmarinho2 Feb 25, 2025
b3798a0
put truncation control in score_utils
gmarinho2 Feb 26, 2025
9035bfa
implement local truncation control
gmarinho2 Mar 12, 2025
8ba7e22
--amend
gmarinho2 Mar 12, 2025
b3a8522
fix typing errors
gmarinho2 Mar 12, 2025
64b7d6d
fix local truncation control
gmarinho2 Mar 12, 2025
1db5069
implement tests for local truncation control
gmarinho2 Mar 13, 2025
79a5693
implement tests for local truncation control
gmarinho2 Mar 13, 2025
48fef03
fix mypy errors
gmarinho2 Mar 14, 2025
d692498
fix type error in tokenizer.encode paramenter
gmarinho2 Mar 14, 2025
281dafd
Merge branch 'vllm-project:main' into main
gmarinho2 Mar 20, 2025
5ce5ef2
Merge branch 'vllm-project:main' into main
gmarinho2 Apr 1, 2025
b55e759
refact
gmarinho2 Apr 1, 2025
e7e7197
Merge branch 'vllm-project:main' into main
gmarinho2 Apr 1, 2025
1f9aa00
enable max size config (using -1) for api
gmarinho2 Apr 3, 2025
50364f4
Merge branch 'main' into truncation-control
gmarinho2 Apr 3, 2025
493d118
fix mypy error
gmarinho2 Apr 3, 2025
14fe15a
add back request_id param and create test for online api
gmarinho2 Apr 3, 2025
7064e5d
implement tests for online API and minor changes
gmarinho2 Apr 8, 2025
7fca843
minor changes to fix mypy errors
gmarinho2 Apr 8, 2025
95da043
minor changes to fix mypy errors
gmarinho2 Apr 8, 2025
a11bcc0
add documentation
gmarinho2 Apr 8, 2025
0b8f0c5
make async tokenization method the same as the sync versions
maxdebayser Apr 8, 2025
4be4e0d
Merge pull request #1 from maxdebayser/truncation-control
gmarinho2 Apr 8, 2025
fea93ce
Refactor function to cut down on the repeated code
maxdebayser Apr 8, 2025
50da440
Merge pull request #2 from maxdebayser/truncation-control
gmarinho2 Apr 8, 2025
c0f0f5f
change tests
gmarinho2 Apr 8, 2025
bb3b804
refact test_truncation_control
gmarinho2 Apr 10, 2025
1d150e4
refact online api test
gmarinho2 Apr 10, 2025
6569363
fix mypy errors
gmarinho2 Apr 10, 2025
30706ab
add truncation control to v1's LLMengine
gmarinho2 Apr 16, 2025
cae8d6a
Merge branch 'main' into truncation-control
gmarinho2 Apr 16, 2025
7876694
Merge branch 'vllm-project:main' into main
gmarinho2 Apr 17, 2025
f2ee840
Merge branch 'main' into truncation-control
gmarinho2 Apr 17, 2025
83746e3
Merge branch 'vllm-project:main' into main
gmarinho2 Apr 22, 2025
76829b5
Merge branch 'main' into truncation-control
gmarinho2 Apr 22, 2025
fe13004
Merge branch 'upstream_main' into truncation-control
maxdebayser Apr 24, 2025
37fa5c9
Merge branch 'upstream_main' into truncation-control
maxdebayser Apr 28, 2025
13f98ac
Merge branch 'vllm-project:main' into main
gmarinho2 Apr 29, 2025
7f53602
Merge branch 'main' into truncation-control
gmarinho2 Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions tests/entrypoints/openai/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any

import openai
import pytest
import pytest_asyncio

from tests.utils import RemoteOpenAIServer

MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128

input = """Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""


@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embed",
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
str(max_model_len),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 10
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}

response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})

assert response["usage"]["prompt_tokens"] == truncation_size


@pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}

with pytest.raises(openai.BadRequestError) as err:
err = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})

assert str(err) == f"""openai.BadRequestError:
Error code: 400 - {{'object': 'error',
'message': 'truncate_prompt_tokens value
({truncation_size})
is greater than max_model_len ({max_model_len}).
Please, select a smaller truncation size.',
'type': 'BadRequestError',
'param': None, 'code': 400}}"""


@pytest.mark.asyncio
async def test_max_truncation_size(client: openai.AsyncOpenAI):
truncation_size = -1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}

response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})

assert response["usage"]["prompt_tokens"] == max_model_len
69 changes: 69 additions & 0 deletions tests/models/embedding/language/test_truncation_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128

input_str = """Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""


def test_smaller_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):

truncate_prompt_tokens = 10

with vllm_runner(model_name, task="embed",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

prompt_tokens = vllm_output[0].prompt_token_ids

assert len(prompt_tokens) == truncate_prompt_tokens


def test_max_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):
truncate_prompt_tokens = -1

with vllm_runner(model_name, task="embed",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

prompt_tokens = vllm_output[0].prompt_token_ids

assert len(prompt_tokens) == max_model_len


def test_bigger_truncation_size(vllm_runner,
model_name=MODEL_NAME,
input_str=input_str):

truncate_prompt_tokens = max_model_len + 1

with pytest.raises(ValueError), vllm_runner(
model_name, task="embed",
max_model_len=max_model_len) as vllm_model:

llm_output = vllm_model.model.encode(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)

assert llm_output == f"""truncate_prompt_tokens value
({truncate_prompt_tokens}) is greater than
max_model_len ({max_model_len}). Please, select
a smaller truncation size."""
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
Expand Down Expand Up @@ -678,6 +679,7 @@ def add_request(
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
Expand Down Expand Up @@ -758,6 +760,7 @@ def add_request(

processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional

from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
Expand Down Expand Up @@ -256,7 +256,7 @@ async def is_tracing_enabled(self) -> bool:
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
...

Expand Down
25 changes: 22 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
Expand Down Expand Up @@ -793,6 +794,7 @@ def encode(
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -807,6 +809,7 @@ def encode(
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -821,6 +824,7 @@ def encode(
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -836,6 +840,7 @@ def encode(
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[int],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -851,6 +856,7 @@ def encode(
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[list[int]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -864,6 +870,7 @@ def encode(
prompts: None,
pooling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -882,6 +889,7 @@ def encode(
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand Down Expand Up @@ -946,10 +954,15 @@ def encode(
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)

tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)

self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
)

Expand All @@ -962,6 +975,7 @@ def embed(
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
Expand Down Expand Up @@ -995,6 +1009,7 @@ def embed(
"Embedding API is only enabled for `--task embed`")

items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
Expand Down Expand Up @@ -1055,6 +1070,7 @@ def _embedding_score(

encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
Expand Down Expand Up @@ -1098,9 +1114,8 @@ def _cross_encoding_score(
pooling_params = PoolingParams()

tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)

parsed_prompts = []

Expand Down Expand Up @@ -1323,6 +1338,7 @@ def _validate_and_add_requests(
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
) -> None:
Expand Down Expand Up @@ -1359,6 +1375,7 @@ def _validate_and_add_requests(
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand All @@ -1369,6 +1386,7 @@ def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
Expand All @@ -1379,6 +1397,7 @@ def _add_request(
prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
Expand Down Expand Up @@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
Expand Down Expand Up @@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
model: Optional[str] = None
text_1: Union[list[str], str]
text_2: Union[list[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# doc: begin-score-pooling-params
additional_data: Optional[Any] = None
Expand All @@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
query: str
documents: list[str]
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# doc: begin-rerank-pooling-params
additional_data: Optional[Any] = None
Expand Down
Loading