diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index aaac2deb12ac..c17c6f6c89b0 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1014,6 +1014,66 @@ def test_kv_connector_basic(): ) +def test_external_prefix_cache_metrics(): + """ + Verify connector prefix cache metrics are updated + correctly when the scheduler processes requests with KV connector hits. + """ + + # Setup Scheduler. + scheduler = create_scheduler( + enable_prefix_caching=False, + use_kv_connector=True, + ) + + # Mock connector to simulate a partial external cache hit + NUM_MATCHED_NEW_TOKENS = 4 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS, + False, + ) + + # --- Prepare simple requests --- + NUM_REQUESTS = 2 + NUM_TOKENS = 8 + MAX_TOKENS = 2 + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + ) + + for req in requests: + scheduler.add_request(req) + + # --- Trigger scheduling and simulate model output --- + output = scheduler.schedule() + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[r.request_id for r in requests], + req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, + sampled_token_ids=[[1000]] * NUM_REQUESTS, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + # Update scheduler stats + ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # --- Assertions --- + assert ecos is not None and len(ecos) > 0 + assert ecos[0].scheduler_stats is not None + + external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats + assert external_stats is not None + + assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS + assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS + assert external_stats.requests == NUM_REQUESTS + assert external_stats.preempted_requests == 0 + + def test_kv_connector_unable_to_allocate(): """ Test whether scheduler with KVConnector is able to handle diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 74176e4b2051..bb8cec91f36d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -208,16 +208,11 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: if self.log_stats: assert self.prefix_cache_stats is not None - if request.num_preemptions > 0: - # Previously preempted request - self.prefix_cache_stats.preempted_requests += 1 - self.prefix_cache_stats.preempted_queries += request.num_tokens - self.prefix_cache_stats.preempted_hits += num_new_computed_tokens - else: - # New request - self.prefix_cache_stats.requests += 1 - self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_new_computed_tokens + self.prefix_cache_stats.record( + num_tokens=request.num_tokens, + num_hits=num_new_computed_tokens, + preempted=request.num_preemptions > 0, + ) return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99ef..e27915f31dad 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -28,7 +28,7 @@ from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -84,6 +84,7 @@ def __init__( # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None + self.connector_prefix_cache_stats: PrefixCacheStats | None = None if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " @@ -95,6 +96,8 @@ def __init__( self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER ) + if self.log_stats: + self.connector_prefix_cache_stats = PrefixCacheStats() self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -525,6 +528,9 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks + new_blocks, num_external_computed_tokens, ) + self._update_connector_prefix_cache_stats( + request, num_external_computed_tokens + ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -1246,11 +1252,13 @@ def make_stats( return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None + connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, + connector_prefix_cache_stats=connector_prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, @@ -1281,6 +1289,25 @@ def shutdown(self) -> None: # KV Connector Related Methods ######################################################################## + def _update_connector_prefix_cache_stats( + self, request: Request, num_external_tokens: int + ) -> None: + if self.connector_prefix_cache_stats is None: + return + + self.connector_prefix_cache_stats.record( + num_tokens=request.num_tokens, + num_hits=num_external_tokens, + preempted=request.num_preemptions > 0, + ) + + def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None: + if self.connector_prefix_cache_stats is None: + return None + stats = self.connector_prefix_cache_stats + self.connector_prefix_cache_stats = PrefixCacheStats() + return stats + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index ca322f104020..a31f8147959b 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -93,6 +93,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = CachingMetrics() + self.connector_prefix_caching_metrics = CachingMetrics() self.mm_caching_metrics = CachingMetrics() self.spec_decoding_logging = SpecDecodingLogging() @@ -140,6 +141,11 @@ def record( if scheduler_stats is not None: self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats.connector_prefix_cache_stats is not None: + self.connector_prefix_caching_metrics.observe( + scheduler_stats.connector_prefix_cache_stats + ) + if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: @@ -192,6 +198,9 @@ def log(self): self.last_scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, ] + if not self.connector_prefix_caching_metrics.empty: + log_parts.append("External prefix cache hit rate: %.1f%%") + log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100) if not self.mm_caching_metrics.empty: log_parts.append("MM cache hit rate: %.1f%%") log_args.append(self.mm_caching_metrics.hit_rate * 100) @@ -457,6 +466,34 @@ def __init__( counter_prefix_cache_hits, engine_indexes, model_name ) + # + # External - KV connector prefix cache + # + + counter_connector_prefix_cache_queries = self._counter_cls( + name="vllm:external_prefix_cache_queries", + documentation=( + "External prefix cache queries from KV connector " + "cross-instance cache sharing, in terms of number of queried tokens." + ), + labelnames=labelnames, + ) + self.counter_connector_prefix_cache_queries = make_per_engine( + counter_connector_prefix_cache_queries, engine_indexes, model_name + ) + + counter_connector_prefix_cache_hits = self._counter_cls( + name="vllm:external_prefix_cache_hits", + documentation=( + "External prefix cache hits from KV connector " + "cross-instance cache sharing, in terms of number of cached tokens." + ), + labelnames=labelnames, + ) + self.counter_connector_prefix_cache_hits = make_per_engine( + counter_connector_prefix_cache_hits, engine_indexes, model_name + ) + # # Multi-modal cache # @@ -883,6 +920,14 @@ def record( scheduler_stats.prefix_cache_stats.hits ) + if scheduler_stats.connector_prefix_cache_stats is not None: + self.counter_connector_prefix_cache_queries[engine_idx].inc( + scheduler_stats.connector_prefix_cache_stats.queries + ) + self.counter_connector_prefix_cache_hits[engine_idx].inc( + scheduler_stats.connector_prefix_cache_stats.hits + ) + if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( scheduler_stats.spec_decoding_stats, engine_idx diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a4a8ab32ad72..7868141d1b1d 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -126,6 +126,19 @@ class PrefixCacheStats(BaseCacheStats): preempted_hits: int = 0 """The `hits` number for preempted requests.""" + def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None: + """Aggregate request information into the stats.""" + if preempted: + # Previously preempted request + self.preempted_requests += 1 + self.preempted_queries += num_tokens + self.preempted_hits += num_hits + else: + # New request + self.requests += 1 + self.queries += num_tokens + self.hits += num_hits + @dataclass class MultiModalCacheStats(BaseCacheStats): @@ -151,6 +164,7 @@ class SchedulerStats: kv_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) + connector_prefix_cache_stats: PrefixCacheStats | None = None spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: dict[str, Any] | None = None