@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
2525use super :: protocols:: WorkerSelectionResult ;
2626use super :: WorkerSelector ;
2727use crate :: kv_router:: indexer:: OverlapScores ;
28- use crate :: kv_router:: indexer:: WorkerId ;
2928use crate :: kv_router:: protocols:: LoadMetrics ;
3029use crate :: kv_router:: scoring:: ProcessedEndpoints ;
3130use crate :: kv_router:: sequence:: ActiveSequencesMultiWorker ;
@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
3736pub struct KVHitRateEvent {
3837 pub worker_id : i64 ,
3938 pub isl_blocks : usize ,
40- pub overlap_blocks : usize ,
39+ pub overlap_blocks : u32 ,
4140}
4241
4342#[ derive( Debug , thiserror:: Error ) ]
@@ -79,13 +78,15 @@ impl Endpoint {
7978#[ derive( Debug ) ]
8079pub struct SchedulingResponse {
8180 pub best_worker_id : i64 ,
81+ pub overlap_blocks : u32 , // Add this field
8282 pub endpoints_changed : Option < Vec < i64 > > ,
8383}
8484
8585pub struct SchedulingRequest {
8686 pub isl_tokens : usize ,
87- pub overlap : OverlapScores ,
87+ pub overlaps : OverlapScores ,
8888 pub potential_blocks : HashMap < i64 , usize > ,
89+ pub potential_tokens : HashMap < i64 , usize > ,
8990 resp_tx : tokio:: sync:: oneshot:: Sender < SchedulingResponse > ,
9091}
9192
@@ -174,6 +175,7 @@ impl KvScheduler {
174175
175176 let response = SchedulingResponse {
176177 best_worker_id : selection. worker_id ,
178+ overlap_blocks : selection. overlap_blocks ,
177179 endpoints_changed : pending_endpoint_update. take ( ) ,
178180 } ;
179181 request. respond ( response) ;
@@ -207,18 +209,20 @@ impl KvScheduler {
207209 isl_tokens : usize ,
208210 block_size : u32 ,
209211 tokens : & [ u32 ] ,
210- overlap : OverlapScores ,
212+ overlaps : OverlapScores ,
211213 ) -> Result < i64 , KvSchedulerError > {
212214 let mut sequences = self . sequences . lock ( ) . await ;
213215
214216 let token_sequence = TokenBlockSequence :: from_slice ( tokens, block_size, None ) ;
215- let potential_blocks = sequences. potential_blocks ( token_sequence) ;
217+ let ( potential_blocks, potential_tokens) =
218+ sequences. potential_blocks_and_tokens ( token_sequence, overlaps. clone ( ) ) ;
216219
217220 let ( resp_tx, resp_rx) = tokio:: sync:: oneshot:: channel ( ) ;
218221 let request = SchedulingRequest {
219222 isl_tokens,
220- overlap ,
223+ overlaps ,
221224 potential_blocks,
225+ potential_tokens,
222226 resp_tx,
223227 } ;
224228 self . request_tx
@@ -234,31 +238,16 @@ impl KvScheduler {
234238 }
235239
236240 let token_sequence = TokenBlockSequence :: from_slice ( tokens, block_size, None ) ;
237- sequences. add_request ( request_id, token_sequence, response. best_worker_id ) ;
241+ sequences. add_request (
242+ request_id,
243+ token_sequence,
244+ response. overlap_blocks ,
245+ response. best_worker_id ,
246+ ) ;
238247
239248 Ok ( response. best_worker_id )
240249 }
241250
242- /// Find the potential blocks for each worker if the sequence were routed there
243- pub async fn potential_blocks (
244- & self ,
245- token_sequence : TokenBlockSequence ,
246- ) -> HashMap < i64 , usize > {
247- let sequences = self . sequences . lock ( ) . await ;
248- sequences. potential_blocks ( token_sequence)
249- }
250-
251- /// Add a new request with its initial tokens to a specific worker
252- pub async fn add_request (
253- & self ,
254- request_id : String ,
255- token_sequence : TokenBlockSequence ,
256- worker_id : WorkerId ,
257- ) {
258- let mut sequences = self . sequences . lock ( ) . await ;
259- sequences. add_request ( request_id, token_sequence, worker_id)
260- }
261-
262251 /// Push tokens to a specific request's sequence
263252 pub async fn push ( & self , request_id : & String , tokens : & [ u32 ] ) {
264253 let mut sequences = self . sequences . lock ( ) . await ;
@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
370359 return Err ( KvSchedulerError :: NoEndpoints ) ;
371360 }
372361
373- let request_blocks = request. isl_tokens . div_ceil ( block_size as usize ) ;
362+ let isl = request. isl_tokens ;
363+ let request_blocks = isl. div_ceil ( block_size as usize ) ;
364+ let overlaps = & request. overlaps . scores ;
365+
366+ // active blocks for decoding
374367 let potential_active_blocks = & request. potential_blocks ;
368+ // active tokens in the batch (processed by the linear layers), mostly prefill tokens
369+ let potential_active_tokens = & request. potential_tokens ;
375370
376371 let mut worker_logits = HashMap :: new ( ) ;
377372 let mut max_logit = f64:: NEG_INFINITY ;
378373
379374 // Calculate logits for each worker
380375 for ( worker_id, _) in workers. endpoints . iter ( ) {
381- let cached_blocks = request. overlap . scores . get ( worker_id) . copied ( ) . unwrap_or ( 0 ) as f64 ;
382- let prefill_blocks = request_blocks as f64 - cached_blocks;
376+ // this is the number of tokens each worker would have if the request were scheduled there
377+ let potential_tokens = * potential_active_tokens. get ( worker_id) . unwrap_or_else ( || {
378+ tracing:: warn!(
379+ "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
380+ ) ;
381+ & isl
382+ } ) as f64 ;
383383
384384 // this is the number of blocks each worker would have if the request were scheduled there
385385 let potential_blocks = * potential_active_blocks. get ( worker_id) . unwrap_or_else ( ||
386- { tracing:: warn!( "assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet" ) ;
387- & 0
386+ { tracing:: warn!( "assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet" ) ;
387+ & request_blocks
388388 } ) as f64 ;
389389
390+ let potential_prefill_blocks = potential_tokens / ( block_size as f64 ) ;
391+
390392 // Calculate logit (lower is better)
391- let logit =
392- self . kv_router_config . overlap_score_weight * prefill_blocks + potential_blocks;
393+ let logit = self . kv_router_config . overlap_score_weight * potential_prefill_blocks
394+ + potential_blocks;
393395 max_logit = max_logit. max ( logit) ;
394396
395397 worker_logits. insert ( * worker_id, logit) ;
396398
397399 tracing:: info!(
398- "Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks :.3} + {potential_blocks:.3} (cached_blocks: {cached_blocks })" ,
400+ "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks :.3} + {potential_blocks:.3} (cached_blocks: {})" ,
399401 self . kv_router_config. overlap_score_weight,
400- cached_blocks = cached_blocks
402+ overlaps . get ( worker_id ) . unwrap_or ( & 0 ) ,
401403 ) ;
402404 }
403405
@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
412414 let temperature = self . kv_router_config . router_temperature ;
413415 let best_worker_id = softmax_sample ( & worker_logits, temperature) ;
414416
415- let overlap_blocks = request
416- . overlap
417- . scores
418- . get ( & best_worker_id)
419- . copied ( )
420- . unwrap_or ( 0 ) as usize ;
417+ let overlap_blocks = overlaps. get ( & best_worker_id) . copied ( ) . unwrap_or ( 0 ) ;
421418 let best_logit = worker_logits[ & best_worker_id] ;
422419
423420 tracing:: info!(
0 commit comments