Skip to content

Commit

Permalink
[V1][Metrics] Add iteration_tokens_total histogram from V0
Browse files Browse the repository at this point in the history
Basing bucket sizes on cudagraph capture sizes was introduced in
PR vllm-project#11031 and vllm-project#12243.

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
  • Loading branch information
markmc committed Feb 14, 2025
1 parent 83481ce commit 8b162c9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
13 changes: 10 additions & 3 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ async def client(server):
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_max_tokens":
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_max_tokens": [
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)
],
"vllm:iteration_tokens_total":
[("_sum", _NUM_REQUESTS *
(_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)),
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)],
"vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens": [
Expand Down Expand Up @@ -197,6 +202,7 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
"vllm:request_params_max_tokens_sum",
"vllm:request_params_max_tokens_bucket",
"vllm:request_params_max_tokens_count",
"vllm:iteration_tokens_total",
"vllm:num_preemptions_total",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
Expand All @@ -223,6 +229,7 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
"vllm:gpu_prefix_cache_hits",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
"vllm:iteration_tokens_total",
"vllm:request_success_total",
"vllm:request_prompt_tokens_sum",
"vllm:request_prompt_tokens_bucket",
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
if self.log_stats:
self.stat_loggers.extend([
LoggingStatLogger(),
PrometheusStatLogger(vllm_config.model_config),
PrometheusStatLogger(vllm_config),
])

# Tokenizer (+ ensure liveness if running in another process).
Expand Down
28 changes: 24 additions & 4 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import prometheus_client

from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
Expand Down Expand Up @@ -92,13 +92,13 @@ def log(self, scheduler_stats: SchedulerStats,

class PrometheusStatLogger(StatLoggerBase):

def __init__(self, model_config: ModelConfig):
def __init__(self, vllm_config: VllmConfig):
self._unregister_vllm_metrics()

labelnames = ["model_name"]
labelvalues = [model_config.served_model_name]
labelvalues = [vllm_config.model_config.served_model_name]

max_model_len = model_config.max_model_len
max_model_len = vllm_config.model_config.max_model_len

self.gauge_scheduler_running = prometheus_client.Gauge(
name="vllm:num_requests_running",
Expand Down Expand Up @@ -162,6 +162,13 @@ def __init__(self, model_config: ModelConfig):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)

self.histogram_iteration_tokens = \
prometheus_client.Histogram(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
buckets=build_cudagraph_buckets(vllm_config),
labelnames=labelnames).labels(*labelvalues)

self.histogram_time_to_first_token = \
prometheus_client.Histogram(
name="vllm:time_to_first_token_seconds",
Expand Down Expand Up @@ -237,6 +244,9 @@ def log(self, scheduler_stats: SchedulerStats,
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens)
self.histogram_iteration_tokens.observe(
iteration_stats.num_prompt_tokens + \
iteration_stats.num_generation_tokens)

for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason].inc()
Expand Down Expand Up @@ -293,3 +303,13 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
[1, 2, 5, 10, 20, 50, 100]
"""
return build_buckets([1, 2, 5], max_value)


def build_cudagraph_buckets(vllm_config: VllmConfig) -> List[int]:
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
return buckets
else:
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]

0 comments on commit 8b162c9

Please sign in to comment.