Skip to content

Commit 559df69

Browse files
max-wittigantoineauger
authored andcommitted
[Bugfix][FE]: Always include usage with --enable-force-include-usage (vllm-project#20983)
Signed-off-by: Max Wittig <max.wittig@siemens.com> Signed-off-by: Antoine Auger <antoineauger@users.noreply.github.com> Co-authored-by: Antoine Auger <antoineauger@users.noreply.github.com>
1 parent 72cb727 commit 559df69

File tree

11 files changed

+172
-30
lines changed

11 files changed

+172
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ markers = [
107107
"distributed: run this test only in distributed GPU tests",
108108
"skip_v1: do not run this test with v1",
109109
"optional: optional tests that are automatically skipped, include --optional to run them",
110+
"extra_server_args: extra arguments to pass to the server fixture",
110111
]
111112

112113
[tool.ty.src]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import openai
4+
import pytest
5+
import pytest_asyncio
6+
7+
from ...utils import RemoteOpenAIServer
8+
9+
10+
@pytest.fixture(scope="module")
11+
def chat_server_with_force_include_usage(request): # noqa: F811
12+
args = [
13+
# use half precision for speed and memory savings in CI environment
14+
"--dtype",
15+
"bfloat16",
16+
"--max-model-len",
17+
"128",
18+
"--enforce-eager",
19+
"--max-num-seqs",
20+
"1",
21+
"--enable-force-include-usage",
22+
"--port",
23+
"55857",
24+
"--gpu-memory-utilization",
25+
"0.2",
26+
]
27+
28+
with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server:
29+
yield remote_server
30+
31+
32+
@pytest_asyncio.fixture
33+
async def chat_client_with_force_include_usage(chat_server_with_force_include_usage):
34+
async with chat_server_with_force_include_usage.get_async_client() as async_client:
35+
yield async_client
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_chat_with_enable_force_include_usage(
40+
chat_client_with_force_include_usage: openai.AsyncOpenAI,
41+
):
42+
messages = [
43+
{"role": "system", "content": "You are a helpful assistant."},
44+
{"role": "user", "content": "What is the capital of France?"},
45+
]
46+
47+
stream = await chat_client_with_force_include_usage.chat.completions.create(
48+
model="Qwen/Qwen3-0.6B",
49+
messages=messages,
50+
max_completion_tokens=10,
51+
extra_body=dict(min_tokens=10),
52+
temperature=0.0,
53+
stream=True,
54+
)
55+
last_completion_tokens = 0
56+
async for chunk in stream:
57+
if not len(chunk.choices):
58+
assert chunk.usage.prompt_tokens >= 0
59+
assert (
60+
last_completion_tokens == 0
61+
or chunk.usage.completion_tokens > last_completion_tokens
62+
or (
63+
not chunk.choices
64+
and chunk.usage.completion_tokens == last_completion_tokens
65+
)
66+
)
67+
assert chunk.usage.total_tokens == (
68+
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
69+
)
70+
else:
71+
assert chunk.usage is None
72+
73+
74+
@pytest.fixture(scope="module")
75+
def transcription_server_with_force_include_usage():
76+
args = [
77+
# use half precision for speed and memory savings in CI environment
78+
"--dtype",
79+
"bfloat16",
80+
"--max-num-seqs",
81+
"1",
82+
"--enforce-eager",
83+
"--enable-force-include-usage",
84+
"--gpu-memory-utilization",
85+
"0.2",
86+
]
87+
88+
with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server:
89+
yield remote_server
90+
91+
92+
@pytest_asyncio.fixture
93+
async def transcription_client_with_force_include_usage(
94+
transcription_server_with_force_include_usage,
95+
):
96+
async with (
97+
transcription_server_with_force_include_usage.get_async_client() as async_client
98+
):
99+
yield async_client
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_transcription_with_enable_force_include_usage(
104+
transcription_client_with_force_include_usage, winning_call
105+
):
106+
res = (
107+
await transcription_client_with_force_include_usage.audio.transcriptions.create(
108+
model="openai/whisper-large-v3-turbo",
109+
file=winning_call,
110+
language="en",
111+
temperature=0.0,
112+
stream=True,
113+
timeout=30,
114+
)
115+
)
116+
117+
async for chunk in res:
118+
if not len(chunk.choices):
119+
# final usage sent
120+
usage = chunk.usage
121+
assert isinstance(usage, dict)
122+
assert usage["prompt_tokens"] > 0
123+
assert usage["completion_tokens"] > 0
124+
assert usage["total_tokens"] > 0
125+
else:
126+
assert not hasattr(chunk, "usage")

vllm/entrypoints/openai/api_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,6 +1808,7 @@ async def init_app_state(
18081808
state.openai_serving_models,
18091809
request_logger=request_logger,
18101810
log_error_stack=args.log_error_stack,
1811+
enable_force_include_usage=args.enable_force_include_usage,
18111812
)
18121813
if "transcription" in supported_tasks
18131814
else None
@@ -1818,6 +1819,7 @@ async def init_app_state(
18181819
state.openai_serving_models,
18191820
request_logger=request_logger,
18201821
log_error_stack=args.log_error_stack,
1822+
enable_force_include_usage=args.enable_force_include_usage,
18211823
)
18221824
if "transcription" in supported_tasks
18231825
else None

vllm/entrypoints/openai/run_batch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def make_arg_parser(parser: FlexibleArgumentParser):
104104
default=False,
105105
help="If set to True, enable prompt_tokens_details in usage.",
106106
)
107+
parser.add_argument(
108+
"--enable-force-include-usage",
109+
action="store_true",
110+
default=False,
111+
help="If set to True, include usage on every request "
112+
"(even when stream_options is not specified)",
113+
)
107114

108115
return parser
109116

@@ -361,6 +368,7 @@ async def run_batch(
361368
chat_template=None,
362369
chat_template_content_format="auto",
363370
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
371+
enable_force_include_usage=args.enable_force_include_usage,
364372
)
365373
if "generate" in supported_tasks
366374
else None

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
5959
from vllm.entrypoints.openai.tool_parsers import ToolParser
6060
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
61-
from vllm.entrypoints.utils import get_max_tokens
61+
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
6262
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
6363
from vllm.logger import init_logger
6464
from vllm.logprobs import Logprob
@@ -101,7 +101,6 @@ def __init__(
101101
models=models,
102102
request_logger=request_logger,
103103
return_tokens_as_token_ids=return_tokens_as_token_ids,
104-
enable_force_include_usage=enable_force_include_usage,
105104
log_error_stack=log_error_stack,
106105
)
107106

@@ -352,7 +351,6 @@ async def create_chat_completion(
352351
conversation,
353352
tokenizer,
354353
request_metadata,
355-
enable_force_include_usage=self.enable_force_include_usage,
356354
)
357355

358356
try:
@@ -518,7 +516,6 @@ async def chat_completion_stream_generator(
518516
conversation: list[ConversationMessage],
519517
tokenizer: AnyTokenizer,
520518
request_metadata: RequestResponseMetadata,
521-
enable_force_include_usage: bool,
522519
) -> AsyncGenerator[str, None]:
523520
created_time = int(time.time())
524521
chunk_object_type: Final = "chat.completion.chunk"
@@ -596,13 +593,9 @@ async def chat_completion_stream_generator(
596593
return
597594

598595
stream_options = request.stream_options
599-
if stream_options:
600-
include_usage = stream_options.include_usage or enable_force_include_usage
601-
include_continuous_usage = (
602-
include_usage and stream_options.continuous_usage_stats
603-
)
604-
else:
605-
include_usage, include_continuous_usage = False, False
596+
include_usage, include_continuous_usage = should_include_usage(
597+
stream_options, self.enable_force_include_usage
598+
)
606599

607600
try:
608601
async for res in result_generator:

vllm/entrypoints/openai/serving_completion.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
2828
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
2929
from vllm.entrypoints.renderer import RenderConfig
30-
from vllm.entrypoints.utils import get_max_tokens
30+
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
3131
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
3232
from vllm.logger import init_logger
3333
from vllm.logprobs import Logprob
@@ -56,11 +56,11 @@ def __init__(
5656
models=models,
5757
request_logger=request_logger,
5858
return_tokens_as_token_ids=return_tokens_as_token_ids,
59-
enable_force_include_usage=enable_force_include_usage,
6059
log_error_stack=log_error_stack,
6160
)
6261
self.enable_prompt_tokens_details = enable_prompt_tokens_details
6362
self.default_sampling_params = self.model_config.get_diff_sampling_param()
63+
self.enable_force_include_usage = enable_force_include_usage
6464
if self.default_sampling_params:
6565
source = self.model_config.generation_config
6666
source = "model" if source == "auto" else source
@@ -256,7 +256,6 @@ async def create_completion(
256256
num_prompts=num_prompts,
257257
tokenizer=tokenizer,
258258
request_metadata=request_metadata,
259-
enable_force_include_usage=self.enable_force_include_usage,
260259
)
261260

262261
# Non-streaming response
@@ -320,7 +319,6 @@ async def completion_stream_generator(
320319
num_prompts: int,
321320
tokenizer: AnyTokenizer,
322321
request_metadata: RequestResponseMetadata,
323-
enable_force_include_usage: bool,
324322
) -> AsyncGenerator[str, None]:
325323
num_choices = 1 if request.n is None else request.n
326324
previous_text_lens = [0] * num_choices * num_prompts
@@ -331,13 +329,9 @@ async def completion_stream_generator(
331329
first_iteration = True
332330

333331
stream_options = request.stream_options
334-
if stream_options:
335-
include_usage = stream_options.include_usage or enable_force_include_usage
336-
include_continuous_usage = (
337-
include_usage and stream_options.continuous_usage_stats
338-
)
339-
else:
340-
include_usage, include_continuous_usage = False, False
332+
include_usage, include_continuous_usage = should_include_usage(
333+
stream_options, self.enable_force_include_usage
334+
)
341335

342336
try:
343337
async for prompt_idx, res in result_generator:

vllm/entrypoints/openai/serving_engine.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def __init__(
249249
*,
250250
request_logger: RequestLogger | None,
251251
return_tokens_as_token_ids: bool = False,
252-
enable_force_include_usage: bool = False,
253252
log_error_stack: bool = False,
254253
):
255254
super().__init__()
@@ -260,8 +259,6 @@ def __init__(
260259

261260
self.request_logger = request_logger
262261
self.return_tokens_as_token_ids = return_tokens_as_token_ids
263-
self.enable_force_include_usage = enable_force_include_usage
264-
265262
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
266263
self._apply_mistral_chat_template_async = make_async(
267264
apply_mistral_chat_template, executor=self._tokenizer_executor

vllm/entrypoints/openai/serving_responses.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def __init__(
127127
models=models,
128128
request_logger=request_logger,
129129
return_tokens_as_token_ids=return_tokens_as_token_ids,
130-
enable_force_include_usage=enable_force_include_usage,
131130
log_error_stack=log_error_stack,
132131
)
133132

vllm/entrypoints/openai/serving_transcription.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
request_logger: RequestLogger | None,
3838
return_tokens_as_token_ids: bool = False,
3939
log_error_stack: bool = False,
40+
enable_force_include_usage: bool = False,
4041
):
4142
super().__init__(
4243
engine_client=engine_client,
@@ -45,6 +46,7 @@ def __init__(
4546
return_tokens_as_token_ids=return_tokens_as_token_ids,
4647
task_type="transcribe",
4748
log_error_stack=log_error_stack,
49+
enable_force_include_usage=enable_force_include_usage,
4850
)
4951

5052
async def create_transcription(
@@ -96,6 +98,7 @@ def __init__(
9698
request_logger: RequestLogger | None,
9799
return_tokens_as_token_ids: bool = False,
98100
log_error_stack: bool = False,
101+
enable_force_include_usage: bool = False,
99102
):
100103
super().__init__(
101104
engine_client=engine_client,
@@ -104,6 +107,7 @@ def __init__(
104107
return_tokens_as_token_ids=return_tokens_as_token_ids,
105108
task_type="translate",
106109
log_error_stack=log_error_stack,
110+
enable_force_include_usage=enable_force_include_usage,
107111
)
108112

109113
async def create_translation(

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
return_tokens_as_token_ids: bool = False,
5959
task_type: Literal["transcribe", "translate"] = "transcribe",
6060
log_error_stack: bool = False,
61+
enable_force_include_usage: bool = False,
6162
):
6263
super().__init__(
6364
engine_client=engine_client,
@@ -74,6 +75,8 @@ def __init__(
7475
self.model_config, task_type
7576
)
7677

78+
self.enable_force_include_usage = enable_force_include_usage
79+
7780
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
7881

7982
if self.default_sampling_params:
@@ -261,9 +264,7 @@ async def _speech_to_text_stream_generator(
261264
completion_tokens = 0
262265
num_prompt_tokens = 0
263266

264-
include_usage = (
265-
request.stream_include_usage if request.stream_include_usage else False
266-
)
267+
include_usage = self.enable_force_include_usage or request.stream_include_usage
267268
include_continuous_usage = (
268269
request.stream_continuous_usage_stats
269270
if include_usage and request.stream_continuous_usage_stats

0 commit comments

Comments
 (0)