@@ -785,6 +785,75 @@ def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats:
785785 assert kv_connector_stats ["FooConnector" ].data ["num_foo_transfers" ] == 6
786786
787787
788+ @patch (
789+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
790+ FakeNixlWrapper ,
791+ )
792+ def test_scheduler_kv_connector_stats_aggregation ():
793+ """Test scheduler and worker KV connector stats aggregation."""
794+ from vllm .v1 .core .sched .output import SchedulerOutput
795+
796+ scheduler = create_scheduler (create_vllm_config ())
797+
798+ # Worker stats with transfer metrics
799+ worker_stats = NixlKVConnectorStats ()
800+ worker_stats .record_transfer (get_default_xfer_telemetry ())
801+ worker_stats .data ["remote_tokens" ] = []
802+
803+ # Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
804+ scheduler_stats = NixlKVConnectorStats ()
805+ scheduler_stats .data .update (
806+ {
807+ "transfer_duration" : [0 ],
808+ "post_duration" : [0 ],
809+ "bytes_transferred" : [0 ],
810+ "num_descriptors" : [0 ],
811+ "remote_tokens" : [128 ],
812+ }
813+ )
814+
815+ # Mock the scheduler connector's stats method
816+ scheduler .connector .get_kv_connector_stats = lambda : MultiKVConnectorStats (
817+ data = {"NixlConnector" : scheduler_stats }
818+ )
819+
820+ model_output = ModelRunnerOutput (
821+ req_ids = ["req_0" ],
822+ req_id_to_index = {"req_0" : 0 },
823+ sampled_token_ids = [[123 ]],
824+ logprobs = None ,
825+ prompt_logprobs_dict = {},
826+ pooler_output = [None ],
827+ kv_connector_output = KVConnectorOutput (
828+ kv_connector_stats = MultiKVConnectorStats (
829+ data = {"NixlConnector" : worker_stats }
830+ )
831+ ),
832+ )
833+ scheduler_output = SchedulerOutput (
834+ scheduled_new_reqs = [],
835+ scheduled_cached_reqs = None ,
836+ num_scheduled_tokens = {"req_0" : 1 },
837+ total_num_scheduled_tokens = 1 ,
838+ scheduled_spec_decode_tokens = {},
839+ scheduled_encoder_inputs = {},
840+ num_common_prefix_blocks = [0 ],
841+ finished_req_ids = set (),
842+ free_encoder_mm_hashes = set (),
843+ structured_output_request_ids = {},
844+ grammar_bitmask = None ,
845+ )
846+
847+ engine_core_outputs = scheduler .update_from_output (scheduler_output , model_output )
848+
849+ final_stats = next (
850+ iter (engine_core_outputs .values ())
851+ ).scheduler_stats .kv_connector_stats
852+ nixl_stats = final_stats ["NixlConnector" ]
853+ assert nixl_stats .num_successful_transfers == 2
854+ assert nixl_stats .data ["remote_tokens" ] == [128 ]
855+
856+
788857@pytest .mark .parametrize ("distributed_executor_backend" , ["ray" , None ])
789858@patch (
790859 "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
0 commit comments