diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 34ff145e6ed4..fe775bb370f2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -271,6 +271,15 @@ steps: commands: - pytest -v -s prefix_caching + +- label: Platform Tests (CUDA) + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/cuda + commands: + - pytest -v -s cuda/test_cuda_context.py + - label: Samplers Test # 36min mirror_hardwares: [amdexperimental] source_file_dependencies: diff --git a/tests/cuda/test_cuda_context.py b/tests/cuda/test_cuda_context.py new file mode 100644 index 000000000000..f973b284b87e --- /dev/null +++ b/tests/cuda/test_cuda_context.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ctypes +from concurrent.futures import ThreadPoolExecutor + +import pytest +import torch + +from vllm.platforms import current_platform + + +def check_cuda_context(): + """Check CUDA driver context status""" + try: + cuda = ctypes.CDLL('libcuda.so') + device = ctypes.c_int() + result = cuda.cuCtxGetDevice(ctypes.byref(device)) + return (True, device.value) if result == 0 else (False, None) + except Exception: + return False, None + + +def run_cuda_test_in_thread(device_input, expected_device_id): + """Run CUDA context test in separate thread for isolation""" + try: + # New thread should have no CUDA context initially + valid_before, device_before = check_cuda_context() + if valid_before: + return False, \ + "CUDA context should not exist in new thread, " \ + f"got device {device_before}" + + # Test setting CUDA context + current_platform.set_device(device_input) + + # Verify context is created correctly + valid_after, device_id = check_cuda_context() + if not valid_after: + return False, "CUDA context should be valid after set_cuda_context" + if device_id != expected_device_id: + return False, \ + f"Expected device {expected_device_id}, got {device_id}" + + return True, "Success" + except Exception as e: + return False, f"Exception in thread: {str(e)}" + + +class TestSetCudaContext: + """Test suite for the set_cuda_context function.""" + + @pytest.mark.skipif(not current_platform.is_cuda(), + reason="CUDA not available") + @pytest.mark.parametrize(argnames="device_input,expected_device_id", + argvalues=[ + (0, 0), + (torch.device('cuda:0'), 0), + ('cuda:0', 0), + ], + ids=["int", "torch_device", "string"]) + def test_set_cuda_context_parametrized(self, device_input, + expected_device_id): + """Test setting CUDA context in isolated threads.""" + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_cuda_test_in_thread, device_input, + expected_device_id) + success, message = future.result(timeout=30) + assert success, message + + @pytest.mark.skipif(not current_platform.is_cuda(), + reason="CUDA not available") + def test_set_cuda_context_invalid_device_type(self): + """Test error handling for invalid device type.""" + with pytest.raises(ValueError, match="Expected a cuda device"): + current_platform.set_device(torch.device('cpu')) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 54719a3e79dd..879d094f6578 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -71,6 +71,17 @@ def supported_dtypes(self) -> list[torch.dtype]: # though vLLM doesn't support these GPUs. return [torch.float32] + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + super().set_device(device) + # With this trick we can force the device to be set eagerly + # see https://github.com/pytorch/pytorch/issues/155668 + # for why and when it is needed + _ = torch.zeros(1, device=device) + @classmethod def get_device_capability(cls, device_id: int = 0 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3ff173dcd8c8..f962fafabf50 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -298,6 +298,13 @@ def seed_everything(cls, seed: Optional[int] = None) -> None: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.cuda.set_device(device) + @classmethod def pre_register_and_update(cls, parser: Optional[FlexibleArgumentParser] = None