Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] StatLoggers: cache spec decode metrics when they get collected. #6645

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import List

import pytest
Expand All @@ -10,6 +11,8 @@
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams

from ..conftest import cleanup

MODELS = [
"facebook/opt-125m",
]
Expand Down Expand Up @@ -219,6 +222,94 @@ def test_metric_spec_decode(
"does not meet expectation")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
def test_metric_spec_decode_interval(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
log_interval: int,
) -> None:
k = 5

engine_args = EngineArgs(model=model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True,
enforce_eager=True)

engine = LLMEngine.from_engine_args(engine_args)

try:

engine.add_request(
"request-id-0",
example_prompts[0],
SamplingParams(max_tokens=max_tokens),
)

# set log internal
stat_logger = engine.stat_loggers['prometheus']
stat_logger.local_interval = log_interval

# prefill
engine.step()

# wait for 5 seconds to ensure that spec decode metrics
# get triggered in first decode step
time.sleep(5)

# first decode step should trigger async collection of metrics
engine.step()

# wait one second to allow H2D transfer to finish
time.sleep(1)

# second decode step should now be able to collect the spec
# decode stats and the request should also be finished
engine.step()

# must have finisehd now
assert not engine.has_unfinished_requests()

# wait to ensure logging occurs
time.sleep(log_interval)

# force logging
engine.step()

# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn = {
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
"counter_spec_decode_num_emitted_tokens":
lambda v: 0 <= v <= k + 1,
}

for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
metric_name).labels(**stat_logger.labels)._value.get()
assert is_expected(metric_val), (
f"the value of metric {metric_name} ({metric_val}) "
"does not meet expectation")

finally:
del engine
cleanup()


def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
Expand Down
47 changes: 31 additions & 16 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(self, local_interval: float) -> None:
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None

@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
Expand All @@ -364,6 +365,12 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
def log(self, stats: Stats) -> None:
raise NotImplementedError

def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics


class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
Expand All @@ -379,6 +386,9 @@ def log(self, stats: Stats) -> None:
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)

# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)

# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
Expand Down Expand Up @@ -408,15 +418,16 @@ def log(self, stats: Stats) -> None:
stats.cpu_cache_usage_sys * 100,
)

if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
self.spec_decode_metrics = None

def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
Expand Down Expand Up @@ -533,6 +544,9 @@ def log(self, stats: Stats):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)

# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)

# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
Expand All @@ -550,26 +564,27 @@ def log(self, stats: Stats):
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)
self.spec_decode_metrics.emitted_tokens)

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None


class RayPrometheusStatLogger(PrometheusStatLogger):
Expand Down
Loading