@@ -31,8 +31,8 @@ use crate::{
3131 kv_router:: {
3232 approx:: ApproxKvIndexer ,
3333 indexer:: {
34- compute_block_hash_for_seq, KvIndexer , KvIndexerInterface , KvRouterError ,
35- OverlapScores , RouterEvent ,
34+ compute_block_hash_for_seq, compute_seq_hash_for_block , KvIndexer , KvIndexerInterface ,
35+ KvRouterError , OverlapScores , RouterEvent ,
3636 } ,
3737 // metrics_aggregator::EndpointCollector,
3838 protocols:: { LocalBlockHash , RouterRequest , RouterResponse , WorkerSelectionResult } ,
@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
7171
7272 pub use_kv_events : bool ,
7373
74- // note: this is not actually used for now
74+ // TODO: this is not actually used for now
75+ // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
7576 pub max_num_batched_tokens : u32 ,
7677}
7778
@@ -231,25 +232,25 @@ impl KvRouter {
231232 let _guard = self . find_best_match_mutex . lock ( ) . await ;
232233
233234 let isl_tokens = tokens. len ( ) ;
234- let block_size = self . block_size ;
235235
236- let local_block_hashes = compute_block_hash_for_seq ( tokens, self . block_size ) ;
237- let overlap_scores = self . indexer . find_matches ( local_block_hashes) . await ?;
236+ let block_hashes = compute_block_hash_for_seq ( tokens, self . block_size ) ;
237+ let seq_hashes = compute_seq_hash_for_block ( & block_hashes) ;
238+
239+ let overlap_scores = self . indexer . find_matches ( block_hashes. clone ( ) ) . await ?;
238240
239241 let best_worker_id = self
240242 . scheduler
241243 . schedule (
242244 context_id. to_string ( ) ,
243245 isl_tokens,
244- block_size,
245- tokens,
246+ seq_hashes. clone ( ) ,
246247 overlap_scores. clone ( ) ,
247248 )
248249 . await ?;
249250
250251 if let Indexer :: ApproxKvIndexer ( ref indexer) = self . indexer {
251252 indexer
252- . process_routing_decision_for_request ( tokens , best_worker_id )
253+ . process_routing_decision ( best_worker_id , block_hashes , seq_hashes )
253254 . await
254255 . unwrap ( ) ;
255256 } ;
@@ -262,9 +263,9 @@ impl KvRouter {
262263 Ok ( ( best_worker_id, overlap_amount) )
263264 }
264265
265- /// Push tokens to a specific request's sequence
266- pub async fn push ( & self , request_id : & String , tokens : & [ u32 ] ) {
267- self . scheduler . push ( request_id, tokens ) . await
266+ /// Free all blocks associated with a request
267+ pub async fn mark_prefill_completed ( & self , request_id : & String ) {
268+ self . scheduler . mark_prefill_completed ( request_id) . await
268269 }
269270
270271 /// Free all blocks associated with a request
@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
331332 let stream_context = request. context ( ) . clone ( ) ;
332333 // Update the request with the estimated prefix hit blocks
333334 let ( mut backend_input, context) = request. into_parts ( ) ;
334- let isl = backend_input. token_ids . len ( ) ;
335335 backend_input. estimated_prefix_hit_num_blocks = Some ( overlap_amount) ;
336336 let updated_request = context. map ( |_| backend_input) ;
337337
@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
345345 let stream = stream:: iter ( vec ! [ response] ) ;
346346 return Ok ( ResponseStream :: new ( Box :: pin ( stream) , stream_context) ) ;
347347 }
348- // Get the response stream from the worker
349- let mut response_stream = self . inner . direct ( updated_request, instance_id) . await ?;
350348
351- // Wrap the stream to track tokens
349+ let mut response_stream = self . inner . direct ( updated_request , instance_id ) . await ? ;
352350 let stream_context = response_stream. context ( ) ;
353351 let chooser = self . chooser . clone ( ) ;
354- let request_id = context_id. clone ( ) ;
355- let block_size = chooser. block_size ( ) as usize ;
356352
357353 let wrapped_stream = Box :: pin ( async_stream:: stream! {
358- let mut accumulated_tokens = Vec :: new ( ) ;
359- let mut total_output_length = 0usize ;
360- let mut last_block_index = ( isl . saturating_sub ( 1 ) ) / block_size ;
361- let mut first_push_done = false ;
354+ if let Some ( first_item ) = response_stream . next ( ) . await {
355+ chooser . mark_prefill_completed ( & context_id ) . await ;
356+ yield first_item ;
357+ }
362358
363359 while let Some ( item) = response_stream. next( ) . await {
364- // Track tokens if they exist in the response
365- let Some ( ref output) = item. data else {
366- yield item;
367- continue ;
368- } ;
369- if output. token_ids. is_empty( ) {
370- yield item;
371- continue ;
372- }
373-
374- // Add tokens to accumulator
375- accumulated_tokens. extend_from_slice( & output. token_ids) ;
376- total_output_length += output. token_ids. len( ) ;
377-
378- // Always push for the first generated token (to mark prefill done)
379- // or when we've moved to a new block
380- let current_block_index = ( isl + total_output_length) . saturating_sub( 1 ) / block_size;
381- let should_push = ( !first_push_done && total_output_length >= 1 ) ||
382- ( first_push_done && current_block_index > last_block_index) ;
383-
384- if should_push {
385- chooser. push( & request_id, & accumulated_tokens) . await ;
386- accumulated_tokens. clear( ) ;
387- last_block_index = current_block_index;
388- if !first_push_done {
389- first_push_done = true ;
390- }
391- }
392-
393360 yield item;
394361 }
395362
396- chooser. free( & request_id ) . await ;
363+ chooser. free( & context_id ) . await ;
397364 } ) ;
398365 Ok ( ResponseStream :: new ( wrapped_stream, stream_context) )
399366 }
0 commit comments