4141from vllm .multimodal .utils import group_mm_inputs_by_modality
4242from vllm .pooling_params import PoolingParams , PoolingTask
4343from vllm .sampling_params import SamplingType
44- from vllm .sequence import IntermediateTensors
44+ from vllm .sequence import IntermediateTensors , PoolerOutput
4545from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
4646 GiB_bytes , LazyLoader , check_use_alibi , get_dtype_size ,
4747 is_pin_memory_available , round_up )
@@ -1819,7 +1819,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
18191819 old_global_expert_indices = None
18201820 rank_mapping = None
18211821
1822- with DeviceMemoryProfiler () as m : # noqa: SIM117
1822+ with DeviceMemoryProfiler () as m :
18231823 time_before_load = time .perf_counter ()
18241824 model_loader = get_model_loader (self .load_config )
18251825 if not hasattr (self , "model" ):
@@ -2215,12 +2215,11 @@ def _dummy_sampler_run(
22152215 )
22162216 return sampler_output
22172217
2218- @torch .inference_mode ()
2219- def _dummy_pooler_run (
2218+ def _dummy_pooler_run_task (
22202219 self ,
22212220 hidden_states : torch .Tensor ,
2222- ) -> torch . Tensor :
2223-
2221+ task : PoolingTask ,
2222+ ) -> PoolerOutput :
22242223 num_tokens = hidden_states .shape [0 ]
22252224 max_num_reqs = self .scheduler_config .max_num_seqs
22262225 num_reqs = min (num_tokens , max_num_reqs )
@@ -2232,37 +2231,55 @@ def _dummy_pooler_run(
22322231
22332232 hidden_states_list = list (
22342233 torch .split (hidden_states , num_scheduled_tokens_list ))
2235-
22362234 req_num_tokens = num_tokens // num_reqs
22372235
2238- model = cast (VllmModelForPooling , self .model )
2239- dummy_task = self .get_supported_pooling_tasks ()[0 ]
2240- dummy_pooling_params = PoolingParams (task = dummy_task )
2236+ dummy_prompt_lens = torch .tensor (
2237+ [h .shape [0 ] for h in hidden_states_list ],
2238+ device = self .device ,
2239+ )
2240+ dummy_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
2241+ dtype = torch .int32 ,
2242+ device = self .device )
22412243
2242- to_update = model .pooler .get_pooling_updates (dummy_task )
2244+ model = cast (VllmModelForPooling , self .model )
2245+ dummy_pooling_params = PoolingParams (task = task )
2246+ to_update = model .pooler .get_pooling_updates (task )
22432247 to_update .apply (dummy_pooling_params )
22442248
22452249 dummy_metadata = PoolingMetadata (
2246- prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
2247- device = self .device ),
2248- prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
2249- dtype = torch .int32 ,
2250- device = self .device ),
2251- pooling_params = [dummy_pooling_params ] * num_reqs )
2250+ prompt_lens = dummy_prompt_lens ,
2251+ prompt_token_ids = dummy_token_ids ,
2252+ pooling_params = [dummy_pooling_params ] * num_reqs ,
2253+ )
22522254
22532255 try :
2254- pooler_output = model .pooler (hidden_states = hidden_states_list ,
2255- pooling_metadata = dummy_metadata )
2256+ return model .pooler (hidden_states = hidden_states_list ,
2257+ pooling_metadata = dummy_metadata )
22562258 except RuntimeError as e :
22572259 if 'out of memory' in str (e ):
22582260 raise RuntimeError (
2259- "CUDA out of memory occurred when warming up pooler with "
2260- f"{ num_reqs } dummy requests. Please try lowering "
2261- "`max_num_seqs` or `gpu_memory_utilization` when "
2261+ "CUDA out of memory occurred when warming up pooler "
2262+ f"( { task = } ) with { num_reqs } dummy requests. Please try "
2263+ "lowering `max_num_seqs` or `gpu_memory_utilization` when "
22622264 "initializing the engine." ) from e
22632265 else :
22642266 raise e
2265- return pooler_output
2267+
2268+ @torch .inference_mode ()
2269+ def _dummy_pooler_run (
2270+ self ,
2271+ hidden_states : torch .Tensor ,
2272+ ) -> PoolerOutput :
2273+ # Find the task that has the largest output for subsequent steps
2274+ output_size = dict [PoolingTask , float ]()
2275+ for task in self .get_supported_pooling_tasks ():
2276+ # Run a full batch with each task to ensure none of them OOMs
2277+ output = self ._dummy_pooler_run_task (hidden_states , task )
2278+ output_size [task ] = output .get_data_nbytes ()
2279+ del output # Allow GC
2280+
2281+ max_task = max (output_size .items (), key = lambda x : x [1 ])[0 ]
2282+ return self ._dummy_pooler_run_task (hidden_states , max_task )
22662283
22672284 def profile_run (self ) -> None :
22682285 # Profile with multimodal encoder & encoder cache.
0 commit comments