2222from vllm .usage .usage_lib import UsageContext
2323from vllm .v1 .engine .core_client import EngineCoreClient
2424from vllm .v1 .engine .output_processor import OutputProcessor
25- from vllm .v1 .engine .parallel_sampling import SyncParallelSamplingManager
25+ from vllm .v1 .engine .parallel_sampling import ParentRequest
2626from vllm .v1 .engine .processor import Processor
2727from vllm .v1 .executor .abstract import Executor
2828
@@ -50,9 +50,6 @@ def __init__(
5050 self .model_config = vllm_config .model_config
5151 self .cache_config = vllm_config .cache_config
5252
53- # Bookkeeping for parallel sampling requests
54- self .parallel_manager = SyncParallelSamplingManager ()
55-
5653 # important: init dp group before init the engine_core
5754 self .parallel_config = vllm_config .parallel_config
5855 self .dp_enabled = self .parallel_config .data_parallel_size > 1 # noqa
@@ -120,8 +117,7 @@ def from_engine_args(
120117 multiprocess_mode = enable_multiprocessing )
121118
122119 def get_num_unfinished_requests (self ) -> int :
123- return self .parallel_manager .get_num_unfinished_requests (
124- self .output_processor .get_num_unfinished_requests ())
120+ return self .output_processor .get_num_unfinished_requests ()
125121
126122 def has_unfinished_requests (self ) -> bool :
127123 has_unfinished = self .output_processor .has_unfinished_requests ()
@@ -157,48 +153,25 @@ def add_request(
157153 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
158154 priority : int = 0 ,
159155 ) -> None :
160- """Add request."""
161- kwargs = dict (request_id = request_id ,
162- prompt = prompt ,
163- params = params ,
164- arrival_time = arrival_time ,
165- lora_request = lora_request ,
166- trace_headers = trace_headers ,
167- prompt_adapter_request = prompt_adapter_request ,
168- priority = priority )
169- # Handle parallel sampling requests differently.
170- if params is None or isinstance (params ,
171- PoolingParams ) or params .n == 1 :
172- self ._add_request (** kwargs )
173- else :
174- # Special handling for parallel sampling requests
175- self .parallel_manager .add_request_parallel_sampling (
176- add_request = self ._add_request , ** kwargs )
177-
178- def _add_request (
179- self ,
180- request_id : str ,
181- prompt : PromptType ,
182- params : Union [SamplingParams , PoolingParams ],
183- arrival_time : Optional [float ] = None ,
184- lora_request : Optional [LoRARequest ] = None ,
185- trace_headers : Optional [Mapping [str , str ]] = None ,
186- prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
187- priority : int = 0 ,
188- ) -> None :
189- """Add request, `n=1`"""
190- # 1) Process raw inputs into the request.
191- request = self .processor .process_inputs (request_id , prompt , params ,
192- arrival_time , lora_request ,
193- trace_headers ,
194- prompt_adapter_request ,
195- priority )
196-
197- # 2) Make a new RequestState and queue.
198- self .output_processor .add_request (request )
199-
200- # 3) Add the request to EngineCore.
201- self .engine_core .add_request (request )
156+ # 1) Fan out child requests (for n>1)
157+ parent_req = ParentRequest .from_params (request_id , params )
158+ n = params .n if isinstance (params , SamplingParams ) else 1
159+ for idx in range (n ):
160+ if parent_req is not None :
161+ request_id , params = parent_req .get_child_info (idx )
162+
163+ # 2) Process raw inputs into the request.
164+ request = self .processor .process_inputs (request_id , prompt , params ,
165+ arrival_time , lora_request ,
166+ trace_headers ,
167+ prompt_adapter_request ,
168+ priority )
169+
170+ # 3) Make a new RequestState and queue.
171+ self .output_processor .add_request (request , parent_req , idx )
172+
173+ # 3) Add the request to EngineCore.
174+ self .engine_core .add_request (request )
202175
203176 def step (self ) -> list [RequestOutput ]:
204177
@@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]:
217190 # 3) Abort any reqs that finished due to stop strings.
218191 self .engine_core .abort_requests (processed_outputs .reqs_to_abort )
219192
220- request_outputs = processed_outputs .request_outputs
221-
222- # 4) Process unfinished parallel sampling requests
223- return self .parallel_manager .step (request_outputs )
193+ return processed_outputs .request_outputs
224194
225195 def get_model_config (self ):
226196 return self .model_config
0 commit comments