Skip to content

Commit 75e6e14

Browse files
authored
[V1][Metrics] Add several request timing histograms (#12644)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 110f59a commit 75e6e14

File tree

16 files changed

+334
-84
lines changed

16 files changed

+334
-84
lines changed

tests/entrypoints/openai/test_metrics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ async def client(server):
8585
"vllm:time_per_output_token_seconds":
8686
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
8787
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
88+
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
89+
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
90+
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
91+
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
8892
"vllm:request_prompt_tokens":
8993
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
9094
("_count", _NUM_REQUESTS)],
@@ -169,6 +173,18 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
169173
"vllm:e2e_request_latency_seconds_sum",
170174
"vllm:e2e_request_latency_seconds_bucket",
171175
"vllm:e2e_request_latency_seconds_count",
176+
"vllm:request_queue_time_seconds_sum",
177+
"vllm:request_queue_time_seconds_bucket",
178+
"vllm:request_queue_time_seconds_count",
179+
"vllm:request_inference_time_seconds_sum",
180+
"vllm:request_inference_time_seconds_bucket",
181+
"vllm:request_inference_time_seconds_count",
182+
"vllm:request_prefill_time_seconds_sum",
183+
"vllm:request_prefill_time_seconds_bucket",
184+
"vllm:request_prefill_time_seconds_count",
185+
"vllm:request_decode_time_seconds_sum",
186+
"vllm:request_decode_time_seconds_bucket",
187+
"vllm:request_decode_time_seconds_count",
172188
"vllm:request_prompt_tokens_sum",
173189
"vllm:request_prompt_tokens_bucket",
174190
"vllm:request_prompt_tokens_count",
@@ -220,6 +236,21 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
220236
"vllm:time_per_output_token_seconds_sum",
221237
"vllm:time_per_output_token_seconds_bucket",
222238
"vllm:time_per_output_token_seconds_count",
239+
"vllm:e2e_request_latency_seconds_sum",
240+
"vllm:e2e_request_latency_seconds_bucket",
241+
"vllm:e2e_request_latency_seconds_count",
242+
"vllm:request_queue_time_seconds_sum",
243+
"vllm:request_queue_time_seconds_bucket",
244+
"vllm:request_queue_time_seconds_count",
245+
"vllm:request_inference_time_seconds_sum",
246+
"vllm:request_inference_time_seconds_bucket",
247+
"vllm:request_inference_time_seconds_count",
248+
"vllm:request_prefill_time_seconds_sum",
249+
"vllm:request_prefill_time_seconds_bucket",
250+
"vllm:request_prefill_time_seconds_count",
251+
"vllm:request_decode_time_seconds_sum",
252+
"vllm:request_decode_time_seconds_bucket",
253+
"vllm:request_decode_time_seconds_count",
223254
]
224255

225256

tests/v1/core/test_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def create_scheduler(
3838
return Scheduler(scheduler_config,
3939
model_config,
4040
cache_config,
41-
lora_config=None)
41+
lora_config=None,
42+
log_stats=True)
4243

4344

4445
def create_requests(

tests/v1/engine/test_engine_core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def test_engine_core(monkeypatch):
5050
executor_class = Executor.get_class(vllm_config)
5151

5252
engine_core = EngineCore(vllm_config=vllm_config,
53-
executor_class=executor_class)
53+
executor_class=executor_class,
54+
log_stats=True)
5455
"""Test basic request lifecycle."""
5556

5657
# First request.
@@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch):
157158
executor_class = Executor.get_class(vllm_config)
158159

159160
engine_core = EngineCore(vllm_config=vllm_config,
160-
executor_class=executor_class)
161+
executor_class=executor_class,
162+
log_stats=True)
161163
"""Test basic request lifecycle."""
162164
# First request.
163165
request: EngineCoreRequest = make_request()

tests/v1/engine/test_engine_core_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
9494
asyncio_mode=False,
9595
vllm_config=vllm_config,
9696
executor_class=executor_class,
97+
log_stats=False,
9798
)
9899

99100
MAX_TOKENS = 20
@@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
163164
asyncio_mode=True,
164165
vllm_config=vllm_config,
165166
executor_class=executor_class,
167+
log_stats=True,
166168
)
167169

168170
MAX_TOKENS = 20

tests/v1/engine/test_output_processor.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import math
4+
import time
45
from typing import Dict, List, Optional
56

67
import pytest
@@ -15,6 +16,7 @@
1516
from vllm.transformers_utils.tokenizer import AnyTokenizer
1617
from vllm.v1.engine import EngineCoreRequest
1718
from vllm.v1.engine.output_processor import OutputProcessor
19+
from vllm.v1.metrics.stats import IterationStats
1820

1921

2022
def _ref_convert_id_to_token(
@@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
603605
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
604606
log_stats=True)
605607
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
608+
engine_core_timestamp = time.monotonic()
606609

607610
# Make N requests.
608611
requests = [
@@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):
630633

631634
# First iteration has 2 prefills.
632635
outputs = engine_core.get_outputs()[:num_active]
633-
processed_outputs = output_processor.process_outputs(outputs)
634-
iteration_stats = processed_outputs.iteration_stats
636+
iteration_stats = IterationStats()
637+
output_processor.process_outputs(outputs, engine_core_timestamp,
638+
iteration_stats)
635639
total_prompt_tokens = sum([
636640
len(prompt_tokens)
637641
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
@@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):
642646

643647
# Just decodes in this step.
644648
outputs = engine_core.get_outputs()[:num_active]
645-
processed_outputs = output_processor.process_outputs(outputs)
646-
iteration_stats = processed_outputs.iteration_stats
649+
iteration_stats = IterationStats()
650+
output_processor.process_outputs(outputs, engine_core_timestamp,
651+
iteration_stats)
647652

648653
assert iteration_stats.num_prompt_tokens == 0
649654
assert iteration_stats.num_generation_tokens == num_active
@@ -652,17 +657,19 @@ def test_iteration_stats(dummy_test_vectors):
652657
output_processor.add_request(inactive_request)
653658
num_active += 1
654659
outputs = engine_core.get_outputs()[:num_active]
655-
processed_outputs = output_processor.process_outputs(outputs)
656-
iteration_stats = processed_outputs.iteration_stats
660+
iteration_stats = IterationStats()
661+
output_processor.process_outputs(outputs, engine_core_timestamp,
662+
iteration_stats)
657663
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
658664

659665
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
660666
assert iteration_stats.num_generation_tokens == num_active
661667

662668
# Just decodes in this step.
663669
outputs = engine_core.get_outputs()[:num_active]
664-
processed_outputs = output_processor.process_outputs(outputs)
665-
iteration_stats = processed_outputs.iteration_stats
670+
iteration_stats = IterationStats()
671+
output_processor.process_outputs(outputs, engine_core_timestamp,
672+
iteration_stats)
666673

667674
assert iteration_stats.num_prompt_tokens == 0
668675
assert iteration_stats.num_generation_tokens == num_active

vllm/v1/core/kv_cache_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@ def __init__(
2626
sliding_window: Optional[int] = None,
2727
enable_caching: bool = True,
2828
num_preallocate_tokens: int = 64,
29+
log_stats: bool = False,
2930
) -> None:
3031
self.block_size = block_size
3132
self.num_gpu_blocks = num_gpu_blocks
3233
self.max_model_len = max_model_len
3334
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
3435
self.sliding_window = sliding_window
3536
self.enable_caching = enable_caching
37+
# FIXME: make prefix cache stats conditional on log_stats
38+
self.log_stats = log_stats
3639
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
3740
# blocks for each request. For example, when a request reaches the end
3841
# of its block table, we preallocate N blocks in advance. This way, we

vllm/v1/core/scheduler.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import time
34
from collections import deque
45
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
56

@@ -10,7 +11,8 @@
1011
from vllm.v1.core.kv_cache_manager import KVCacheManager
1112
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
1213
SchedulerOutput)
13-
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
14+
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
15+
EngineCoreOutput, EngineCoreOutputs)
1416
from vllm.v1.metrics.stats import SchedulerStats
1517
from vllm.v1.outputs import ModelRunnerOutput
1618
from vllm.v1.request import Request, RequestStatus
@@ -26,10 +28,12 @@ def __init__(
2628
model_config: ModelConfig,
2729
cache_config: CacheConfig,
2830
lora_config: Optional[LoRAConfig],
31+
log_stats: bool,
2932
) -> None:
3033
self.scheduler_config = scheduler_config
3134
self.cache_config = cache_config
3235
self.lora_config = lora_config
36+
self.log_stats = log_stats
3337

3438
# Scheduling constraints.
3539
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@@ -45,7 +49,8 @@ def __init__(
4549
num_gpu_blocks=num_gpu_blocks,
4650
max_model_len=self.max_model_len,
4751
sliding_window=self.cache_config.sliding_window,
48-
enable_caching=self.cache_config.enable_prefix_caching)
52+
enable_caching=self.cache_config.enable_prefix_caching,
53+
log_stats=self.log_stats)
4954
self.block_size = self.cache_config.block_size
5055

5156
# req_id -> Request
@@ -107,6 +112,8 @@ def schedule(self) -> "SchedulerOutput":
107112
scheduled_encoder_inputs: Dict[str, List[int]] = {}
108113
encoder_budget = self.max_num_encoder_input_tokens
109114

115+
scheduled_timestamp = time.monotonic()
116+
110117
# First, schedule the RUNNING requests.
111118
req_index = 0
112119
while req_index < len(self.running) and token_budget > 0:
@@ -246,6 +253,7 @@ def schedule(self) -> "SchedulerOutput":
246253
self.running.append(request)
247254
if request.status == RequestStatus.WAITING:
248255
scheduled_new_reqs.append(request)
256+
self.request_scheduled(request, scheduled_timestamp)
249257
elif request.status == RequestStatus.PREEMPTED:
250258
scheduled_resumed_reqs.append(request)
251259
else:
@@ -508,7 +516,8 @@ def update_from_output(
508516
finish_reason=request.get_finished_reason(),
509517
new_logprobs=new_logprobs,
510518
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
511-
stop_reason=request.stop_reason))
519+
stop_reason=request.stop_reason,
520+
events=request.take_events()))
512521

513522
if not stopped:
514523
new_running.append(request)
@@ -541,6 +550,7 @@ def _check_stop(self, request: Request) -> bool:
541550
def add_request(self, request: Request) -> None:
542551
self.waiting.append(request)
543552
self.requests[request.request_id] = request
553+
self.request_queued(request)
544554

545555
def finish_requests(
546556
self,
@@ -588,7 +598,22 @@ def has_unfinished_requests(self) -> bool:
588598
def reset_prefix_cache(self) -> bool:
589599
return self.kv_cache_manager.reset_prefix_cache()
590600

591-
def make_stats(self) -> SchedulerStats:
601+
def request_queued(self, request: Request):
602+
if not self.log_stats:
603+
return
604+
request.events.append(
605+
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))
606+
607+
def request_scheduled(self, request: Request, timestamp: float):
608+
if not self.log_stats:
609+
return
610+
request.events.append(
611+
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
612+
timestamp))
613+
614+
def make_stats(self) -> Optional[SchedulerStats]:
615+
if not self.log_stats:
616+
return None
592617
return SchedulerStats(
593618
num_running_reqs=len(self.running),
594619
num_waiting_reqs=len(self.waiting),

vllm/v1/engine/__init__.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import enum
4+
import time
45
from typing import List, Optional, Union
56

67
import msgspec
@@ -60,6 +61,30 @@ class EngineCoreRequest(
6061
lora_request: Optional[LoRARequest]
6162

6263

64+
class EngineCoreEventType(enum.IntEnum):
65+
"""The type of engine core request event."""
66+
QUEUED = 1
67+
SCHEDULED = 2
68+
69+
70+
class EngineCoreEvent(msgspec.Struct):
71+
"""A timestamped engine core event associated with a request.
72+
73+
The timestamp is a monotonic timestamps and is used for by the engine
74+
frontend to calculate intervals between engine core events. These
75+
timestamps should not be compared with timestamps from other processes.
76+
"""
77+
type: EngineCoreEventType
78+
timestamp: float
79+
80+
@classmethod
81+
def new_event(cls,
82+
event_type: EngineCoreEventType,
83+
timestamp: Optional[float] = None) -> "EngineCoreEvent":
84+
timestamp = time.monotonic() if timestamp is None else timestamp
85+
return cls(event_type, timestamp)
86+
87+
6388
class EngineCoreOutput(
6489
msgspec.Struct,
6590
array_like=True, # type: ignore[call-arg]
@@ -74,6 +99,7 @@ class EngineCoreOutput(
7499

75100
finish_reason: Optional[FinishReason] = None
76101
stop_reason: Union[int, str, None] = None
102+
events: Optional[List[EngineCoreEvent]] = None
77103

78104
@property
79105
def finished(self) -> bool:
@@ -91,7 +117,12 @@ class EngineCoreOutputs(
91117

92118
# [num_reqs]
93119
outputs: List[EngineCoreOutput]
94-
scheduler_stats: SchedulerStats
120+
scheduler_stats: Optional[SchedulerStats]
121+
timestamp: float = 0.0
122+
123+
def __post_init__(self):
124+
if self.timestamp == 0.0:
125+
self.timestamp = time.monotonic()
95126

96127

97128
class EngineCoreRequestType(enum.Enum):

0 commit comments

Comments
 (0)