Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl KvRouter {
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::debug!(
tracing::warn!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
Expand Down
7 changes: 7 additions & 0 deletions lib/llm/src/kv_router/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ impl KvScheduler {
request.respond(response);
continue 'outer;
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update");
endpoints_rx.changed().await.ok();
endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
continue;
}
// TODO: this is not actually hooked up
Err(KvSchedulerError::AllWorkersBusy) => {
tracing::trace!("all workers busy; waiting for more capacity");
Expand Down
206 changes: 121 additions & 85 deletions lib/llm/src/mocker/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, Duration};
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;

Expand Down Expand Up @@ -81,6 +81,10 @@ impl SchedulerState {
}
}

fn is_empty(&self) -> bool {
self.requests.is_empty()
}

/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one
Expand Down Expand Up @@ -295,11 +299,25 @@ impl Scheduler {

// Spawn main background task with cancellation token
tokio::spawn(async move {
let mut schedule_interval = interval(Duration::from_secs_f64(1e-3));
let mut simulate_interval = interval(Duration::from_secs_f64(1e-4));
let mut should_schedule = true;

loop {
{
let state_guard = state_clone.lock().await;

// Enqueue new request, blocks until at least one is received, so no redundant work is done
// TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever
if state_guard.is_empty() {
drop(state_guard);
let Some(request) = request_rx.recv().await else {
tracing::warn!("request sender is dropped");
break;
};
let mut state_guard = state_clone.lock().await;
state_guard.receive(request);
}
}

tokio::select! {
biased;

Expand All @@ -310,7 +328,7 @@ impl Scheduler {
}

// Try Scheduling Requests - runs on normal interval or after simulation
_ = schedule_interval.tick() => {
_ = tokio::task::yield_now() => {
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if !should_schedule {
continue;
Expand Down Expand Up @@ -371,100 +389,117 @@ impl Scheduler {
_ = cancel_token_clone.cancelled() => {
break;
}
}

// Simulate running requests (prefill + decode)
_ = simulate_interval.tick() => {
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;

// Base time needed for decoding using active percentage and quadratic formula
let active_perc = kv_manager_guard.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);

// Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = state_guard.try_prefill() {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);

if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)) {
panic!("Block allocation for prefilling cannot fail.");
}

// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
};

// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill { break; }
// Simulates prefill + decode
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;

// Base time needed for decoding using active percentage and quadratic formula
let active_perc = kv_manager_guard.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);

// Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state_guard.try_prefill()
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);

if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(
&mut kv_manager_guard,
std::slice::from_ref(&creation_signal),
) {
panic!("Block allocation for prefilling cannot fail.");
}

state_guard.reset_active_tokens();

// Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {should_schedule = true};
for uuid in uuids {
let Some(sequence) = state_guard.run(uuid) else {
continue;
};
let signals = sequence.generate();

// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state_guard.preempt() {
kv_manager_guard.process(&signal);
}
continue;
// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) =
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
};

// Drain KV events and forward to relay after decode signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill {
break;
}
}

// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
state_guard.reset_active_tokens();

// Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {
should_schedule = true
};
for uuid in uuids {
let Some(sequence) = state_guard.run(uuid) else {
continue;
};
let signals = sequence.generate();

// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
for signal in state_guard.preempt() {
kv_manager_guard.process(&signal);
}
continue;
}

let mut send_failed = false;
if should_output {
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
tx.send(OutputSignal { uuid, completed: is_complete }).is_err()
});
}
// Drain KV events and forward to relay after decode signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) =
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}

if send_failed {
for signal in &sequence.free_signal() {
kv_manager_guard.process(signal);
}
}
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output =
sequence.generated_tokens() > sequence.already_generated_tokens();

let mut send_failed = false;
if should_output {
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
tx.send(OutputSignal {
uuid,
completed: is_complete,
})
.is_err()
});
}

if send_failed || is_complete {
state_guard.complete(&uuid);
continue;
}
if send_failed {
for signal in &sequence.free_signal() {
kv_manager_guard.process(signal);
}
}

// Sleep once for the adjusted duration
drop(kv_manager_guard);
drop(state_guard);
let adjusted_time = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
}
if send_failed || is_complete {
state_guard.complete(&uuid);
continue;
}
}

// Sleep once for the adjusted duration
drop(kv_manager_guard);
drop(state_guard);
let adjusted_time =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
}
}
});

Expand Down Expand Up @@ -632,6 +667,7 @@ mod tests {
use super::*;
use rstest::rstest;
use std::time::Duration;
use tokio::time::interval;

#[rstest]
#[case::case_1(false, false, false)]
Expand Down
72 changes: 72 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,66 @@
datefmt=DATE_FORMAT, # ISO 8601 UTC format
)

# List of models used in tests
TEST_MODELS = [
"Qwen/Qwen3-0.6B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llava-hf/llava-1.5-7b-hf",
]


def download_models(model_list=None):
"""Download models - can be called directly or via fixture

Args:
model_list: List of model IDs to download. If None, downloads TEST_MODELS.
"""
if model_list is None:
model_list = TEST_MODELS

# Check for HF_TOKEN in environment
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
logging.info("HF_TOKEN found in environment")
else:
logging.warning(
"HF_TOKEN not found in environment. "
"Some models may fail to download or you may encounter rate limits. "
"Get a token from https://huggingface.co/settings/tokens"
)

try:
from huggingface_hub import snapshot_download

for model_id in model_list:
logging.info(f"Pre-downloading model: {model_id}")

try:
# Download the full model snapshot (includes all files)
# HuggingFace will handle caching automatically
snapshot_download(
repo_id=model_id,
token=hf_token,
)
logging.info(f"Successfully pre-downloaded: {model_id}")

except Exception as e:
logging.error(f"Failed to pre-download {model_id}: {e}")
# Don't fail the fixture - let individual tests handle missing models

except ImportError:
logging.warning(
"huggingface_hub not installed. "
"Models will be downloaded during test execution."
)


@pytest.fixture(scope="session")
def predownload_models():
"""Fixture wrapper around download_models for all TEST_MODELS"""
download_models()
yield


@pytest.fixture(autouse=True)
def logger(request):
Expand Down Expand Up @@ -64,6 +124,18 @@ def pytest_collection_modifyitems(config, items):
if "tensorrtllm" in item.keywords:
item.add_marker(skip_tensorrtllm)

# Auto-inject predownload_models fixture for serve tests only (not router tests)
# Skip items that don't have fixturenames (like MypyFileItem)
if hasattr(item, "fixturenames"):
# Only apply to tests in the serve directory
if (
("serve" in str(item.path))
and ("predownload_models" not in item.fixturenames)
and (not item.get_closest_marker("skip_model_download"))
):
item.fixturenames = list(item.fixturenames)
item.fixturenames.append("predownload_models")


class EtcdServer(ManagedProcess):
def __init__(self, request, port=2379, timeout=300):
Expand Down
2 changes: 2 additions & 0 deletions tests/router/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Loading
Loading