1616from  vllm .lora .request  import  LoRARequest 
1717from  vllm .prompt_adapter .request  import  PromptAdapterRequest 
1818from  vllm .sequence  import  (Sequence , SequenceData , SequenceGroup ,
19-                            SequenceGroupMetadata , SequenceGroupMetadataDelta ,
20-                            SequenceStage , SequenceStatus )
19+                            SequenceGroupBase , SequenceGroupMetadata ,
20+                            SequenceGroupMetadataDelta , SequenceStage ,
21+                            SequenceStatus )
2122from  vllm .utils  import  Device , PyObjectCache 
2223
2324logger  =  init_logger (__name__ )
@@ -561,7 +562,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
561562        # Only for testing purposes. 
562563        self .swapped .append (seq_group )
563564
564-     def  abort_seq_group (self , request_id : Union [str , Iterable [str ]]) ->  None :
565+     def  abort_seq_group (
566+         self ,
567+         request_id : Union [str , Iterable [str ]],
568+         seq_id_to_seq_group : Optional [Dict [str , SequenceGroupBase ]] =  None ,
569+     ) ->  None :
565570        """Aborts a sequence group with the given ID. 
566571
567572        Check if the sequence group with the given ID 
@@ -573,21 +578,29 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
573578
574579        Args: 
575580            request_id: The ID(s) of the sequence group to abort. 
581+             seq_id_to_seq_group: helper for groups with n>1 
576582        """ 
577583        if  isinstance (request_id , str ):
578584            request_id  =  (request_id , )
579585        request_ids  =  set (request_id )
586+         seq_id_to_seq_group  =  seq_id_to_seq_group  or  {}
580587        for  state_queue  in  [self .waiting , self .running , self .swapped ]:
581588            aborted_groups : List [SequenceGroup ] =  []
582589            for  seq_group  in  state_queue :
583-                 if  not  request_ids :
584-                     # Using 'break' here may add two extra iterations, 
585-                     # but is acceptable to reduce complexity. 
586-                     break 
587-                 if  seq_group .request_id  in  request_ids :
590+                 # When n>1, seq_group.request_id looks like 
591+                 # foo_parallel_sample_0, while request_ids is just foo, and we 
592+                 # should resolve it as real_request_id to match. 
593+                 if  seq_group .request_id  in  seq_id_to_seq_group :
594+                     real_request_id  =  seq_id_to_seq_group [
595+                         seq_group .request_id ].group_id 
596+                 else :
597+                     real_request_id  =  seq_group .request_id 
598+                 if  real_request_id  in  request_ids :
588599                    # Appending aborted group into pending list. 
589600                    aborted_groups .append (seq_group )
590-                     request_ids .remove (seq_group .request_id )
601+                     # We can't remove real_request_id in request_ids here, 
602+                     # because there may be other seq groups sharing the same 
603+                     # real_request_id 
591604            for  aborted_group  in  aborted_groups :
592605                # Remove the sequence group from the state queue. 
593606                state_queue .remove (aborted_group )
@@ -598,6 +611,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
598611                        continue 
599612                    seq .status  =  SequenceStatus .FINISHED_ABORTED 
600613                    self .free_seq (seq )
614+                 if  aborted_group .request_id  in  seq_id_to_seq_group :
615+                     del  seq_id_to_seq_group [aborted_group .request_id ]
601616
602617                self ._free_seq_group_cross_attn_blocks (aborted_group )
603618
0 commit comments