Skip to content

Commit 644a68d

Browse files
PeaBraneZichengMa
authored andcommitted
feat: prefill aware routing (#1895)
1 parent 899a964 commit 644a68d

File tree

8 files changed

+169
-58
lines changed

8 files changed

+169
-58
lines changed

components/metrics/src/bin/mock_worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) {
9393
let event = KVHitRateEvent {
9494
worker_id,
9595
isl_blocks,
96-
overlap_blocks,
96+
overlap_blocks: overlap_blocks as u32,
9797
};
9898

9999
if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {

components/metrics/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
199199
&config_clone,
200200
event.worker_id,
201201
event.isl_blocks,
202-
event.overlap_blocks,
202+
event.overlap_blocks as usize,
203203
);
204204
}
205205
Err(e) => {

docs/guides/dynamo_run.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
88

99
Usage:
1010
```
11-
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)]
11+
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.0] [--use-kv-events=true] [--verbosity (-v|-vv)]
1212
```
1313

1414
Example: `dynamo run Qwen/Qwen3-0.6B`

launch/dynamo-run/src/flags.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ pub struct Flags {
118118
pub max_num_batched_tokens: Option<u32>,
119119

120120
/// KV Router: Weight for overlap score in worker selection.
121-
/// Higher values prioritize KV cache reuse. Default: 2.0
121+
/// Higher values prioritize KV cache reuse. Default: 1.0
122122
#[arg(long)]
123123
pub kv_overlap_score_weight: Option<f64>,
124124

125125
/// KV Router: Temperature for worker sampling via softmax.
126126
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
127-
/// Default: 0.5
127+
/// Default: 0.0
128128
#[arg(long)]
129129
pub router_temperature: Option<f64>,
130130

lib/llm/src/kv_router.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl Default for KvRouterConfig {
7878
fn default() -> Self {
7979
Self {
8080
overlap_score_weight: 1.0,
81-
router_temperature: 0.5,
81+
router_temperature: 0.0,
8282
use_kv_events: true,
8383
max_num_batched_tokens: 8192,
8484
}
@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
337337
let mut accumulated_tokens = Vec::new();
338338
let mut total_output_length = 0usize;
339339
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
340+
let mut first_push_done = false;
340341

341342
while let Some(item) = response_stream.next().await {
342343
// Track tokens if they exist in the response
@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
353354
accumulated_tokens.extend_from_slice(&output.token_ids);
354355
total_output_length += output.token_ids.len();
355356

356-
// Check if we've moved to a new block
357+
// Always push for the first generated token (to mark prefill done)
358+
// or when we've moved to a new block
357359
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
358-
if current_block_index > last_block_index {
360+
let should_push = (!first_push_done && total_output_length >= 1) ||
361+
(first_push_done && current_block_index > last_block_index);
362+
363+
if should_push {
359364
chooser.push(&request_id, &accumulated_tokens).await;
360365
accumulated_tokens.clear();
361366
last_block_index = current_block_index;
367+
if !first_push_done {
368+
first_push_done = true;
369+
}
362370
}
363371

364372
yield item;

lib/llm/src/kv_router/protocols.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult {
3636

3737
/// The number of blocks that the selected worker may already have cached.
3838
/// This is not a guarantee, but an estimate.
39-
pub overlap_blocks: usize,
39+
pub overlap_blocks: u32,
4040
}
4141

4242
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]

lib/llm/src/kv_router/scheduler.rs

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
2525
use super::protocols::WorkerSelectionResult;
2626
use super::WorkerSelector;
2727
use crate::kv_router::indexer::OverlapScores;
28-
use crate::kv_router::indexer::WorkerId;
2928
use crate::kv_router::protocols::LoadMetrics;
3029
use crate::kv_router::scoring::ProcessedEndpoints;
3130
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
3736
pub struct KVHitRateEvent {
3837
pub worker_id: i64,
3938
pub isl_blocks: usize,
40-
pub overlap_blocks: usize,
39+
pub overlap_blocks: u32,
4140
}
4241

4342
#[derive(Debug, thiserror::Error)]
@@ -79,13 +78,15 @@ impl Endpoint {
7978
#[derive(Debug)]
8079
pub struct SchedulingResponse {
8180
pub best_worker_id: i64,
81+
pub overlap_blocks: u32, // Add this field
8282
pub endpoints_changed: Option<Vec<i64>>,
8383
}
8484

8585
pub struct SchedulingRequest {
8686
pub isl_tokens: usize,
87-
pub overlap: OverlapScores,
87+
pub overlaps: OverlapScores,
8888
pub potential_blocks: HashMap<i64, usize>,
89+
pub potential_tokens: HashMap<i64, usize>,
8990
resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
9091
}
9192

@@ -174,6 +175,7 @@ impl KvScheduler {
174175

175176
let response = SchedulingResponse {
176177
best_worker_id: selection.worker_id,
178+
overlap_blocks: selection.overlap_blocks,
177179
endpoints_changed: pending_endpoint_update.take(),
178180
};
179181
request.respond(response);
@@ -207,18 +209,20 @@ impl KvScheduler {
207209
isl_tokens: usize,
208210
block_size: u32,
209211
tokens: &[u32],
210-
overlap: OverlapScores,
212+
overlaps: OverlapScores,
211213
) -> Result<i64, KvSchedulerError> {
212214
let mut sequences = self.sequences.lock().await;
213215

214216
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
215-
let potential_blocks = sequences.potential_blocks(token_sequence);
217+
let (potential_blocks, potential_tokens) =
218+
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
216219

217220
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
218221
let request = SchedulingRequest {
219222
isl_tokens,
220-
overlap,
223+
overlaps,
221224
potential_blocks,
225+
potential_tokens,
222226
resp_tx,
223227
};
224228
self.request_tx
@@ -234,31 +238,16 @@ impl KvScheduler {
234238
}
235239

236240
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
237-
sequences.add_request(request_id, token_sequence, response.best_worker_id);
241+
sequences.add_request(
242+
request_id,
243+
token_sequence,
244+
response.overlap_blocks,
245+
response.best_worker_id,
246+
);
238247

239248
Ok(response.best_worker_id)
240249
}
241250

242-
/// Find the potential blocks for each worker if the sequence were routed there
243-
pub async fn potential_blocks(
244-
&self,
245-
token_sequence: TokenBlockSequence,
246-
) -> HashMap<i64, usize> {
247-
let sequences = self.sequences.lock().await;
248-
sequences.potential_blocks(token_sequence)
249-
}
250-
251-
/// Add a new request with its initial tokens to a specific worker
252-
pub async fn add_request(
253-
&self,
254-
request_id: String,
255-
token_sequence: TokenBlockSequence,
256-
worker_id: WorkerId,
257-
) {
258-
let mut sequences = self.sequences.lock().await;
259-
sequences.add_request(request_id, token_sequence, worker_id)
260-
}
261-
262251
/// Push tokens to a specific request's sequence
263252
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
264253
let mut sequences = self.sequences.lock().await;
@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
370359
return Err(KvSchedulerError::NoEndpoints);
371360
}
372361

373-
let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
362+
let isl = request.isl_tokens;
363+
let request_blocks = isl.div_ceil(block_size as usize);
364+
let overlaps = &request.overlaps.scores;
365+
366+
// active blocks for decoding
374367
let potential_active_blocks = &request.potential_blocks;
368+
// active tokens in the batch (processed by the linear layers), mostly prefill tokens
369+
let potential_active_tokens = &request.potential_tokens;
375370

376371
let mut worker_logits = HashMap::new();
377372
let mut max_logit = f64::NEG_INFINITY;
378373

379374
// Calculate logits for each worker
380375
for (worker_id, _) in workers.endpoints.iter() {
381-
let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64;
382-
let prefill_blocks = request_blocks as f64 - cached_blocks;
376+
// this is the number of tokens each worker would have if the request were scheduled there
377+
let potential_tokens = *potential_active_tokens.get(worker_id).unwrap_or_else(|| {
378+
tracing::warn!(
379+
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
380+
);
381+
&isl
382+
}) as f64;
383383

384384
// this is the number of blocks each worker would have if the request were scheduled there
385385
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
386-
{tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet");
387-
&0
386+
{tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
387+
&request_blocks
388388
}) as f64;
389389

390+
let potential_prefill_blocks = potential_tokens / (block_size as f64);
391+
390392
// Calculate logit (lower is better)
391-
let logit =
392-
self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks;
393+
let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks
394+
+ potential_blocks;
393395
max_logit = max_logit.max(logit);
394396

395397
worker_logits.insert(*worker_id, logit);
396398

397399
tracing::info!(
398-
"Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {cached_blocks})",
400+
"Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})",
399401
self.kv_router_config.overlap_score_weight,
400-
cached_blocks = cached_blocks
402+
overlaps.get(worker_id).unwrap_or(&0),
401403
);
402404
}
403405

@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
412414
let temperature = self.kv_router_config.router_temperature;
413415
let best_worker_id = softmax_sample(&worker_logits, temperature);
414416

415-
let overlap_blocks = request
416-
.overlap
417-
.scores
418-
.get(&best_worker_id)
419-
.copied()
420-
.unwrap_or(0) as usize;
417+
let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0);
421418
let best_logit = worker_logits[&best_worker_id];
422419

423420
tracing::info!(

0 commit comments

Comments
 (0)