Skip to content

Commit

Permalink
[ROCm][AMD] unify CUDA_VISIBLE_DEVICES usage in cuda/rocm (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
hongxiayang authored and jimpang committed Jul 24, 2024
1 parent f2fb555 commit 4b3657a
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 34 deletions.
14 changes: 7 additions & 7 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,25 @@ RUN pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.4.0 on ROCm
# Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-5.7"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
torchvision==0.19.0.dev20240612 \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
*"rocm-6.0"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
torchvision==0.19.0.dev20240612 \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
*"rocm-6.1"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
torchvision==0.19.0.dev20240612 \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac

Expand Down
7 changes: 1 addition & 6 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ray

import vllm.envs as envs
from vllm.utils import (cuda_device_count_stateless, is_hip,
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)


Expand All @@ -22,11 +22,6 @@ def get_cuda_visible_devices(self):
def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""
if is_hip():
# Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion
# is handled by `update_environment_variables`
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get(
Expand Down
9 changes: 1 addition & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import torch
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once, update_environment_variables)
print_warning_once)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -695,12 +694,6 @@ def __init__(
self.distributed_executor_backend = backend
logger.info("Defaulting to use %s for distributed inference",
backend)
# If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
# propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
# the update_environment_variables function)
if is_hip() and envs.CUDA_VISIBLE_DEVICES:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})

self._verify_args()
self.rank = 0
Expand Down
4 changes: 0 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,6 @@ def get_open_port() -> int:


def update_environment_variables(envs: Dict[str, str]):
if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
# Propagate changes to CUDA_VISIBLE_DEVICES to
# ROCm's HIP_VISIBLE_DEVICES as well
envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
Expand Down
10 changes: 1 addition & 9 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

Expand Down Expand Up @@ -309,14 +309,6 @@ def update_environment_variables(envs: Dict[str, str]) -> None:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
if is_hip():
hip_env_var = "HIP_VISIBLE_DEVICES"
if hip_env_var in os.environ:
logger.warning(
"Ignoring pre-set environment variable `%s=%s` as "
"%s has also been set, which takes precedence.",
hip_env_var, os.environ[hip_env_var], key)
os.environ.pop(hip_env_var, None)
update_environment_variables(envs)

def init_worker(self, *args, **kwargs):
Expand Down

0 comments on commit 4b3657a

Please sign in to comment.