@@ -36,6 +36,7 @@ def __init__(
3636 blocks_to_swap_out : Dict [int , int ],
3737 blocks_to_copy : Dict [int , List [int ]],
3838 ignored_seq_groups : List [SequenceGroup ],
39+ finished_seqs : List [int ],
3940 ) -> None :
4041 self .scheduled_seq_groups = scheduled_seq_groups
4142 self .prompt_run = prompt_run
@@ -46,11 +47,13 @@ def __init__(
4647 # Swap in and swap out should never happen at the same time.
4748 assert not (blocks_to_swap_in and blocks_to_swap_out )
4849 self .ignored_seq_groups = ignored_seq_groups
50+ self .finished_seqs = finished_seqs
4951
5052 def is_empty (self ) -> bool :
5153 # NOTE: We do not consider the ignored sequence groups.
5254 return (not self .scheduled_seq_groups and not self .blocks_to_swap_in
53- and not self .blocks_to_swap_out and not self .blocks_to_copy )
55+ and not self .blocks_to_swap_out and not self .blocks_to_copy
56+ and not self .finished_seqs )
5457
5558
5659class Scheduler :
@@ -417,6 +420,7 @@ def __init__(
417420 self .waiting : List [SequenceGroup ] = []
418421 # Sequence groups in the RUNNING state.
419422 self .running : List [SequenceGroup ] = []
423+ self .cleaned : List [int ] = []
420424
421425 def add_seq_group (self , seq_group : SequenceGroup ) -> None :
422426 # Add sequence groups to the waiting queue.
@@ -456,6 +460,8 @@ def _schedule(self) -> SchedulerOutputs:
456460
457461 ignored_seq_groups : List [SequenceGroup ] = []
458462 scheduled : List [SequenceGroup ] = []
463+ finished_seqs : List [int ] = self .cleaned .copy ()
464+ self .cleaned = []
459465 # The total number of sequences on the fly, including the
460466 # requests in the generation phase.
461467 num_curr_seqs = sum (seq_group .get_max_num_running_seqs ()
@@ -518,6 +524,7 @@ def _schedule(self) -> SchedulerOutputs:
518524 blocks_to_swap_out = {},
519525 blocks_to_copy = {},
520526 ignored_seq_groups = ignored_seq_groups ,
527+ finished_seqs = finished_seqs ,
521528 )
522529 return scheduler_outputs
523530
@@ -539,6 +546,7 @@ def _schedule(self) -> SchedulerOutputs:
539546 blocks_to_swap_out = {},
540547 blocks_to_copy = {},
541548 ignored_seq_groups = [],
549+ finished_seqs = finished_seqs ,
542550 )
543551 return scheduler_outputs
544552
@@ -576,7 +584,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
576584 self .block_manager .fork (parent_seq , child_seq )
577585
578586 def free_seq (self , seq : Sequence ) -> None :
579- self .block_manager .free (seq )
587+ #self.block_manager.free(seq)
588+ self .cleaned .append (seq .seq_id )
580589
581590 def free_finished_seq_groups (self ) -> None :
582591 for seq_group in self .running :
0 commit comments