Skip to content

Commit 66231cf

Browse files
authored
feat: reduce / revert routing overheads, do not consider output tokens (#2182)
1 parent dbd33df commit 66231cf

File tree

6 files changed

+252
-404
lines changed

6 files changed

+252
-404
lines changed

lib/llm/src/kv_router.rs

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use crate::{
3131
kv_router::{
3232
approx::ApproxKvIndexer,
3333
indexer::{
34-
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
35-
OverlapScores, RouterEvent,
34+
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
35+
KvRouterError, OverlapScores, RouterEvent,
3636
},
3737
// metrics_aggregator::EndpointCollector,
3838
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
7171

7272
pub use_kv_events: bool,
7373

74-
// note: this is not actually used for now
74+
// TODO: this is not actually used for now
75+
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
7576
pub max_num_batched_tokens: u32,
7677
}
7778

@@ -231,25 +232,25 @@ impl KvRouter {
231232
let _guard = self.find_best_match_mutex.lock().await;
232233

233234
let isl_tokens = tokens.len();
234-
let block_size = self.block_size;
235235

236-
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
237-
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
236+
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
237+
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
238+
239+
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
238240

239241
let best_worker_id = self
240242
.scheduler
241243
.schedule(
242244
context_id.to_string(),
243245
isl_tokens,
244-
block_size,
245-
tokens,
246+
seq_hashes.clone(),
246247
overlap_scores.clone(),
247248
)
248249
.await?;
249250

250251
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
251252
indexer
252-
.process_routing_decision_for_request(tokens, best_worker_id)
253+
.process_routing_decision(best_worker_id, block_hashes, seq_hashes)
253254
.await
254255
.unwrap();
255256
};
@@ -262,9 +263,9 @@ impl KvRouter {
262263
Ok((best_worker_id, overlap_amount))
263264
}
264265

265-
/// Push tokens to a specific request's sequence
266-
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
267-
self.scheduler.push(request_id, tokens).await
266+
/// Free all blocks associated with a request
267+
pub async fn mark_prefill_completed(&self, request_id: &String) {
268+
self.scheduler.mark_prefill_completed(request_id).await
268269
}
269270

270271
/// Free all blocks associated with a request
@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
331332
let stream_context = request.context().clone();
332333
// Update the request with the estimated prefix hit blocks
333334
let (mut backend_input, context) = request.into_parts();
334-
let isl = backend_input.token_ids.len();
335335
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
336336
let updated_request = context.map(|_| backend_input);
337337

@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
345345
let stream = stream::iter(vec![response]);
346346
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
347347
}
348-
// Get the response stream from the worker
349-
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
350348

351-
// Wrap the stream to track tokens
349+
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
352350
let stream_context = response_stream.context();
353351
let chooser = self.chooser.clone();
354-
let request_id = context_id.clone();
355-
let block_size = chooser.block_size() as usize;
356352

357353
let wrapped_stream = Box::pin(async_stream::stream! {
358-
let mut accumulated_tokens = Vec::new();
359-
let mut total_output_length = 0usize;
360-
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
361-
let mut first_push_done = false;
354+
if let Some(first_item) = response_stream.next().await {
355+
chooser.mark_prefill_completed(&context_id).await;
356+
yield first_item;
357+
}
362358

363359
while let Some(item) = response_stream.next().await {
364-
// Track tokens if they exist in the response
365-
let Some(ref output) = item.data else {
366-
yield item;
367-
continue;
368-
};
369-
if output.token_ids.is_empty() {
370-
yield item;
371-
continue;
372-
}
373-
374-
// Add tokens to accumulator
375-
accumulated_tokens.extend_from_slice(&output.token_ids);
376-
total_output_length += output.token_ids.len();
377-
378-
// Always push for the first generated token (to mark prefill done)
379-
// or when we've moved to a new block
380-
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
381-
let should_push = (!first_push_done && total_output_length >= 1) ||
382-
(first_push_done && current_block_index > last_block_index);
383-
384-
if should_push {
385-
chooser.push(&request_id, &accumulated_tokens).await;
386-
accumulated_tokens.clear();
387-
last_block_index = current_block_index;
388-
if !first_push_done {
389-
first_push_done = true;
390-
}
391-
}
392-
393360
yield item;
394361
}
395362

396-
chooser.free(&request_id).await;
363+
chooser.free(&context_id).await;
397364
});
398365
Ok(ResponseStream::new(wrapped_stream, stream_context))
399366
}

lib/llm/src/kv_router/approx.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
2323
use tokio::time::{Duration, Instant};
2424
use tokio_util::sync::CancellationToken;
2525

26-
use crate::tokens::TokenBlockSequence;
26+
use crate::tokens::{SequenceHash, TokenBlockSequence};
2727

2828
use crate::kv_router::indexer::{
2929
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
295295
self.kv_block_size
296296
}
297297

298+
/// Core function to process a routing decision with pre-computed hashes
299+
pub async fn process_routing_decision(
300+
&self,
301+
worker_id: WorkerId,
302+
local_hashes: Vec<LocalBlockHash>,
303+
sequence_hashes: Vec<SequenceHash>,
304+
) -> Result<(), KvRouterError> {
305+
self.route_tx
306+
.send(RouterResult {
307+
worker_id,
308+
local_hashes,
309+
sequence_hashes,
310+
})
311+
.await
312+
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
313+
314+
Ok(())
315+
}
316+
317+
/// Wrapper function that computes hashes from tokens and calls the core function
298318
pub async fn process_routing_decision_for_request(
299319
&self,
300320
tokens: &[u32],
@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
309329
.map(|b| b.sequence_hash())
310330
.collect::<Vec<_>>();
311331

312-
self.route_tx
313-
.send(RouterResult {
314-
worker_id,
315-
local_hashes,
316-
sequence_hashes,
317-
})
332+
self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
318333
.await
319-
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
320-
321-
Ok(())
322334
}
323335
}
324336

lib/llm/src/kv_router/indexer.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
6363
pub const XXH3_SEED: u64 = 1337;
6464

6565
use crate::kv_router::protocols::*;
66+
use crate::tokens::SequenceHash;
6667

6768
/// Errors that can occur in the KV Router.
6869
#[derive(Debug, thiserror::Error)]
@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
133134
.collect()
134135
}
135136

137+
/// Compute rolling sequence hashes for a vector of block hashes.
138+
///
139+
/// This mirrors the behavior in tokens.rs where:
140+
/// - The first block's sequence hash equals its block hash
141+
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
142+
///
143+
/// ### Arguments
144+
///
145+
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
146+
///
147+
/// ### Returns
148+
///
149+
/// A vector of u64 values representing the sequence hashes for each block.
150+
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
151+
if block_hashes.is_empty() {
152+
return Vec::new();
153+
}
154+
155+
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
156+
sequence_hashes.push(block_hashes[0].0);
157+
158+
for i in 1..block_hashes.len() {
159+
let parent_seq_hash = sequence_hashes[i - 1];
160+
let current_block_hash = block_hashes[i].0;
161+
162+
let combined = [parent_seq_hash, current_block_hash];
163+
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
164+
let seq_hash = compute_hash(&bytes);
165+
sequence_hashes.push(seq_hash);
166+
}
167+
168+
sequence_hashes
169+
}
170+
136171
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
137172
#[derive(Debug, Clone, Serialize, Deserialize)]
138173
pub struct RouterEvent {

lib/llm/src/kv_router/scheduler.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
2929
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
3030
use crate::kv_router::KvRouterConfig;
3131
use crate::kv_router::KV_HIT_RATE_SUBJECT;
32-
use crate::tokens::TokenBlockSequence;
32+
use crate::tokens::SequenceHash;
3333
use dynamo_runtime::component::Instance;
3434

3535
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -217,15 +217,13 @@ impl KvScheduler {
217217
&self,
218218
request_id: String,
219219
isl_tokens: usize,
220-
block_size: u32,
221-
tokens: &[u32],
220+
token_seq: Vec<SequenceHash>,
222221
overlaps: OverlapScores,
223222
) -> Result<i64, KvSchedulerError> {
224223
let mut sequences = self.sequences.lock().await;
225224

226-
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
227225
let (potential_blocks, potential_tokens) =
228-
sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
226+
sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
229227

230228
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
231229
let request = SchedulingRequest {
@@ -247,21 +245,20 @@ impl KvScheduler {
247245
sequences.update_workers(new_worker_ids);
248246
}
249247

250-
let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
251248
sequences.add_request(
252249
request_id,
253-
token_sequence,
250+
token_seq,
251+
isl_tokens,
254252
response.overlap_blocks,
255253
response.best_worker_id,
256254
);
257255

258256
Ok(response.best_worker_id)
259257
}
260258

261-
/// Push tokens to a specific request's sequence
262-
pub async fn push(&self, request_id: &String, tokens: &[u32]) {
259+
pub async fn mark_prefill_completed(&self, request_id: &String) {
263260
let mut sequences = self.sequences.lock().await;
264-
sequences.push(request_id, tokens)
261+
sequences.mark_prefill_completed(request_id)
265262
}
266263

267264
/// Free all blocks associated with a request

0 commit comments

Comments
 (0)