@@ -51,7 +51,7 @@ use std::collections::HashMap;
5151use std:: collections:: VecDeque ;
5252use std:: sync:: Arc ;
5353use tokio:: sync:: { mpsc, Mutex } ;
54- use tokio:: time:: { interval , Duration } ;
54+ use tokio:: time:: Duration ;
5555use tokio_util:: sync:: CancellationToken ;
5656use uuid:: Uuid ;
5757
@@ -81,6 +81,10 @@ impl SchedulerState {
8181 }
8282 }
8383
84+ fn is_empty ( & self ) -> bool {
85+ self . requests . is_empty ( )
86+ }
87+
8488 /// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
8589 fn receive ( & mut self , request : DirectRequest ) -> Uuid {
8690 // Use the provided UUID if available, otherwise generate a new one
@@ -295,11 +299,25 @@ impl Scheduler {
295299
296300 // Spawn main background task with cancellation token
297301 tokio:: spawn ( async move {
298- let mut schedule_interval = interval ( Duration :: from_secs_f64 ( 1e-3 ) ) ;
299- let mut simulate_interval = interval ( Duration :: from_secs_f64 ( 1e-4 ) ) ;
300302 let mut should_schedule = true ;
301303
302304 loop {
305+ {
306+ let state_guard = state_clone. lock ( ) . await ;
307+
308+ // Enqueue new request, blocks until at least one is received, so no redundant work is done
309+ // TODO: clean this up? double lock acquisition is ugly, but needed to not hold the lock forever
310+ if state_guard. is_empty ( ) {
311+ drop ( state_guard) ;
312+ let Some ( request) = request_rx. recv ( ) . await else {
313+ tracing:: warn!( "request sender is dropped" ) ;
314+ break ;
315+ } ;
316+ let mut state_guard = state_clone. lock ( ) . await ;
317+ state_guard. receive ( request) ;
318+ }
319+ }
320+
303321 tokio:: select! {
304322 biased;
305323
@@ -310,7 +328,7 @@ impl Scheduler {
310328 }
311329
312330 // Try Scheduling Requests - runs on normal interval or after simulation
313- _ = schedule_interval . tick ( ) => {
331+ _ = tokio :: task :: yield_now ( ) => {
314332 // Skip if we just ran scheduling after simulation to prevent consecutive runs
315333 if !should_schedule {
316334 continue ;
@@ -371,100 +389,117 @@ impl Scheduler {
371389 _ = cancel_token_clone. cancelled( ) => {
372390 break ;
373391 }
392+ }
374393
375- // Simulate running requests (prefill + decode)
376- _ = simulate_interval. tick( ) => {
377- let mut state_guard = state_clone. lock( ) . await ;
378- let mut kv_manager_guard = kv_manager_clone. lock( ) . await ;
379-
380- // Base time needed for decoding using active percentage and quadratic formula
381- let active_perc = kv_manager_guard. get_active_perc( ) ;
382- let decoding_time = -5.47 * active_perc. powi( 2 ) + 43.88 * active_perc + 19.44 ;
383- let mut total_time = Duration :: from_secs_f64( decoding_time / 1000.0 ) ;
384-
385- // Process prefilling
386- while let Some ( ( prefill_compute, maybe_creation_signal, is_full_prefill) ) = state_guard. try_prefill( ) {
387- // NOTE: Prefill cost/time is always incremented for new blocks, even if they
388- // could be cached by other requests in the same batch. This matches vLLM behavior.
389- total_time += Duration :: from_secs_f64( prefill_compute / 1000.0 ) ;
390-
391- if let Some ( creation_signal) = maybe_creation_signal {
392- if !process_signals( & mut kv_manager_guard, std:: slice:: from_ref( & creation_signal) ) {
393- panic!( "Block allocation for prefilling cannot fail." ) ;
394- }
395-
396- // Drain KV events and forward to relay after prefill signal processing
397- if let ( Some ( ref relay_tx) , Some ( ref mut rx) ) = ( & kv_events_tx, & mut block_resp_rx) {
398- while let Ok ( event) = rx. try_recv( ) {
399- let _ = relay_tx. send( block_response_to_kv_event( event) ) ;
400- }
401- }
402- } ;
403-
404- // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
405- if !is_full_prefill { break ; }
394+ // Simulates prefill + decode
395+ let mut state_guard = state_clone. lock ( ) . await ;
396+ let mut kv_manager_guard = kv_manager_clone. lock ( ) . await ;
397+
398+ // Base time needed for decoding using active percentage and quadratic formula
399+ let active_perc = kv_manager_guard. get_active_perc ( ) ;
400+ let decoding_time = -5.47 * active_perc. powi ( 2 ) + 43.88 * active_perc + 19.44 ;
401+ let mut total_time = Duration :: from_secs_f64 ( decoding_time / 1000.0 ) ;
402+
403+ // Process prefilling
404+ while let Some ( ( prefill_compute, maybe_creation_signal, is_full_prefill) ) =
405+ state_guard. try_prefill ( )
406+ {
407+ // NOTE: Prefill cost/time is always incremented for new blocks, even if they
408+ // could be cached by other requests in the same batch. This matches vLLM behavior.
409+ total_time += Duration :: from_secs_f64 ( prefill_compute / 1000.0 ) ;
410+
411+ if let Some ( creation_signal) = maybe_creation_signal {
412+ if !process_signals (
413+ & mut kv_manager_guard,
414+ std:: slice:: from_ref ( & creation_signal) ,
415+ ) {
416+ panic ! ( "Block allocation for prefilling cannot fail." ) ;
406417 }
407418
408- state_guard. reset_active_tokens( ) ;
409-
410- // Process decoding
411- let uuids: Vec <Uuid > = state_guard. decode. keys( ) . cloned( ) . collect( ) ;
412- if !uuids. is_empty( ) { should_schedule = true } ;
413- for uuid in uuids {
414- let Some ( sequence) = state_guard. run( uuid) else {
415- continue ;
416- } ;
417- let signals = sequence. generate( ) ;
418-
419- // Process all signals with the KvManager
420- // Handling of preemption on failure
421- if !process_signals( & mut kv_manager_guard, & signals) {
422- sequence. pop( ) ; // revert the failed generation op
423- for signal in state_guard. preempt( ) {
424- kv_manager_guard. process( & signal) ;
425- }
426- continue ;
419+ // Drain KV events and forward to relay after prefill signal processing
420+ if let ( Some ( ref relay_tx) , Some ( ref mut rx) ) =
421+ ( & kv_events_tx, & mut block_resp_rx)
422+ {
423+ while let Ok ( event) = rx. try_recv ( ) {
424+ let _ = relay_tx. send ( block_response_to_kv_event ( event) ) ;
427425 }
426+ }
427+ } ;
428428
429- // Drain KV events and forward to relay after decode signal processing
430- if let ( Some ( ref relay_tx) , Some ( ref mut rx) ) = ( & kv_events_tx, & mut block_resp_rx) {
431- while let Ok ( event) = rx. try_recv( ) {
432- let _ = relay_tx. send( block_response_to_kv_event( event) ) ;
433- }
434- }
429+ // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
430+ if !is_full_prefill {
431+ break ;
432+ }
433+ }
435434
436- // Check completion and send notification
437- let is_complete = sequence. generated_tokens( ) >= sequence. max_output_tokens( ) ;
438- let should_output = sequence. generated_tokens( ) > sequence. already_generated_tokens( ) ;
435+ state_guard. reset_active_tokens ( ) ;
436+
437+ // Process decoding
438+ let uuids: Vec < Uuid > = state_guard. decode . keys ( ) . cloned ( ) . collect ( ) ;
439+ if !uuids. is_empty ( ) {
440+ should_schedule = true
441+ } ;
442+ for uuid in uuids {
443+ let Some ( sequence) = state_guard. run ( uuid) else {
444+ continue ;
445+ } ;
446+ let signals = sequence. generate ( ) ;
447+
448+ // Process all signals with the KvManager
449+ // Handling of preemption on failure
450+ if !process_signals ( & mut kv_manager_guard, & signals) {
451+ sequence. pop ( ) ; // revert the failed generation op
452+ for signal in state_guard. preempt ( ) {
453+ kv_manager_guard. process ( & signal) ;
454+ }
455+ continue ;
456+ }
439457
440- let mut send_failed = false ;
441- if should_output {
442- send_failed = output_tx_clone. as_ref( ) . is_some_and( |tx| {
443- tx. send( OutputSignal { uuid, completed: is_complete } ) . is_err( )
444- } ) ;
445- }
458+ // Drain KV events and forward to relay after decode signal processing
459+ if let ( Some ( ref relay_tx) , Some ( ref mut rx) ) =
460+ ( & kv_events_tx, & mut block_resp_rx)
461+ {
462+ while let Ok ( event) = rx. try_recv ( ) {
463+ let _ = relay_tx. send ( block_response_to_kv_event ( event) ) ;
464+ }
465+ }
446466
447- if send_failed {
448- for signal in & sequence. free_signal( ) {
449- kv_manager_guard. process( signal) ;
450- }
451- }
467+ // Check completion and send notification
468+ let is_complete = sequence. generated_tokens ( ) >= sequence. max_output_tokens ( ) ;
469+ let should_output =
470+ sequence. generated_tokens ( ) > sequence. already_generated_tokens ( ) ;
471+
472+ let mut send_failed = false ;
473+ if should_output {
474+ send_failed = output_tx_clone. as_ref ( ) . is_some_and ( |tx| {
475+ tx. send ( OutputSignal {
476+ uuid,
477+ completed : is_complete,
478+ } )
479+ . is_err ( )
480+ } ) ;
481+ }
452482
453- if send_failed || is_complete {
454- state_guard. complete( & uuid) ;
455- continue ;
456- }
483+ if send_failed {
484+ for signal in & sequence. free_signal ( ) {
485+ kv_manager_guard. process ( signal) ;
457486 }
487+ }
458488
459- // Sleep once for the adjusted duration
460- drop( kv_manager_guard) ;
461- drop( state_guard) ;
462- let adjusted_time = Duration :: from_secs_f64( total_time. as_secs_f64( ) / args. speedup_ratio) ;
463- if adjusted_time. as_millis( ) > 0 {
464- tokio:: time:: sleep( adjusted_time) . await ;
465- }
489+ if send_failed || is_complete {
490+ state_guard. complete ( & uuid) ;
491+ continue ;
466492 }
467493 }
494+
495+ // Sleep once for the adjusted duration
496+ drop ( kv_manager_guard) ;
497+ drop ( state_guard) ;
498+ let adjusted_time =
499+ Duration :: from_secs_f64 ( total_time. as_secs_f64 ( ) / args. speedup_ratio ) ;
500+ if adjusted_time. as_millis ( ) > 0 {
501+ tokio:: time:: sleep ( adjusted_time) . await ;
502+ }
468503 }
469504 } ) ;
470505
@@ -632,6 +667,7 @@ mod tests {
632667 use super :: * ;
633668 use rstest:: rstest;
634669 use std:: time:: Duration ;
670+ use tokio:: time:: interval;
635671
636672 #[ rstest]
637673 #[ case:: case_1( false , false , false ) ]
0 commit comments