Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,66 @@ def test_kv_connector_basic():
)


def test_external_prefix_cache_metrics():
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
"""

# Setup Scheduler.
scheduler = create_scheduler(
enable_prefix_caching=False,
use_kv_connector=True,
)

# Mock connector to simulate a partial external cache hit
NUM_MATCHED_NEW_TOKENS = 4
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)

# --- Prepare simple requests ---
NUM_REQUESTS = 2
NUM_TOKENS = 8
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
)

for req in requests:
scheduler.add_request(req)

# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[r.request_id for r in requests],
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
sampled_token_ids=[[1000]] * NUM_REQUESTS,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)

# Update scheduler stats
ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)

# --- Assertions ---
assert ecos is not None and len(ecos) > 0
assert ecos[0].scheduler_stats is not None

external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
assert external_stats is not None

assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
assert external_stats.requests == NUM_REQUESTS
assert external_stats.preempted_requests == 0


def test_kv_connector_unable_to_allocate():
"""
Test whether scheduler with KVConnector is able to handle
Expand Down
15 changes: 5 additions & 10 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,11 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:

if self.log_stats:
assert self.prefix_cache_stats is not None
if request.num_preemptions > 0:
# Previously preempted request
self.prefix_cache_stats.preempted_requests += 1
self.prefix_cache_stats.preempted_queries += request.num_tokens
self.prefix_cache_stats.preempted_hits += num_new_computed_tokens
else:
# New request
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
self.prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=num_new_computed_tokens,
preempted=request.num_preemptions > 0,
)

return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens

Expand Down
29 changes: 28 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
# will have a corresponding KVConnector with Role=WORKER.
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
Expand All @@ -95,6 +96,8 @@ def __init__(
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()

self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
Expand Down Expand Up @@ -525,6 +528,9 @@ def schedule(self) -> SchedulerOutput:
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
self._update_connector_prefix_cache_stats(
request, num_external_computed_tokens
)

# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
Expand Down Expand Up @@ -1246,11 +1252,13 @@ def make_stats(
return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats()
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running),
kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None,
Expand Down Expand Up @@ -1281,6 +1289,25 @@ def shutdown(self) -> None:
# KV Connector Related Methods
########################################################################

def _update_connector_prefix_cache_stats(
self, request: Request, num_external_tokens: int
) -> None:
if self.connector_prefix_cache_stats is None:
return

self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=num_external_tokens,
preempted=request.num_preemptions > 0,
)

def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None:
if self.connector_prefix_cache_stats is None:
return None
stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
return stats

def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector

Expand Down
45 changes: 45 additions & 0 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
# Caching metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = CachingMetrics()
self.connector_prefix_caching_metrics = CachingMetrics()
self.mm_caching_metrics = CachingMetrics()

self.spec_decoding_logging = SpecDecodingLogging()
Expand Down Expand Up @@ -140,6 +141,11 @@ def record(
if scheduler_stats is not None:
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)

if scheduler_stats.connector_prefix_cache_stats is not None:
self.connector_prefix_caching_metrics.observe(
scheduler_stats.connector_prefix_cache_stats
)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
Expand Down Expand Up @@ -192,6 +198,9 @@ def log(self):
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
]
if not self.connector_prefix_caching_metrics.empty:
log_parts.append("External prefix cache hit rate: %.1f%%")
log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100)
if not self.mm_caching_metrics.empty:
log_parts.append("MM cache hit rate: %.1f%%")
log_args.append(self.mm_caching_metrics.hit_rate * 100)
Expand Down Expand Up @@ -457,6 +466,34 @@ def __init__(
counter_prefix_cache_hits, engine_indexes, model_name
)

#
# External - KV connector prefix cache
#

counter_connector_prefix_cache_queries = self._counter_cls(
name="vllm:external_prefix_cache_queries",
documentation=(
"External prefix cache queries from KV connector "
"cross-instance cache sharing, in terms of number of queried tokens."
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_queries = make_per_engine(
counter_connector_prefix_cache_queries, engine_indexes, model_name
)

counter_connector_prefix_cache_hits = self._counter_cls(
name="vllm:external_prefix_cache_hits",
documentation=(
"External prefix cache hits from KV connector "
"cross-instance cache sharing, in terms of number of cached tokens."
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_hits = make_per_engine(
counter_connector_prefix_cache_hits, engine_indexes, model_name
)

#
# Multi-modal cache
#
Expand Down Expand Up @@ -883,6 +920,14 @@ def record(
scheduler_stats.prefix_cache_stats.hits
)

if scheduler_stats.connector_prefix_cache_stats is not None:
self.counter_connector_prefix_cache_queries[engine_idx].inc(
scheduler_stats.connector_prefix_cache_stats.queries
)
self.counter_connector_prefix_cache_hits[engine_idx].inc(
scheduler_stats.connector_prefix_cache_stats.hits
)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats, engine_idx
Expand Down
14 changes: 14 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ class PrefixCacheStats(BaseCacheStats):
preempted_hits: int = 0
"""The `hits` number for preempted requests."""

def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
"""Aggregate request information into the stats."""
if preempted:
# Previously preempted request
self.preempted_requests += 1
self.preempted_queries += num_tokens
self.preempted_hits += num_hits
else:
# New request
self.requests += 1
self.queries += num_tokens
self.hits += num_hits


@dataclass
class MultiModalCacheStats(BaseCacheStats):
Expand All @@ -151,6 +164,7 @@ class SchedulerStats:
kv_cache_usage: float = 0.0

prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
connector_prefix_cache_stats: PrefixCacheStats | None = None

spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None
Expand Down