Skip to content

Commit b074399

Browse files
noooopalbertoperdomo2
authored andcommitted
[Model] Add num_cached_tokens for PoolingRequestOutput (vllm-project#27378)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 7437ded commit b074399

File tree

7 files changed

+75
-5
lines changed

7 files changed

+75
-5
lines changed

tests/models/language/pooling/test_auto_prefix_cache_support.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,25 @@ def test_classify_models(
1919
model: str,
2020
dtype: str,
2121
) -> None:
22-
example_prompts = example_prompts * 2
22+
# example_prompts is too short for testing prefix_caching
23+
example_prompts = [s * 10 for s in example_prompts]
2324

2425
with vllm_runner(
2526
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
2627
) as vllm_model:
2728
cache_config = vllm_model.llm.llm_engine.cache_config
2829
assert cache_config.enable_prefix_caching
29-
vllm_outputs = vllm_model.classify(example_prompts)
30+
31+
# First Run
32+
vllm_model.classify(example_prompts)
33+
34+
# assert prefix_caching works
35+
pooling_outputs = vllm_model.llm.encode(
36+
example_prompts, pooling_task="classify"
37+
)
38+
for output in pooling_outputs:
39+
assert output.num_cached_tokens > 0
40+
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
3041

3142
with hf_runner(
3243
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
@@ -54,7 +65,8 @@ def test_embed_models(
5465
model: str,
5566
dtype: str,
5667
):
57-
example_prompts = [str(s).strip() for s in example_prompts] * 2
68+
# example_prompts is too short for testing prefix_caching
69+
example_prompts = [str(s).strip() * 10 for s in example_prompts]
5870

5971
with vllm_runner(
6072
model,
@@ -64,7 +76,15 @@ def test_embed_models(
6476
) as vllm_model:
6577
cache_config = vllm_model.llm.llm_engine.cache_config
6678
assert cache_config.enable_prefix_caching
67-
vllm_outputs = vllm_model.embed(example_prompts)
79+
80+
# First Run
81+
vllm_model.embed(example_prompts)
82+
83+
# assert prefix_caching works
84+
pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed")
85+
for output in pooling_outputs:
86+
assert output.num_cached_tokens > 0
87+
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
6888

6989
with hf_runner(
7090
model,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm import TokensPrompt
7+
8+
9+
@pytest.mark.parametrize(
10+
"model",
11+
["Qwen/Qwen3-0.6B"],
12+
)
13+
@torch.inference_mode
14+
def test_embed_models(hf_runner, vllm_runner, model: str):
15+
n_prompt_tokens = [55, 56, 57]
16+
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
17+
18+
with vllm_runner(
19+
model,
20+
max_model_len=128,
21+
enforce_eager=True,
22+
runner="pooling",
23+
enable_chunked_prefill=False,
24+
enable_prefix_caching=False,
25+
) as vllm_model:
26+
pooling_outputs = vllm_model.llm.encode(
27+
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
28+
pooling_task="token_embed",
29+
)
30+
31+
for n, output in zip(n_prompt_tokens, pooling_outputs):
32+
assert len(output.prompt_token_ids) == n
33+
assert output.num_cached_tokens == 0

vllm/entrypoints/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,9 @@ def encode(
10781078
PoolingRequestOutput[Any](
10791079
request_id="",
10801080
outputs=processed_outputs,
1081+
num_cached_tokens=getattr(
1082+
processed_outputs, "num_cached_tokens", 0
1083+
),
10811084
prompt_token_ids=[],
10821085
finished=True,
10831086
)

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ async def _collect_batch(
583583
request_id=aggregator["request_id"],
584584
prompt_token_ids=original_token_ids,
585585
outputs=pooling_output_data,
586+
num_cached_tokens=0,
586587
finished=True,
587588
)
588589

vllm/entrypoints/score_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _cosine_similarity(
6666
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
6767
outputs=pair_score,
6868
prompt_token_ids=tokens,
69+
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
6970
finished=True,
7071
)
7172
)

vllm/outputs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]):
201201
request_id (str): A unique identifier for the pooling request.
202202
outputs (PoolingOutput): The pooling results for the given input.
203203
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
204+
num_cached_tokens: The number of tokens with prefix cache hit.
204205
finished (bool): A flag indicating whether the pooling is completed.
205206
"""
206207

207208
def __init__(
208-
self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool
209+
self,
210+
request_id: str,
211+
outputs: _O,
212+
prompt_token_ids: list[int],
213+
num_cached_tokens: int,
214+
finished: bool,
209215
):
210216
self.request_id = request_id
211217
self.prompt_token_ids = prompt_token_ids
218+
self.num_cached_tokens = num_cached_tokens
212219
self.finished = finished
213220
self.outputs = outputs
214221

@@ -217,6 +224,7 @@ def __repr__(self):
217224
f"{type(self).__name__}(request_id={self.request_id!r}, "
218225
f"outputs={self.outputs!r}, "
219226
f"prompt_token_ids={self.prompt_token_ids}, "
227+
f"num_cached_tokens={self.num_cached_tokens}, "
220228
f"finished={self.finished})"
221229
)
222230

@@ -255,6 +263,7 @@ def from_base(request_output: PoolingRequestOutput):
255263
request_id=request_output.request_id,
256264
outputs=EmbeddingOutput.from_base(request_output.outputs),
257265
prompt_token_ids=request_output.prompt_token_ids,
266+
num_cached_tokens=request_output.num_cached_tokens,
258267
finished=request_output.finished,
259268
)
260269

@@ -294,6 +303,7 @@ def from_base(request_output: PoolingRequestOutput):
294303
request_id=request_output.request_id,
295304
outputs=ClassificationOutput.from_base(request_output.outputs),
296305
prompt_token_ids=request_output.prompt_token_ids,
306+
num_cached_tokens=request_output.num_cached_tokens,
297307
finished=request_output.finished,
298308
)
299309

@@ -330,5 +340,6 @@ def from_base(request_output: PoolingRequestOutput):
330340
request_id=request_output.request_id,
331341
outputs=ScoringOutput.from_base(request_output.outputs),
332342
prompt_token_ids=request_output.prompt_token_ids,
343+
num_cached_tokens=request_output.num_cached_tokens,
333344
finished=request_output.finished,
334345
)

vllm/v1/engine/output_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _new_request_output(
230230
return PoolingRequestOutput(
231231
request_id=request_id,
232232
outputs=first_output,
233+
num_cached_tokens=self.num_cached_tokens,
233234
prompt_token_ids=self.prompt_token_ids,
234235
finished=finished,
235236
)

0 commit comments

Comments
 (0)