Skip to content

Commit ba3ac23

Browse files
authored
test: add router e2e test with mockers to per-merge ci (#2073)
Signed-off-by: Yan Ru Pei <yanrpei@gmail.com>
1 parent 2fc65ad commit ba3ac23

File tree

7 files changed

+444
-172
lines changed

7 files changed

+444
-172
lines changed

lib/llm/src/kv_router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ impl KvRouter {
191191
}
192192
};
193193
if let Err(e) = kv_events_tx.send(event).await {
194-
tracing::debug!(
194+
tracing::warn!(
195195
"failed to send kv event to indexer; shutting down: {:?}",
196196
e
197197
);

lib/llm/src/kv_router/scheduler.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ impl KvScheduler {
177177
request.respond(response);
178178
continue 'outer;
179179
}
180+
Err(KvSchedulerError::NoEndpoints) => {
181+
tracing::trace!("no endpoints available; waiting for endpoints update");
182+
endpoints_rx.changed().await.ok();
183+
endpoints = endpoints_rx.borrow_and_update().clone();
184+
pending_endpoint_update = Some(endpoints.worker_ids());
185+
continue;
186+
}
180187
// TODO: this is not actually hooked up
181188
Err(KvSchedulerError::AllWorkersBusy) => {
182189
tracing::trace!("all workers busy; waiting for more capacity");

lib/llm/src/mocker/scheduler.rs

Lines changed: 121 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use std::collections::HashMap;
5151
use std::collections::VecDeque;
5252
use std::sync::Arc;
5353
use tokio::sync::{mpsc, Mutex};
54-
use tokio::time::{interval, Duration};
54+
use tokio::time::Duration;
5555
use tokio_util::sync::CancellationToken;
5656
use uuid::Uuid;
5757

@@ -81,6 +81,10 @@ impl SchedulerState {
8181
}
8282
}
8383

84+
fn is_empty(&self) -> bool {
85+
self.requests.is_empty()
86+
}
87+
8488
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
8589
fn receive(&mut self, request: DirectRequest) -> Uuid {
8690
// Use the provided UUID if available, otherwise generate a new one
@@ -295,11 +299,25 @@ impl Scheduler {
295299

296300
// Spawn main background task with cancellation token
297301
tokio::spawn(async move {
298-
let mut schedule_interval = interval(Duration::from_secs_f64(1e-3));
299-
let mut simulate_interval = interval(Duration::from_secs_f64(1e-4));
300302
let mut should_schedule = true;
301303

302304
loop {
305+
{
306+
let state_guard = state_clone.lock().await;
307+
308+
// Enqueue new request, blocks until at least one is received, so no redundant work is done
309+
// TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever
310+
if state_guard.is_empty() {
311+
drop(state_guard);
312+
let Some(request) = request_rx.recv().await else {
313+
tracing::warn!("request sender is dropped");
314+
break;
315+
};
316+
let mut state_guard = state_clone.lock().await;
317+
state_guard.receive(request);
318+
}
319+
}
320+
303321
tokio::select! {
304322
biased;
305323

@@ -310,7 +328,7 @@ impl Scheduler {
310328
}
311329

312330
// Try Scheduling Requests - runs on normal interval or after simulation
313-
_ = schedule_interval.tick() => {
331+
_ = tokio::task::yield_now() => {
314332
// Skip if we just ran scheduling after simulation to prevent consecutive runs
315333
if !should_schedule {
316334
continue;
@@ -371,100 +389,117 @@ impl Scheduler {
371389
_ = cancel_token_clone.cancelled() => {
372390
break;
373391
}
392+
}
374393

375-
// Simulate running requests (prefill + decode)
376-
_ = simulate_interval.tick() => {
377-
let mut state_guard = state_clone.lock().await;
378-
let mut kv_manager_guard = kv_manager_clone.lock().await;
379-
380-
// Base time needed for decoding using active percentage and quadratic formula
381-
let active_perc = kv_manager_guard.get_active_perc();
382-
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
383-
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
384-
385-
// Process prefilling
386-
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = state_guard.try_prefill() {
387-
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
388-
// could be cached by other requests in the same batch. This matches vLLM behavior.
389-
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
390-
391-
if let Some(creation_signal) = maybe_creation_signal {
392-
if !process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)) {
393-
panic!("Block allocation for prefilling cannot fail.");
394-
}
395-
396-
// Drain KV events and forward to relay after prefill signal processing
397-
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
398-
while let Ok(event) = rx.try_recv() {
399-
let _ = relay_tx.send(block_response_to_kv_event(event));
400-
}
401-
}
402-
};
403-
404-
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
405-
if !is_full_prefill { break; }
394+
// Simulates prefill + decode
395+
let mut state_guard = state_clone.lock().await;
396+
let mut kv_manager_guard = kv_manager_clone.lock().await;
397+
398+
// Base time needed for decoding using active percentage and quadratic formula
399+
let active_perc = kv_manager_guard.get_active_perc();
400+
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
401+
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
402+
403+
// Process prefilling
404+
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
405+
state_guard.try_prefill()
406+
{
407+
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
408+
// could be cached by other requests in the same batch. This matches vLLM behavior.
409+
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
410+
411+
if let Some(creation_signal) = maybe_creation_signal {
412+
if !process_signals(
413+
&mut kv_manager_guard,
414+
std::slice::from_ref(&creation_signal),
415+
) {
416+
panic!("Block allocation for prefilling cannot fail.");
406417
}
407418

408-
state_guard.reset_active_tokens();
409-
410-
// Process decoding
411-
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
412-
if !uuids.is_empty() {should_schedule = true};
413-
for uuid in uuids {
414-
let Some(sequence) = state_guard.run(uuid) else {
415-
continue;
416-
};
417-
let signals = sequence.generate();
418-
419-
// Process all signals with the KvManager
420-
// Handling of preemption on failure
421-
if !process_signals(&mut kv_manager_guard, &signals) {
422-
sequence.pop(); // revert the failed generation op
423-
for signal in state_guard.preempt() {
424-
kv_manager_guard.process(&signal);
425-
}
426-
continue;
419+
// Drain KV events and forward to relay after prefill signal processing
420+
if let (Some(ref relay_tx), Some(ref mut rx)) =
421+
(&kv_events_tx, &mut block_resp_rx)
422+
{
423+
while let Ok(event) = rx.try_recv() {
424+
let _ = relay_tx.send(block_response_to_kv_event(event));
427425
}
426+
}
427+
};
428428

429-
// Drain KV events and forward to relay after decode signal processing
430-
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
431-
while let Ok(event) = rx.try_recv() {
432-
let _ = relay_tx.send(block_response_to_kv_event(event));
433-
}
434-
}
429+
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
430+
if !is_full_prefill {
431+
break;
432+
}
433+
}
435434

436-
// Check completion and send notification
437-
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
438-
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
435+
state_guard.reset_active_tokens();
436+
437+
// Process decoding
438+
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
439+
if !uuids.is_empty() {
440+
should_schedule = true
441+
};
442+
for uuid in uuids {
443+
let Some(sequence) = state_guard.run(uuid) else {
444+
continue;
445+
};
446+
let signals = sequence.generate();
447+
448+
// Process all signals with the KvManager
449+
// Handling of preemption on failure
450+
if !process_signals(&mut kv_manager_guard, &signals) {
451+
sequence.pop(); // revert the failed generation op
452+
for signal in state_guard.preempt() {
453+
kv_manager_guard.process(&signal);
454+
}
455+
continue;
456+
}
439457

440-
let mut send_failed = false;
441-
if should_output {
442-
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
443-
tx.send(OutputSignal { uuid, completed: is_complete }).is_err()
444-
});
445-
}
458+
// Drain KV events and forward to relay after decode signal processing
459+
if let (Some(ref relay_tx), Some(ref mut rx)) =
460+
(&kv_events_tx, &mut block_resp_rx)
461+
{
462+
while let Ok(event) = rx.try_recv() {
463+
let _ = relay_tx.send(block_response_to_kv_event(event));
464+
}
465+
}
446466

447-
if send_failed {
448-
for signal in &sequence.free_signal() {
449-
kv_manager_guard.process(signal);
450-
}
451-
}
467+
// Check completion and send notification
468+
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
469+
let should_output =
470+
sequence.generated_tokens() > sequence.already_generated_tokens();
471+
472+
let mut send_failed = false;
473+
if should_output {
474+
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
475+
tx.send(OutputSignal {
476+
uuid,
477+
completed: is_complete,
478+
})
479+
.is_err()
480+
});
481+
}
452482

453-
if send_failed || is_complete {
454-
state_guard.complete(&uuid);
455-
continue;
456-
}
483+
if send_failed {
484+
for signal in &sequence.free_signal() {
485+
kv_manager_guard.process(signal);
457486
}
487+
}
458488

459-
// Sleep once for the adjusted duration
460-
drop(kv_manager_guard);
461-
drop(state_guard);
462-
let adjusted_time = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
463-
if adjusted_time.as_millis() > 0 {
464-
tokio::time::sleep(adjusted_time).await;
465-
}
489+
if send_failed || is_complete {
490+
state_guard.complete(&uuid);
491+
continue;
466492
}
467493
}
494+
495+
// Sleep once for the adjusted duration
496+
drop(kv_manager_guard);
497+
drop(state_guard);
498+
let adjusted_time =
499+
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
500+
if adjusted_time.as_millis() > 0 {
501+
tokio::time::sleep(adjusted_time).await;
502+
}
468503
}
469504
});
470505

@@ -632,6 +667,7 @@ mod tests {
632667
use super::*;
633668
use rstest::rstest;
634669
use std::time::Duration;
670+
use tokio::time::interval;
635671

636672
#[rstest]
637673
#[case::case_1(false, false, false)]

tests/conftest.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,66 @@
3333
datefmt=DATE_FORMAT, # ISO 8601 UTC format
3434
)
3535

36+
# List of models used in tests
37+
TEST_MODELS = [
38+
"Qwen/Qwen3-0.6B",
39+
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
40+
"llava-hf/llava-1.5-7b-hf",
41+
]
42+
43+
44+
def download_models(model_list=None):
45+
"""Download models - can be called directly or via fixture
46+
47+
Args:
48+
model_list: List of model IDs to download. If None, downloads TEST_MODELS.
49+
"""
50+
if model_list is None:
51+
model_list = TEST_MODELS
52+
53+
# Check for HF_TOKEN in environment
54+
hf_token = os.environ.get("HF_TOKEN")
55+
if hf_token:
56+
logging.info("HF_TOKEN found in environment")
57+
else:
58+
logging.warning(
59+
"HF_TOKEN not found in environment. "
60+
"Some models may fail to download or you may encounter rate limits. "
61+
"Get a token from https://huggingface.co/settings/tokens"
62+
)
63+
64+
try:
65+
from huggingface_hub import snapshot_download
66+
67+
for model_id in model_list:
68+
logging.info(f"Pre-downloading model: {model_id}")
69+
70+
try:
71+
# Download the full model snapshot (includes all files)
72+
# HuggingFace will handle caching automatically
73+
snapshot_download(
74+
repo_id=model_id,
75+
token=hf_token,
76+
)
77+
logging.info(f"Successfully pre-downloaded: {model_id}")
78+
79+
except Exception as e:
80+
logging.error(f"Failed to pre-download {model_id}: {e}")
81+
# Don't fail the fixture - let individual tests handle missing models
82+
83+
except ImportError:
84+
logging.warning(
85+
"huggingface_hub not installed. "
86+
"Models will be downloaded during test execution."
87+
)
88+
89+
90+
@pytest.fixture(scope="session")
91+
def predownload_models():
92+
"""Fixture wrapper around download_models for all TEST_MODELS"""
93+
download_models()
94+
yield
95+
3696

3797
@pytest.fixture(autouse=True)
3898
def logger(request):
@@ -64,6 +124,18 @@ def pytest_collection_modifyitems(config, items):
64124
if "tensorrtllm" in item.keywords:
65125
item.add_marker(skip_tensorrtllm)
66126

127+
# Auto-inject predownload_models fixture for serve tests only (not router tests)
128+
# Skip items that don't have fixturenames (like MypyFileItem)
129+
if hasattr(item, "fixturenames"):
130+
# Only apply to tests in the serve directory
131+
if (
132+
("serve" in str(item.path))
133+
and ("predownload_models" not in item.fixturenames)
134+
and (not item.get_closest_marker("skip_model_download"))
135+
):
136+
item.fixturenames = list(item.fixturenames)
137+
item.fixturenames.append("predownload_models")
138+
67139

68140
class EtcdServer(ManagedProcess):
69141
def __init__(self, request, port=2379, timeout=300):

tests/router/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)