@@ -205,8 +205,12 @@ def _initialize_kv_caches(
205205 def get_supported_tasks (self ) -> tuple [SupportedTask , ...]:
206206 return self .model_executor .supported_tasks
207207
208- def add_request (self , request : EngineCoreRequest ):
209- """Add request to the scheduler."""
208+ def add_request (self , request : Request , request_wave : int = 0 ):
209+ """Add request to the scheduler.
210+
211+ `request_wave`: indicate which wave of requests this is expected to
212+ belong to in DP case
213+ """
210214 # Validate the request_id type.
211215 if not isinstance (request .request_id , str ):
212216 raise TypeError (
@@ -222,27 +226,12 @@ def add_request(self, request: EngineCoreRequest):
222226 raise ValueError (f"Unsupported task: { pooling_params .task !r} "
223227 f"Supported tasks: { supported_pooling_tasks } " )
224228
225- if request .mm_hashes is not None :
226- # Here, if hash exists for a multimodal input, then it will be
227- # fetched from the cache, else it will be added to the cache.
228- # Note that the cache here is mirrored with the client cache, so
229- # anything that has a hash must have a HIT cache entry here
230- # as well.
231- assert request .mm_inputs is not None
232- request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
233- request .mm_inputs , request .mm_hashes )
234-
235- req = Request .from_engine_core_request (request )
236- if req .use_structured_output :
237- # Start grammar compilation asynchronously
238- self .structured_output_manager .grammar_init (req )
239-
240- if req .kv_transfer_params is not None and (
229+ if request .kv_transfer_params is not None and (
241230 not self .scheduler .get_kv_connector ()):
242231 logger .warning ("Got kv_transfer_params, but no KVConnector found. "
243232 "Disabling KVTransfer for this request." )
244233
245- self .scheduler .add_request (req )
234+ self .scheduler .add_request (request )
246235
247236 def abort_requests (self , request_ids : list [str ]):
248237 """Abort requests from the scheduler."""
@@ -414,6 +403,31 @@ def save_tensorized_model(
414403 self .model_executor .save_tensorized_model (
415404 tensorizer_config = tensorizer_config , )
416405
406+ def preprocess_add_request (
407+ self , request : EngineCoreRequest ) -> tuple [Request , int ]:
408+ """Preprocess the request.
409+
410+ This function could be directly used in input processing thread to allow
411+ request initialization running in parallel with Model forward
412+ """
413+ if request .mm_hashes is not None :
414+ assert request .mm_inputs is not None
415+ # Note on thread safety: no race condition.
416+ # `mm_input_cache_server` is reset at the end of LLMEngine init,
417+ # and will only accessed in the input processing thread afterwards.
418+ request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
419+ request .mm_inputs , request .mm_hashes )
420+
421+ req = Request .from_engine_core_request (request )
422+ if req .use_structured_output :
423+ # Note on thread safety: no race condition.
424+ # `grammar_init` is only invoked in input processing thread. For
425+ # `structured_output_manager`, each request is independent and
426+ # grammar compilation is async. Scheduler always checks grammar
427+ # compilation status before scheduling request.
428+ self .structured_output_manager .grammar_init (req )
429+ return req , request .current_wave
430+
417431
418432class EngineCoreProc (EngineCore ):
419433 """ZMQ-wrapper for running EngineCore in background process."""
@@ -707,7 +721,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
707721 """Dispatch request from client."""
708722
709723 if request_type == EngineCoreRequestType .ADD :
710- self .add_request (request )
724+ req , request_wave = request
725+ self .add_request (req , request_wave )
711726 elif request_type == EngineCoreRequestType .ABORT :
712727 self .abort_requests (request )
713728 elif request_type == EngineCoreRequestType .UTILITY :
@@ -806,10 +821,11 @@ def process_input_sockets(self, input_addresses: list[str],
806821 bytes (type_frame .buffer ))
807822
808823 # Deserialize the request data.
809- decoder = add_request_decoder if (
810- request_type
811- == EngineCoreRequestType .ADD ) else generic_decoder
812- request = decoder .decode (data_frames )
824+ if request_type == EngineCoreRequestType .ADD :
825+ request = add_request_decoder .decode (data_frames )
826+ request = self .preprocess_add_request (request )
827+ else :
828+ request = generic_decoder .decode (data_frames )
813829
814830 # Push to input queue for core busy loop.
815831 self .input_queue .put_nowait ((request_type , request ))
@@ -939,17 +955,17 @@ def shutdown(self):
939955 if dp_group := getattr (self , "dp_group" , None ):
940956 stateless_destroy_torch_distributed_process_group (dp_group )
941957
942- def add_request (self , request : EngineCoreRequest ):
943- if self .has_coordinator and request . current_wave != self .current_wave :
944- if request . current_wave > self .current_wave :
945- self .current_wave = request . current_wave
958+ def add_request (self , request : Request , request_wave : int = 0 ):
959+ if self .has_coordinator and request_wave != self .current_wave :
960+ if request_wave > self .current_wave :
961+ self .current_wave = request_wave
946962 elif not self .engines_running :
947963 # Request received for an already-completed wave, notify
948964 # front-end that we need to start the next one.
949965 self .output_queue .put_nowait (
950966 (- 1 , EngineCoreOutputs (start_wave = self .current_wave )))
951967
952- super ().add_request (request )
968+ super ().add_request (request , request_wave )
953969
954970 def _handle_client_request (self , request_type : EngineCoreRequestType ,
955971 request : Any ) -> None :
0 commit comments