Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,15 @@ steps:
commands:
- pytest -v -s prefix_caching


- label: Platform Tests (CUDA)
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/cuda
commands:
- pytest -v -s cuda/test_cuda_context.py

- label: Samplers Test # 36min
mirror_hardwares: [amdexperimental]
source_file_dependencies:
Expand Down
80 changes: 80 additions & 0 deletions tests/cuda/test_cuda_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ctypes
from concurrent.futures import ThreadPoolExecutor

import pytest
import torch

from vllm.platforms import current_platform


def check_cuda_context():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use torch._C._cuda_hasPrimaryContext(int device)

Copy link
Collaborator Author

@kouroshHakha kouroshHakha Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this API capture what we are tying to do here. torch._C._cuda_hasPrimaryContext(0) returns true even if the background thread is created. The current method returns false which is consistent with the problem that I was trying to solve.

"""Check CUDA driver context status"""
try:
cuda = ctypes.CDLL('libcuda.so')
device = ctypes.c_int()
result = cuda.cuCtxGetDevice(ctypes.byref(device))
return (True, device.value) if result == 0 else (False, None)
except Exception:
return False, None


def run_cuda_test_in_thread(device_input, expected_device_id):
"""Run CUDA context test in separate thread for isolation"""
try:
# New thread should have no CUDA context initially
valid_before, device_before = check_cuda_context()
if valid_before:
return False, \
"CUDA context should not exist in new thread, " \
f"got device {device_before}"

# Test setting CUDA context
current_platform.set_device(device_input)

# Verify context is created correctly
valid_after, device_id = check_cuda_context()
if not valid_after:
return False, "CUDA context should be valid after set_cuda_context"
if device_id != expected_device_id:
return False, \
f"Expected device {expected_device_id}, got {device_id}"

return True, "Success"
except Exception as e:
return False, f"Exception in thread: {str(e)}"


class TestSetCudaContext:
"""Test suite for the set_cuda_context function."""

@pytest.mark.skipif(not current_platform.is_cuda(),
reason="CUDA not available")
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
argvalues=[
(0, 0),
(torch.device('cuda:0'), 0),
('cuda:0', 0),
],
ids=["int", "torch_device", "string"])
def test_set_cuda_context_parametrized(self, device_input,
expected_device_id):
"""Test setting CUDA context in isolated threads."""
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_cuda_test_in_thread, device_input,
expected_device_id)
success, message = future.result(timeout=30)
assert success, message

@pytest.mark.skipif(not current_platform.is_cuda(),
reason="CUDA not available")
def test_set_cuda_context_invalid_device_type(self):
"""Test error handling for invalid device type."""
with pytest.raises(ValueError, match="Expected a cuda device"):
current_platform.set_device(torch.device('cpu'))


if __name__ == "__main__":
pytest.main([__file__, "-v"])
11 changes: 11 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ def supported_dtypes(self) -> list[torch.dtype]:
# though vLLM doesn't support these GPUs.
return [torch.float32]

@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
super().set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)

@classmethod
def get_device_capability(cls,
device_id: int = 0
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ def seed_everything(cls, seed: Optional[int] = None) -> None:
np.random.seed(seed)
torch.manual_seed(seed)

@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be getattr(torch, self.device_type) ?


@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
Expand Down