@@ -72,6 +72,9 @@ pub enum SlotState {
7272
7373 /// The slot is finished and all resources have been released.
7474 Finished ,
75+
76+ /// The slot is preempted and is waiting for the next iteration to resume.
77+ Preempted ,
7578}
7679
7780pub trait Slot : std:: fmt:: Debug {
@@ -122,6 +125,9 @@ pub trait Slot: std::fmt::Debug {
122125
123126 /// Record the number of tokens that were cached on the disk.
124127 fn record_cached_disk_tokens ( & mut self , num_tokens : usize ) ;
128+
129+ /// Reset the slot after preemption.
130+ fn reset_after_preemption ( & mut self ) -> Result < ( ) , SlotError > ;
125131}
126132
127133pub trait ExternallyManagedDeviceSlot : Slot {
@@ -341,6 +347,22 @@ impl Slot for VllmConnectorSlot {
341347 self . state
342348 }
343349
350+ fn reset_after_preemption ( & mut self ) -> Result < ( ) , SlotError > {
351+ assert ! ( self . staging_from_disk. is_none( ) ) ;
352+ assert ! ( self . staging_from_host. is_none( ) ) ;
353+ assert ! ( self . pending_operations. is_none( ) ) ;
354+
355+ self . state = SlotState :: Preempted ;
356+ self . iteration_first_scheduled = None ;
357+ self . current_position = 0 ;
358+ self . evaluated_blocks = 0 ;
359+ self . device_blocks . clear ( ) ;
360+ self . tokens_cached_from_device = 0 ;
361+ self . tokens_cached_from_host = 0 ;
362+ self . tokens_cached_from_disk = 0 ;
363+ Ok ( ( ) )
364+ }
365+
344366 fn record_cached_device_tokens ( & mut self , num_tokens : usize ) {
345367 self . tokens_cached_from_device = num_tokens;
346368 tracing:: debug!( "recording {} cached device tokens" , num_tokens, ) ;
@@ -511,13 +533,17 @@ impl Slot for VllmConnectorSlot {
511533 return Ok ( ( ) ) ;
512534 }
513535
514- if !matches ! ( self . state( ) , SlotState :: Initialized ) {
536+ if !matches ! ( self . state( ) , SlotState :: Initialized | SlotState :: Preempted ) {
515537 return Err ( SlotError :: InvalidOperation ( format ! (
516538 "slot must be in the NotScheduled state to acquire local matches; got {:?}" ,
517539 self . state( )
518540 ) ) ) ;
519541 }
520542
543+ if matches ! ( self . state( ) , SlotState :: Preempted ) {
544+ tracing:: info!( "slot is in the Preempted state; we get another chance to match" ) ;
545+ }
546+
521547 let block_size = self . block_manager . block_size ( ) ;
522548 let num_computed_blocks = num_computed_tokens / block_size;
523549 debug_assert ! ( num_computed_tokens % block_size == 0 ) ;
0 commit comments