File tree Expand file tree Collapse file tree 2 files changed +17
-18
lines changed Expand file tree Collapse file tree 2 files changed +17
-18
lines changed Original file line number Diff line number Diff line change @@ -1288,9 +1288,18 @@ def _dummy_sampler_run(
12881288 allowed_token_ids_mask = None ,
12891289 bad_words_token_ids = {},
12901290 )
1291- sampler_output = self .model .sample (logits = logits ,
1292- sampling_metadata = dummy_metadata )
1293-
1291+ try :
1292+ sampler_output = self .model .sample (
1293+ logits = logits , sampling_metadata = dummy_metadata )
1294+ except RuntimeError as e :
1295+ if 'out of memory' in str (e ):
1296+ raise RuntimeError (
1297+ "CUDA out of memory occurred when warming up sampler with "
1298+ f"{ num_reqs } dummy requests. Please try lowering "
1299+ "`max_num_seqs` or `gpu_memory_utilization` when "
1300+ "initializing the engine." ) from e
1301+ else :
1302+ raise e
12941303 return sampler_output
12951304
12961305 def profile_run (self ) -> None :
Original file line number Diff line number Diff line change @@ -221,21 +221,11 @@ def compile_or_warm_up_model(self) -> None:
221221 # NOTE: This is called after `capture_model` on purpose to prevent
222222 # memory buffers from being cleared by `torch.cuda.empty_cache`.
223223 if get_pp_group ().is_last_rank :
224- try :
225- max_num_reqs = min (
226- self .scheduler_config .max_num_seqs ,
227- self .scheduler_config .max_num_batched_tokens )
228- self .model_runner ._dummy_sampler_run (
229- hidden_states = self .model_runner ._dummy_run (
230- num_tokens = max_num_reqs ))
231- except RuntimeError as e :
232- if 'out of memory' in str (e ):
233- raise RuntimeError (
234- "CUDA out of memory occurred when warming up sampler. "
235- "Please try lowering `gpu_memory_utilization` when "
236- "initializing the engine." ) from None
237- else :
238- raise e
224+ max_num_reqs = min (self .scheduler_config .max_num_seqs ,
225+ self .scheduler_config .max_num_batched_tokens )
226+ self .model_runner ._dummy_sampler_run (
227+ hidden_states = self .model_runner ._dummy_run (
228+ num_tokens = max_num_reqs ))
239229
240230 # Reset the seed to ensure that the random state is not affected by
241231 # the model initialization and profiling.
You can’t perform that action at this time.
0 commit comments