|
14 | 14 | from vllm.distributed import (ensure_model_parallel_initialized, |
15 | 15 | init_distributed_environment, |
16 | 16 | set_custom_all_reduce) |
| 17 | +from vllm.distributed.parallel_state import get_pp_group |
17 | 18 | from vllm.logger import init_logger |
18 | 19 | from vllm.lora.request import LoRARequest |
19 | 20 | from vllm.model_executor import set_random_seed |
@@ -219,20 +220,22 @@ def compile_or_warm_up_model(self) -> None: |
219 | 220 | # fragmentation issue. |
220 | 221 | # NOTE: This is called after `capture_model` on purpose to prevent |
221 | 222 | # memory buffers from being cleared by `torch.cuda.empty_cache`. |
222 | | - try: |
223 | | - max_num_reqs = min(self.scheduler_config.max_num_seqs, |
224 | | - self.scheduler_config.max_num_batched_tokens) |
225 | | - self.model_runner._dummy_sampler_run( |
226 | | - hidden_states=self.model_runner._dummy_run( |
227 | | - num_tokens=max_num_reqs)) |
228 | | - except RuntimeError as e: |
229 | | - if 'out of memory' in str(e): |
230 | | - raise RuntimeError( |
231 | | - "CUDA out of memory occurred when warming up sampler. " |
232 | | - "Please try lowering `gpu_memory_utilization` when " |
233 | | - "initializing the engine.") from None |
234 | | - else: |
235 | | - raise e |
| 223 | + if get_pp_group().is_last_rank: |
| 224 | + try: |
| 225 | + max_num_reqs = min( |
| 226 | + self.scheduler_config.max_num_seqs, |
| 227 | + self.scheduler_config.max_num_batched_tokens) |
| 228 | + self.model_runner._dummy_sampler_run( |
| 229 | + hidden_states=self.model_runner._dummy_run( |
| 230 | + num_tokens=max_num_reqs)) |
| 231 | + except RuntimeError as e: |
| 232 | + if 'out of memory' in str(e): |
| 233 | + raise RuntimeError( |
| 234 | + "CUDA out of memory occurred when warming up sampler. " |
| 235 | + "Please try lowering `gpu_memory_utilization` when " |
| 236 | + "initializing the engine.") from None |
| 237 | + else: |
| 238 | + raise e |
236 | 239 |
|
237 | 240 | # Reset the seed to ensure that the random state is not affected by |
238 | 241 | # the model initialization and profiling. |
|
0 commit comments