|
18 | 18 |
|
19 | 19 | from vllm import LLM |
20 | 20 | from vllm.config import KVTransferConfig |
| 21 | +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator |
| 22 | +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( |
| 23 | + KVConnectorStats) |
| 24 | +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( |
| 25 | + MultiKVConnectorStats) |
21 | 26 | from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( |
22 | 27 | KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, |
23 | | - NixlConnectorWorker) |
| 28 | + NixlConnectorWorker, NixlKVConnectorStats) |
24 | 29 | from vllm.forward_context import ForwardContext |
25 | 30 | from vllm.sampling_params import SamplingParams |
26 | 31 | from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend |
| 32 | +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput |
27 | 33 |
|
28 | 34 | from .utils import create_request, create_scheduler, create_vllm_config |
29 | 35 |
|
@@ -475,6 +481,209 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): |
475 | 481 | # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which |
476 | 482 | # we put here is important. First run ray, it will clean up the resources, then |
477 | 483 | # the rest of the tests. |
| 484 | +@patch( |
| 485 | + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", |
| 486 | + FakeNixlWrapper) |
| 487 | +def test_kv_connector_stats(dist_init): |
| 488 | + """Test that KV transfer stats are properly recorded and retrieved.""" |
| 489 | + vllm_config = create_vllm_config() |
| 490 | + |
| 491 | + # Test worker role in decode server. |
| 492 | + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) |
| 493 | + connector.connector_worker = FakeNixlConnectorWorker(vllm_config, |
| 494 | + connector.engine_id, |
| 495 | + hand_shake_latency=0) |
| 496 | + |
| 497 | + # Verify that xfer_stats starts empty |
| 498 | + initial_stats = connector.get_kv_connector_stats() |
| 499 | + assert initial_stats is None |
| 500 | + |
| 501 | + # Create transfer metadata |
| 502 | + request_id = "test_req_for_stats" |
| 503 | + metadata = NixlConnectorMetadata() |
| 504 | + metadata.add_new_req(request_id=request_id, |
| 505 | + local_block_ids=[1, 2, 3], |
| 506 | + kv_transfer_params={ |
| 507 | + "remote_block_ids": [4, 5, 6], |
| 508 | + "remote_engine_id": |
| 509 | + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, |
| 510 | + "remote_host": "localhost", |
| 511 | + "remote_port": 1234, |
| 512 | + "remote_tp_size": 1, |
| 513 | + }) |
| 514 | + connector.bind_connector_metadata(metadata) |
| 515 | + |
| 516 | + # Start the transfer |
| 517 | + dummy_ctx = ForwardContext( |
| 518 | + no_compile_layers={}, |
| 519 | + attn_metadata={}, |
| 520 | + virtual_engine=0, |
| 521 | + ) |
| 522 | + connector.start_load_kv(dummy_ctx) |
| 523 | + |
| 524 | + # Verify stats are recorded after transfer is complete |
| 525 | + max_iterations = 2 |
| 526 | + # Clear metadata before start_load_kv to prevent reprocessing same request |
| 527 | + connector.bind_connector_metadata(NixlConnectorMetadata()) |
| 528 | + for _ in range(max_iterations): |
| 529 | + # Need to call start_load_kv to process completed handshakes |
| 530 | + connector.start_load_kv(dummy_ctx) |
| 531 | + _, done_recving = connector.get_finished(finished_req_ids=set()) |
| 532 | + if len(done_recving) > 0 and request_id in done_recving: |
| 533 | + break |
| 534 | + time.sleep( |
| 535 | + 0.1) # Small delay to allow background handshake to complete |
| 536 | + else: |
| 537 | + assert "Transfer did not complete within expected iterations" |
| 538 | + |
| 539 | + # Now check that stats were recorded |
| 540 | + stats_after_transfer = connector.get_kv_connector_stats() |
| 541 | + assert isinstance(stats_after_transfer, NixlKVConnectorStats) |
| 542 | + |
| 543 | + # Verify stats values are recorded |
| 544 | + assert not stats_after_transfer.is_empty() |
| 545 | + assert stats_after_transfer.data["num_successful_transfers"] == 1 |
| 546 | + |
| 547 | + # Verify stats are reset after retrieval |
| 548 | + stats_after_reset = connector.get_kv_connector_stats() |
| 549 | + assert stats_after_reset is None |
| 550 | + |
| 551 | + |
| 552 | +def test_kv_connector_stats_aggregation(): |
| 553 | + """ |
| 554 | + Test KV transfer stats aggregation across TP ranks using |
| 555 | + KVOutputAggregator (used by MultiprocExecutor). |
| 556 | + """ |
| 557 | + |
| 558 | + # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing |
| 559 | + # done in MultiprocExecutor.execute_model |
| 560 | + aggregator = KVOutputAggregator(world_size=3) |
| 561 | + |
| 562 | + # Create stats for multiple workers with different transfer patterns |
| 563 | + worker1_stats = NixlKVConnectorStats() |
| 564 | + worker2_stats = NixlKVConnectorStats() |
| 565 | + worker3_stats = NixlKVConnectorStats() |
| 566 | + |
| 567 | + # Record different transfers on each worker |
| 568 | + # Worker 1: 2 transfers |
| 569 | + worker1_stats.record_transfer() |
| 570 | + worker1_stats.record_transfer() |
| 571 | + |
| 572 | + # Worker 2: 1 transfer |
| 573 | + worker2_stats.record_transfer() |
| 574 | + |
| 575 | + # Worker 3: 3 transfers |
| 576 | + worker3_stats.record_transfer() |
| 577 | + worker3_stats.record_transfer() |
| 578 | + worker3_stats.record_transfer() |
| 579 | + |
| 580 | + # Create ModelRunnerOutput instances for each worker |
| 581 | + worker_outputs = [] |
| 582 | + for i, worker_stats in enumerate( |
| 583 | + [worker1_stats, worker2_stats, worker3_stats]): |
| 584 | + output = ModelRunnerOutput( |
| 585 | + req_ids=[f"req_{i}"], |
| 586 | + req_id_to_index={f"req_{i}": 0}, |
| 587 | + sampled_token_ids=[[123]], # dummy token |
| 588 | + logprobs=None, |
| 589 | + prompt_logprobs_dict={}, |
| 590 | + pooler_output=[None], |
| 591 | + kv_connector_output=KVConnectorOutput( |
| 592 | + finished_sending=set([f"req_{i}_send"]) |
| 593 | + if i < 2 else None, # Workers 0,1 finished sending |
| 594 | + finished_recving=set([f"req_{i}_recv"]) |
| 595 | + if i > 0 else None, # Workers 1,2 finished receiving |
| 596 | + kv_connector_stats=worker_stats, |
| 597 | + )) |
| 598 | + worker_outputs.append(output) |
| 599 | + |
| 600 | + # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) |
| 601 | + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) |
| 602 | + kv_connector_stats = \ |
| 603 | + aggregated_output.kv_connector_output.kv_connector_stats |
| 604 | + assert isinstance(kv_connector_stats, NixlKVConnectorStats) |
| 605 | + # Number of total transfers across all workers. |
| 606 | + assert kv_connector_stats.data["num_successful_transfers"] == 6 |
| 607 | + |
| 608 | + |
| 609 | +def test_multi_kv_connector_stats_aggregation(): |
| 610 | + """ |
| 611 | + Test MultiKVConnectorStats aggregation across TP ranks using |
| 612 | + KVOutputAggregator (used by MultiprocExecutor). |
| 613 | + """ |
| 614 | + |
| 615 | + aggregator = KVOutputAggregator(world_size=3) |
| 616 | + |
| 617 | + from dataclasses import dataclass |
| 618 | + |
| 619 | + @dataclass |
| 620 | + class FooKVConnectorStats(KVConnectorStats): |
| 621 | + |
| 622 | + def reset(self): |
| 623 | + self.data = {"num_foo_transfers": 0} |
| 624 | + |
| 625 | + def record_transfer(self): |
| 626 | + if "num_foo_transfers" not in self.data: |
| 627 | + self.data["num_foo_transfers"] = 0 |
| 628 | + self.data["num_foo_transfers"] += 1 |
| 629 | + |
| 630 | + def is_empty(self) -> bool: |
| 631 | + return self.data["num_foo_transfers"] == 0 |
| 632 | + |
| 633 | + def aggregate(self, |
| 634 | + other: "FooKVConnectorStats") -> "FooKVConnectorStats": |
| 635 | + if not other.is_empty(): |
| 636 | + self.data["num_foo_transfers"] += other.data[ |
| 637 | + "num_foo_transfers"] |
| 638 | + return self |
| 639 | + |
| 640 | + def make_multi_stats(nixl_count: int, |
| 641 | + foo_count: int) -> MultiKVConnectorStats: |
| 642 | + data: dict[str, KVConnectorStats] = {} |
| 643 | + if nixl_count > 0: |
| 644 | + nixl_stats = NixlKVConnectorStats() |
| 645 | + for _ in range(nixl_count): |
| 646 | + nixl_stats.record_transfer() |
| 647 | + data["NixlConnector"] = nixl_stats |
| 648 | + if foo_count > 0: |
| 649 | + foo_stats = FooKVConnectorStats() |
| 650 | + for _ in range(foo_count): |
| 651 | + foo_stats.record_transfer() |
| 652 | + data["FooConnector"] = foo_stats |
| 653 | + return MultiKVConnectorStats(data=data) |
| 654 | + |
| 655 | + # Create heterogeneous stats across 3 workers |
| 656 | + worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo) |
| 657 | + |
| 658 | + worker_outputs: list[ModelRunnerOutput] = [] |
| 659 | + for i, (nixl, foo) in enumerate(worker_patterns): |
| 660 | + stats = make_multi_stats(nixl, foo) |
| 661 | + output = ModelRunnerOutput( |
| 662 | + req_ids=[f"req_{i}"], |
| 663 | + req_id_to_index={f"req_{i}": 0}, |
| 664 | + sampled_token_ids=[[123]], |
| 665 | + logprobs=None, |
| 666 | + prompt_logprobs_dict={}, |
| 667 | + pooler_output=[None], |
| 668 | + kv_connector_output=KVConnectorOutput( |
| 669 | + finished_sending=set([f"req_{i}_send"]) if i < 2 else None, |
| 670 | + finished_recving=set([f"req_{i}_recv"]) if i > 0 else None, |
| 671 | + kv_connector_stats=stats, |
| 672 | + ), |
| 673 | + ) |
| 674 | + worker_outputs.append(output) |
| 675 | + |
| 676 | + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) |
| 677 | + kv_connector_stats = \ |
| 678 | + aggregated_output.kv_connector_output.kv_connector_stats |
| 679 | + assert isinstance(kv_connector_stats, MultiKVConnectorStats) |
| 680 | + |
| 681 | + # Validate per-connector totals across workers |
| 682 | + assert kv_connector_stats["NixlConnector"].data[ |
| 683 | + "num_successful_transfers"] == 5 |
| 684 | + assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 |
| 685 | + |
| 686 | + |
478 | 687 | @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) |
479 | 688 | @patch( |
480 | 689 | "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", |
|
0 commit comments