@@ -195,7 +195,9 @@ def _initialize_kv_caches(
195195 "warmup model) took %.2f seconds" ), elapsed )
196196 return num_gpu_blocks , num_cpu_blocks , scheduler_kv_cache_config
197197
198- def add_request (self , request : Request ):
198+ def add_request (self , request : Union [EngineCoreRequest , Request ]):
199+ if type (request ) is EngineCoreRequest :
200+ request = self ._preprocess_add_request (request )
199201 """Add request to the scheduler."""
200202 if pooling_params := request .pooling_params :
201203 supported_pooling_tasks = (
@@ -204,13 +206,13 @@ def add_request(self, request: Request):
204206 raise ValueError (f"Unsupported task: { pooling_params .task !r} "
205207 f"Supported tasks: { supported_pooling_tasks } " )
206208
207- if request .mm_hashes is not None :
209+ if request .mm_hashes :
208210 # Here, if hash exists for a multimodal input, then it will be
209211 # fetched from the cache, else it will be added to the cache.
210212 # Note that the cache here is mirrored with the client cache, so
211213 # anything that has a hash must have a HIT cache entry here
212214 # as well.
213- assert request .mm_inputs is not None
215+ assert request .mm_inputs
214216 updated_mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
215217 request .mm_inputs , request .mm_hashes )
216218 assert isinstance (updated_mm_inputs , list )
@@ -389,6 +391,13 @@ def save_tensorized_model(
389391 self .model_executor .save_tensorized_model (
390392 tensorizer_config = tensorizer_config , )
391393
394+ def _preprocess_add_request (self , request : EngineCoreRequest ) -> Request :
395+ """Preprocess the request.
396+
397+ This function could be directly used in input processing thread to allow
398+ request initialization running in parallel with Model forward"""
399+ return Request .from_engine_core_request (request )
400+
392401
393402class EngineCoreProc (EngineCore ):
394403 """ZMQ-wrapper for running EngineCore in background process."""
@@ -772,7 +781,7 @@ def process_input_sockets(self, input_addresses: list[str],
772781 # Deserialize the request data.
773782 if request_type == EngineCoreRequestType .ADD :
774783 request = add_request_decoder .decode (data_frames )
775- request = self ._post_process_add_request (request )
784+ request = self ._preprocess_add_request (request )
776785 else :
777786 request = generic_decoder .decode (data_frames )
778787
@@ -840,13 +849,6 @@ def process_output_sockets(self, output_paths: list[str],
840849 # Limit the number of buffers to reuse.
841850 reuse_buffers .append (buffer )
842851
843- def _post_process_add_request (self , request : EngineCoreRequest ) -> Request :
844- """Post-processes the request before reaching to EngineCore.
845-
846- This call would be executed in parallel with Model forward which
847- relaxes request preparation works out from critical path."""
848- return Request .from_engine_core_request (request )
849-
850852
851853class DPEngineCoreProc (EngineCoreProc ):
852854 """ZMQ-wrapper for running EngineCore in background process
@@ -927,7 +929,7 @@ def shutdown(self):
927929 if dp_group := getattr (self , "dp_group" , None ):
928930 stateless_destroy_torch_distributed_process_group (dp_group )
929931
930- def add_request (self , request : Request ):
932+ def add_request (self , request : Union [ EngineCoreRequest , Request ] ):
931933 if self .has_coordinator and request .current_wave != self .current_wave :
932934 if request .current_wave > self .current_wave :
933935 self .current_wave = request .current_wave
0 commit comments