Skip to content

Commit 77801b5

Browse files
committed
Add scheduler-worker aggreation tests in test_nixl_connector.py
Signed-off-by: Qier Li <kevin44036@gmail.com>
1 parent 01f313c commit 77801b5

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,75 @@ def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats:
785785
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
786786

787787

788+
@patch(
789+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
790+
FakeNixlWrapper,
791+
)
792+
def test_scheduler_kv_connector_stats_aggregation():
793+
"""Test scheduler and worker KV connector stats aggregation."""
794+
from vllm.v1.core.sched.output import SchedulerOutput
795+
796+
scheduler = create_scheduler(create_vllm_config())
797+
798+
# Worker stats with transfer metrics
799+
worker_stats = NixlKVConnectorStats()
800+
worker_stats.record_transfer(get_default_xfer_telemetry())
801+
worker_stats.data["remote_tokens"] = []
802+
803+
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
804+
scheduler_stats = NixlKVConnectorStats()
805+
scheduler_stats.data.update(
806+
{
807+
"transfer_duration": [0],
808+
"post_duration": [0],
809+
"bytes_transferred": [0],
810+
"num_descriptors": [0],
811+
"remote_tokens": [128],
812+
}
813+
)
814+
815+
# Mock the scheduler connector's stats method
816+
scheduler.connector.get_kv_connector_stats = lambda: MultiKVConnectorStats(
817+
data={"NixlConnector": scheduler_stats}
818+
)
819+
820+
model_output = ModelRunnerOutput(
821+
req_ids=["req_0"],
822+
req_id_to_index={"req_0": 0},
823+
sampled_token_ids=[[123]],
824+
logprobs=None,
825+
prompt_logprobs_dict={},
826+
pooler_output=[None],
827+
kv_connector_output=KVConnectorOutput(
828+
kv_connector_stats=MultiKVConnectorStats(
829+
data={"NixlConnector": worker_stats}
830+
)
831+
),
832+
)
833+
scheduler_output = SchedulerOutput(
834+
scheduled_new_reqs=[],
835+
scheduled_cached_reqs=None,
836+
num_scheduled_tokens={"req_0": 1},
837+
total_num_scheduled_tokens=1,
838+
scheduled_spec_decode_tokens={},
839+
scheduled_encoder_inputs={},
840+
num_common_prefix_blocks=[0],
841+
finished_req_ids=set(),
842+
free_encoder_mm_hashes=set(),
843+
structured_output_request_ids={},
844+
grammar_bitmask=None,
845+
)
846+
847+
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
848+
849+
final_stats = next(
850+
iter(engine_core_outputs.values())
851+
).scheduler_stats.kv_connector_stats
852+
nixl_stats = final_stats["NixlConnector"]
853+
assert nixl_stats.num_successful_transfers == 2
854+
assert nixl_stats.data["remote_tokens"] == [128]
855+
856+
788857
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
789858
@patch(
790859
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",

0 commit comments

Comments
 (0)