11# SPDX-License-Identifier: Apache-2.0
22
33import asyncio
4+ from collections .abc import Iterable
45from dataclasses import dataclass
56from typing import Optional , Union
67
@@ -102,33 +103,32 @@ def make_request_output(
102103 ) -> Optional [RequestOutput ]:
103104
104105 finished = finish_reason is not None
105- output_kind = self .output_kind
106- final_only = output_kind == RequestOutputKind .FINAL_ONLY
106+ final_only = self .output_kind == RequestOutputKind .FINAL_ONLY
107107
108108 # In follow up, we will switch to invariant where EngineCore
109109 # does not stream partial prefills.
110110 if not finished and (self .is_prefilling or final_only ):
111111 # Only the final output is required in FINAL_ONLY mode.
112112 return None
113113
114- def new_request_output (request_id : str ) -> RequestOutput :
115- return self ._new_request_output (request_id , finished )
116-
117114 completion_output = self ._new_completion_output (
118115 new_token_ids , finish_reason , stop_reason )
119116
120- if self .parent_req is not None :
121- return self .parent_req .make_request_output (final_only ,
122- completion_output ,
123- new_request_output )
117+ request_id = self .request_id
118+ if self .parent_req is None :
119+ outputs = [completion_output ]
120+ else :
121+ request_id , outputs , finished = self .parent_req .get_outputs (
122+ request_id , completion_output )
123+ if not outputs :
124+ return None
124125
125- request_output = new_request_output (self .request_id )
126- request_output .outputs .append (completion_output )
127- return request_output
126+ return self ._new_request_output (request_id , outputs , finished )
128127
129128 def _new_request_output (
130129 self ,
131130 request_id : str ,
131+ outputs : list [CompletionOutput ],
132132 finished : bool ,
133133 ) -> RequestOutput :
134134
@@ -143,7 +143,7 @@ def _new_request_output(
143143 prompt = self .prompt ,
144144 prompt_token_ids = self .prompt_token_ids ,
145145 prompt_logprobs = prompt_logprobs ,
146- outputs = [] ,
146+ outputs = outputs ,
147147 finished = finished ,
148148 )
149149
@@ -188,6 +188,7 @@ def __init__(
188188 self .log_stats = log_stats
189189 self .tokenizer = tokenizer
190190 self .request_states : dict [str , RequestState ] = {}
191+ self .parent_requests : dict [str , ParentRequest ] = {}
191192 self .lora_states = LoRARequestStates ()
192193
193194 def get_num_unfinished_requests (self ):
@@ -198,14 +199,20 @@ def has_unfinished_requests(self) -> bool:
198199
199200 def abort_requests (
200201 self ,
201- request_ids : list [str ],
202- ) -> None :
202+ request_ids : Iterable [str ],
203+ ) -> list [str ]:
204+ request_ids_to_abort = []
203205 for request_id in request_ids :
204206 req_state = self .request_states .pop (request_id , None )
205207 if req_state is not None :
206208 self .lora_states .abort_request (req_state )
207- if req_state .parent_req is not None :
208- req_state .parent_req .finish_child_request (request_id )
209+ request_ids_to_abort .append (request_id )
210+ else :
211+ parent = self .parent_requests .pop (request_id , None )
212+ if parent and parent .child_requests :
213+ self .abort_requests (parent .child_requests )
214+ request_ids_to_abort .extend (parent .child_requests )
215+ return request_ids_to_abort
209216
210217 def add_request (
211218 self ,
@@ -227,6 +234,8 @@ def add_request(
227234 log_stats = self .log_stats )
228235 self .request_states [request_id ] = req_state
229236 self .lora_states .add_request (req_state )
237+ if parent_req :
238+ self .parent_requests [parent_req .request_id ] = parent_req
230239
231240 def process_outputs (
232241 self ,
@@ -314,12 +323,14 @@ def process_outputs(
314323 # Free completed requests.
315324 if finish_reason is not None :
316325 self .request_states .pop (req_id )
326+ # Remove parent request if applicable.
327+ parent_req = req_state .parent_req
328+ if parent_req and not parent_req .child_requests :
329+ self .parent_requests .pop (parent_req .request_id , None )
317330 if not engine_core_output .finished :
318331 # If req not finished in EngineCore, but Detokenizer
319332 # detected stop string, abort needed in EngineCore.
320333 reqs_to_abort .append (req_id )
321- if req_state .parent_req is not None :
322- req_state .parent_req .finish_child_request (req_id )
323334
324335 # Track per-request stats
325336 self ._update_stats_from_finished (req_state , finish_reason ,
0 commit comments