Skip to content

Commit a3d087a

Browse files
authored
[P/D][Nixl] Introduce KVTransferMetrics and aggregation strategy (#22188)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 058525b commit a3d087a

File tree

11 files changed

+525
-25
lines changed

11 files changed

+525
-25
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@
1818

1919
from vllm import LLM
2020
from vllm.config import KVTransferConfig
21+
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
22+
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
23+
KVConnectorStats)
24+
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
25+
MultiKVConnectorStats)
2126
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
2227
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
23-
NixlConnectorWorker)
28+
NixlConnectorWorker, NixlKVConnectorStats)
2429
from vllm.forward_context import ForwardContext
2530
from vllm.sampling_params import SamplingParams
2631
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
32+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
2733

2834
from .utils import create_request, create_scheduler, create_vllm_config
2935

@@ -475,6 +481,209 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
475481
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
476482
# we put here is important. First run ray, it will clean up the resources, then
477483
# the rest of the tests.
484+
@patch(
485+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
486+
FakeNixlWrapper)
487+
def test_kv_connector_stats(dist_init):
488+
"""Test that KV transfer stats are properly recorded and retrieved."""
489+
vllm_config = create_vllm_config()
490+
491+
# Test worker role in decode server.
492+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
493+
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
494+
connector.engine_id,
495+
hand_shake_latency=0)
496+
497+
# Verify that xfer_stats starts empty
498+
initial_stats = connector.get_kv_connector_stats()
499+
assert initial_stats is None
500+
501+
# Create transfer metadata
502+
request_id = "test_req_for_stats"
503+
metadata = NixlConnectorMetadata()
504+
metadata.add_new_req(request_id=request_id,
505+
local_block_ids=[1, 2, 3],
506+
kv_transfer_params={
507+
"remote_block_ids": [4, 5, 6],
508+
"remote_engine_id":
509+
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
510+
"remote_host": "localhost",
511+
"remote_port": 1234,
512+
"remote_tp_size": 1,
513+
})
514+
connector.bind_connector_metadata(metadata)
515+
516+
# Start the transfer
517+
dummy_ctx = ForwardContext(
518+
no_compile_layers={},
519+
attn_metadata={},
520+
virtual_engine=0,
521+
)
522+
connector.start_load_kv(dummy_ctx)
523+
524+
# Verify stats are recorded after transfer is complete
525+
max_iterations = 2
526+
# Clear metadata before start_load_kv to prevent reprocessing same request
527+
connector.bind_connector_metadata(NixlConnectorMetadata())
528+
for _ in range(max_iterations):
529+
# Need to call start_load_kv to process completed handshakes
530+
connector.start_load_kv(dummy_ctx)
531+
_, done_recving = connector.get_finished(finished_req_ids=set())
532+
if len(done_recving) > 0 and request_id in done_recving:
533+
break
534+
time.sleep(
535+
0.1) # Small delay to allow background handshake to complete
536+
else:
537+
assert "Transfer did not complete within expected iterations"
538+
539+
# Now check that stats were recorded
540+
stats_after_transfer = connector.get_kv_connector_stats()
541+
assert isinstance(stats_after_transfer, NixlKVConnectorStats)
542+
543+
# Verify stats values are recorded
544+
assert not stats_after_transfer.is_empty()
545+
assert stats_after_transfer.data["num_successful_transfers"] == 1
546+
547+
# Verify stats are reset after retrieval
548+
stats_after_reset = connector.get_kv_connector_stats()
549+
assert stats_after_reset is None
550+
551+
552+
def test_kv_connector_stats_aggregation():
553+
"""
554+
Test KV transfer stats aggregation across TP ranks using
555+
KVOutputAggregator (used by MultiprocExecutor).
556+
"""
557+
558+
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
559+
# done in MultiprocExecutor.execute_model
560+
aggregator = KVOutputAggregator(world_size=3)
561+
562+
# Create stats for multiple workers with different transfer patterns
563+
worker1_stats = NixlKVConnectorStats()
564+
worker2_stats = NixlKVConnectorStats()
565+
worker3_stats = NixlKVConnectorStats()
566+
567+
# Record different transfers on each worker
568+
# Worker 1: 2 transfers
569+
worker1_stats.record_transfer()
570+
worker1_stats.record_transfer()
571+
572+
# Worker 2: 1 transfer
573+
worker2_stats.record_transfer()
574+
575+
# Worker 3: 3 transfers
576+
worker3_stats.record_transfer()
577+
worker3_stats.record_transfer()
578+
worker3_stats.record_transfer()
579+
580+
# Create ModelRunnerOutput instances for each worker
581+
worker_outputs = []
582+
for i, worker_stats in enumerate(
583+
[worker1_stats, worker2_stats, worker3_stats]):
584+
output = ModelRunnerOutput(
585+
req_ids=[f"req_{i}"],
586+
req_id_to_index={f"req_{i}": 0},
587+
sampled_token_ids=[[123]], # dummy token
588+
logprobs=None,
589+
prompt_logprobs_dict={},
590+
pooler_output=[None],
591+
kv_connector_output=KVConnectorOutput(
592+
finished_sending=set([f"req_{i}_send"])
593+
if i < 2 else None, # Workers 0,1 finished sending
594+
finished_recving=set([f"req_{i}_recv"])
595+
if i > 0 else None, # Workers 1,2 finished receiving
596+
kv_connector_stats=worker_stats,
597+
))
598+
worker_outputs.append(output)
599+
600+
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
601+
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
602+
kv_connector_stats = \
603+
aggregated_output.kv_connector_output.kv_connector_stats
604+
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
605+
# Number of total transfers across all workers.
606+
assert kv_connector_stats.data["num_successful_transfers"] == 6
607+
608+
609+
def test_multi_kv_connector_stats_aggregation():
610+
"""
611+
Test MultiKVConnectorStats aggregation across TP ranks using
612+
KVOutputAggregator (used by MultiprocExecutor).
613+
"""
614+
615+
aggregator = KVOutputAggregator(world_size=3)
616+
617+
from dataclasses import dataclass
618+
619+
@dataclass
620+
class FooKVConnectorStats(KVConnectorStats):
621+
622+
def reset(self):
623+
self.data = {"num_foo_transfers": 0}
624+
625+
def record_transfer(self):
626+
if "num_foo_transfers" not in self.data:
627+
self.data["num_foo_transfers"] = 0
628+
self.data["num_foo_transfers"] += 1
629+
630+
def is_empty(self) -> bool:
631+
return self.data["num_foo_transfers"] == 0
632+
633+
def aggregate(self,
634+
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
635+
if not other.is_empty():
636+
self.data["num_foo_transfers"] += other.data[
637+
"num_foo_transfers"]
638+
return self
639+
640+
def make_multi_stats(nixl_count: int,
641+
foo_count: int) -> MultiKVConnectorStats:
642+
data: dict[str, KVConnectorStats] = {}
643+
if nixl_count > 0:
644+
nixl_stats = NixlKVConnectorStats()
645+
for _ in range(nixl_count):
646+
nixl_stats.record_transfer()
647+
data["NixlConnector"] = nixl_stats
648+
if foo_count > 0:
649+
foo_stats = FooKVConnectorStats()
650+
for _ in range(foo_count):
651+
foo_stats.record_transfer()
652+
data["FooConnector"] = foo_stats
653+
return MultiKVConnectorStats(data=data)
654+
655+
# Create heterogeneous stats across 3 workers
656+
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
657+
658+
worker_outputs: list[ModelRunnerOutput] = []
659+
for i, (nixl, foo) in enumerate(worker_patterns):
660+
stats = make_multi_stats(nixl, foo)
661+
output = ModelRunnerOutput(
662+
req_ids=[f"req_{i}"],
663+
req_id_to_index={f"req_{i}": 0},
664+
sampled_token_ids=[[123]],
665+
logprobs=None,
666+
prompt_logprobs_dict={},
667+
pooler_output=[None],
668+
kv_connector_output=KVConnectorOutput(
669+
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
670+
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
671+
kv_connector_stats=stats,
672+
),
673+
)
674+
worker_outputs.append(output)
675+
676+
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
677+
kv_connector_stats = \
678+
aggregated_output.kv_connector_output.kv_connector_stats
679+
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
680+
681+
# Validate per-connector totals across workers
682+
assert kv_connector_stats["NixlConnector"].data[
683+
"num_successful_transfers"] == 5
684+
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
685+
686+
478687
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
479688
@patch(
480689
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(self, world_size: int):
129129
def aggregate(self,
130130
outputs: list[ModelRunnerOutput],
131131
output_rank: int = 0) -> ModelRunnerOutput:
132-
# aggregate kv_connector_output from all workers
132+
# Aggregate kv_connector_output from all workers
133133

134134
def update_finished_set(req_ids: Optional[set[str]],
135135
remaining_count_dict: dict[str, int],
@@ -142,21 +142,36 @@ def update_finished_set(req_ids: Optional[set[str]],
142142

143143
finished_sending = set[str]()
144144
finished_recving = set[str]()
145-
for output in outputs:
146-
output = output.kv_connector_output
145+
aggregated_kv_connector_stats = None
146+
for model_runner_output in outputs:
147+
output = model_runner_output.kv_connector_output
147148
if not output:
148149
continue
149150
update_finished_set(output.finished_sending,
150151
self._send_remaining_count, finished_sending)
151152
update_finished_set(output.finished_recving,
152153
self._recv_remaining_count, finished_recving)
153154

155+
# Aggregate kv_connector_stats from all workers.
156+
if aggregated_kv_connector_stats is None:
157+
# Use the first worker's kv_connector_stats as accumulator.
158+
aggregated_kv_connector_stats = output.kv_connector_stats
159+
elif kv_connector_stats := output.kv_connector_stats:
160+
if aggregated_kv_connector_stats is None:
161+
aggregated_kv_connector_stats = kv_connector_stats
162+
else:
163+
assert isinstance(aggregated_kv_connector_stats,
164+
type(kv_connector_stats))
165+
aggregated_kv_connector_stats = \
166+
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
167+
154168
# select output of the worker specified by output_rank
155169
output = outputs[output_rank]
156170

157171
output.kv_connector_output = KVConnectorOutput(
158172
finished_sending=finished_sending or None,
159173
finished_recving=finished_recving or None,
174+
kv_connector_stats=aggregated_kv_connector_stats or None,
160175
)
161176

162177
return output

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from vllm.attention.backends.abstract import AttentionMetadata
5050
from vllm.config import VllmConfig
5151
from vllm.distributed.kv_events import KVCacheEvent
52+
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
53+
KVConnectorStats)
5254
from vllm.forward_context import ForwardContext
5355
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
5456
from vllm.v1.request import Request
@@ -235,6 +237,12 @@ def shutdown(self):
235237
"""
236238
return None
237239

240+
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
241+
"""
242+
Get the KV connector stats collected during the last interval.
243+
"""
244+
return None
245+
238246
# ==============================
239247
# Scheduler-side methods
240248
# ==============================
@@ -365,4 +373,16 @@ def get_finished_count(self) -> Optional[int]:
365373
int: expected sending or receiving completion count.
366374
"""
367375

368-
return None
376+
return None
377+
378+
@classmethod
379+
def build_kv_connector_stats(
380+
cls,
381+
data: Optional[dict[str,
382+
Any]] = None) -> Optional["KVConnectorStats"]:
383+
"""
384+
KVConnectorStats resolution method. This method allows dynamically
385+
registered connectors to return their own KVConnectorStats object,
386+
which can implement custom aggregation logic on the data dict.
387+
"""
388+
return None

0 commit comments

Comments
 (0)