diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 5bfeac0cf027e..ae5062bd64073 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (cuda_device_count_stateless, + error_on_invalid_device_count_status, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async, update_environment_variables) @@ -39,6 +40,8 @@ def _init_executor(self) -> None: assert world_size <= cuda_device_count_stateless(), ( "please set tensor_parallel_size to less than max local gpu count") + error_on_invalid_device_count_status() + # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e742d11bb3e62..e0b9441a90410 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -11,7 +11,8 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, +from vllm.utils import (error_on_invalid_device_count_status, + get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) if ray is not None: @@ -175,6 +176,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) + error_on_invalid_device_count_status() + # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ self._get_worker_kwargs( diff --git a/vllm/utils.py b/vllm/utils.py index 763b0b91c8646..854decc290fae 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,5 +1,6 @@ import argparse import asyncio +import contextlib import datetime import enum import gc @@ -816,6 +817,27 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def error_on_invalid_device_count_status(): + cache_entries = 0 + with contextlib.suppress(Exception): + # future pytorch will fix the issue, device_count will not be cached + # at that time, `.cache_info().currsize` will error out + cache_entries = torch.cuda.device_count.cache_info().currsize + if cache_entries != 0: + # the function is already called, and the result is cached + remembered = torch.cuda.device_count() + current = cuda_device_count_stateless() + if remembered > current: + raise RuntimeError( + "The number of CUDA devices has changed since the first " + "call to torch.cuda.device_count(). This is not allowed " + "and may result in undefined behavior. Please check out " + "https://github.com/vllm-project/vllm/issues/6056 to " + "find the first call to torch.cuda.device_count() " + "and defer it until the engine is up. Or you can set " + "CUDA_VISIBLE_DEVICES to the GPUs you want to use.") + + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids.