From 1bc86dd12f0ceaa21826f8a00151b95fbe605789 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 18 Sep 2024 18:38:11 +0800 Subject: [PATCH] [CI/Build] Avoid CUDA initialization (#8534) Signed-off-by: Amit Garg --- benchmarks/kernels/benchmark_layernorm.py | 9 +-- benchmarks/kernels/benchmark_moe.py | 6 +- .../kernels/benchmark_paged_attention.py | 7 +-- benchmarks/kernels/benchmark_quant.py | 9 +-- benchmarks/kernels/benchmark_rope.py | 6 +- tests/kernels/test_activation.py | 9 +-- tests/kernels/test_attention.py | 18 ++---- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/test_awq_triton.py | 5 +- tests/kernels/test_blocksparse_attention.py | 12 +--- tests/kernels/test_cache.py | 25 +++----- tests/kernels/test_causal_conv1d.py | 5 +- tests/kernels/test_cutlass.py | 11 ++-- tests/kernels/test_flash_attn.py | 5 +- tests/kernels/test_flashinfer.py | 10 +-- tests/kernels/test_fp8_quant.py | 10 ++- tests/kernels/test_gguf.py | 5 +- tests/kernels/test_int8_quant.py | 13 ++-- tests/kernels/test_layernorm.py | 5 +- tests/kernels/test_machete_gemm.py | 2 +- tests/kernels/test_mamba_ssm.py | 5 +- tests/kernels/test_moe.py | 3 +- tests/kernels/test_pos_encoding.py | 14 ++--- tests/kernels/test_prefix_prefill.py | 12 +--- tests/lora/test_layers.py | 5 +- tests/lora/test_punica_sizes.py | 18 ++---- tests/lora/test_punica_variation.py | 18 ++---- .../decoder_only/language/test_granite.py | 9 +-- tests/quantization/test_fp8.py | 4 +- tests/quantization/utils.py | 8 ++- vllm/attention/backends/rocm_flash_attn.py | 3 +- .../ops/blocksparse_attention/interface.py | 5 +- vllm/attention/ops/prefix_prefill.py | 3 +- vllm/attention/selector.py | 4 +- vllm/config.py | 12 ++-- vllm/distributed/parallel_state.py | 3 +- vllm/envs.py | 1 + .../compressed_tensors/compressed_tensors.py | 6 +- .../layers/quantization/fbgemm_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 5 +- .../layers/quantization/utils/marlin_utils.py | 10 +-- .../quantization/utils/marlin_utils_fp8.py | 3 +- .../layers/quantization/utils/w8a8_utils.py | 5 +- vllm/model_executor/model_loader/loader.py | 6 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/utils.py | 10 +-- vllm/platforms/cpu.py | 8 +-- vllm/platforms/cuda.py | 17 ++--- vllm/platforms/interface.py | 62 ++++++++++++++++--- vllm/platforms/rocm.py | 14 ++--- vllm/platforms/tpu.py | 8 ++- vllm/prompt_adapter/utils.py | 4 +- vllm/usage/usage_lib.py | 3 +- vllm/utils.py | 28 ++++++--- vllm/worker/worker.py | 16 +++-- 55 files changed, 256 insertions(+), 256 deletions(-) diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 4947fda02e1cc..92f6053cc6d7e 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -16,10 +16,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") layer = RMSNorm(hidden_size).to(dtype=dtype) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index fd233c71b10a6..c2ad98b7e2656 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,7 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything class BenchmarkConfig(TypedDict): @@ -166,7 +166,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + seed_everything(seed) self.seed = seed def benchmark( @@ -180,7 +180,7 @@ def benchmark( use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) + seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da42..87864d038d593 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, seed_everything) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,10 +28,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 4c1a7b26213a5..743a5744e8614 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm import _custom_ops as ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -17,10 +17,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a9..73fc9e9dbf461 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,7 @@ from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything def benchmark_rope_kernels_multi_lora( @@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ed050ce851535..9b476585fa19e 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, QuickGELU, SiluAndMul) +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -34,9 +35,7 @@ def test_act_and_mul( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) if activation == "silu": @@ -77,9 +76,7 @@ def test_activation( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) layer = activation[0]() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 46831b506aff3..4bd6f7863a658 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,7 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -139,10 +139,8 @@ def test_paged_attention( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -354,10 +352,7 @@ def test_paged_attention_rocm( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -506,10 +501,7 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a20a741c27f74..c1fb45955a0e5 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch): override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + with patch("torch.cuda.get_device_capability", return_value=(7, 5)): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 198d40a155ccb..e95e5bd948212 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) +from vllm.utils import seed_everything device = "cuda" @@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): zeros_cols = qweight_cols zeros_dtype = torch.int32 - torch.manual_seed(0) + seed_everything(0) qweight = torch.randint(0, torch.iinfo(torch.int32).max, @@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size): qzeros_rows = scales_rows qzeros_cols = qweight_cols - torch.manual_seed(0) + seed_everything(0) input = torch.rand((input_rows, input_cols), dtype=input_dtype, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 7357508751ae1..f3bd8f0524264 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -172,10 +172,7 @@ def test_paged_attention( blocksparse_block_size: int, blocksparse_head_sliding_step: int, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 19402a337b8d6..b0e7097fdfbd4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,6 +6,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops +from vllm.utils import seed_everything COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -55,10 +56,7 @@ def test_copy_blocks( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. @@ -134,10 +132,7 @@ def test_reshape_and_cache( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks @@ -229,9 +224,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. @@ -345,10 +338,8 @@ def test_swap_blocks( pytest.skip() if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) src_device = device if direction[0] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu' @@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) low = -224.0 high = 224.0 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 344e07e739454..043c4923bd660 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) +from vllm.utils import seed_everything def causal_conv1d_ref( @@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) if not channel_last: x = torch.randn(batch, 4096 + dim + 64, @@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch = 2 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index d1f0524f83c4c..cc4ca2e91e76f 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,9 +15,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -capability = current_platform.get_device_capability() -capability = capability[0] * 10 + capability[1] - def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): @@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], @@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, use_bias: bool, device: str): @@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, use_bias: bool): diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 870a8bf65eb92..8e960d098c408 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -174,7 +175,7 @@ def test_varlen_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 696cc0c6cdf10..80a388db6530e 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,6 +4,8 @@ import pytest import torch +from vllm.utils import seed_everything + NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] @@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv( soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index bae9b39203ff9..49f5ce53aab54 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, @@ -24,8 +25,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans @@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() @pytest.mark.parametrize("seed", SEEDS) def test_fp8_quant_large(seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings hidden_size = 1152 # Smallest hidden_size to reproduce the error diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index ee29ed93b61fc..1513fc196153c 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops +from vllm.utils import seed_everything GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") @@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") @@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index e93cb535d715a..41e103e1d09f9 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,6 +4,7 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @torch.inference_mode() def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, @@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float, azp: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 6eaf67ec75f41..382079d472ee9 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,6 +3,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing @@ -30,9 +31,7 @@ def test_rms_norm( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index ce65aaef60ac6..0a90882223077 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -48,7 +48,7 @@ # `is_quant_method_supported` conflates kernels with quantization methods # an assumption which is breaking down as quantizations methods can have # have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) def rand_data(shape, dtype=torch.float16): diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index d3cb0a8656a02..f582445692344 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) +from vllm.utils import seed_everything def selective_state_update_ref(state, @@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 2 dim = 4 dstate = 8 @@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): if torch.version.hip: atol *= 2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8072cf09e5b65..b1f0516dfa0b3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,6 +18,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types +from vllm.utils import seed_everything def torch_moe(a, w1, w2, score, topk): @@ -151,7 +152,7 @@ def test_fused_marlin_moe( act_order: bool, num_bits: int, ): - torch.manual_seed(7) + seed_everything(7) if topk > e: return diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 65242e275650c..ba9d2d4389b21 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -46,9 +47,8 @@ def test_rotary_embedding( ) -> None: if rotary_dim is None: rotary_dim = head_size - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -100,9 +100,7 @@ def test_batched_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 60f9a4dc9f90f..3181d92562399 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,7 +9,7 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -39,10 +39,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index effcffc5c174e..e3233c6b60696 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed +from vllm.utils import seed_everything from .utils import DummyLoRAManager @@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.float16 seed = 0 - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c36fb3afb0cc3..314d6215cbd9c 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,7 +4,6 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ -import random from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -145,11 +145,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -238,11 +235,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -329,11 +323,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index d026e34878e04..28a395af19e6d 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,7 +3,6 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ -import random from unittest.mock import patch import pytest @@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -60,11 +60,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -153,11 +150,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -244,11 +238,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index 82c753855e714..e5c5ce4a8f745 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -2,23 +2,18 @@ Run `pytest tests/models/test_granite.py`. """ -import importlib.metadata - import pytest +import transformers from ...utils import check_logprobs_close -TRANSFORMERS_VERSION = tuple( - map(int, - importlib.metadata.version("transformers").split("."))) - MODELS = [ "ibm/PowerLM-3b", ] # GraniteForCausalLM will be in transformers >= 4.45 -@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), +@pytest.mark.skipif(transformers.__version__ < "4.45", reason="granite model test requires transformers >= 4.45") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 58864e83173f9..a0c1d7e24c503 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability >= 89 and not force_marlin: + if current_platform.has_device_capability(89) and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 5fad06878f4a3..061a077592e80 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool: return False capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - return (capability >= - QUANTIZATION_METHODS[quant_method].get_min_capability()) + assert capability is not None + + min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + + return capability.to_int() >= min_capability diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ae56403cfb7b1..22fceda8c80e6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -13,6 +13,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -303,7 +304,7 @@ def __init__( else: # if not using triton, navi3x/navi21/navi10 do not use flash-attn # either - if torch.cuda.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): self.use_naive_attn = True else: try: diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index e870a8e614d12..1ead541f391b5 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -8,8 +8,7 @@ from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) -IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() - and current_platform.get_device_capability()[0] >= 8) +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd @@ -36,7 +35,7 @@ def __init__( use_spda = is_hip() or is_cpu() or not \ IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() - if torch.cuda.is_available() else "cpu") + if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) # NOTE: vllm CPU backend support BF16 instead of FP16. dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 558b2f3eeac7e..a2a649c8ebcfd 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -709,8 +709,7 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 128 if current_platform.has_device_capability(80) else 64 NUM_WARPS = 8 # need to reduce num. blocks when using fp32 diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..fbda263ba8e08 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -203,7 +203,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if current_platform.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -212,7 +212,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if current_platform.get_device_capability()[0] < 8: + if not current_platform.has_device_capability(80): # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/config.py b/vllm/config.py index 6c24d15640e99..9d42b75c1c462 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -1035,20 +1035,20 @@ class DeviceConfig: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if is_neuron(): + if current_platform.is_cuda_alike(): + self.device_type = "cuda" + elif is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" - elif is_cpu(): + elif current_platform.is_cpu(): self.device_type = "cpu" elif is_xpu(): self.device_type = "xpu" else: - # We don't call torch.cuda.is_available() here to - # avoid initializing CUDA before workers are forked - self.device_type = "cuda" + raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly self.device_type = device diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1c864bcd5d708..df07842edfa56 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -35,6 +35,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform @dataclass @@ -191,7 +192,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") diff --git a/vllm/envs.py b/vllm/envs.py index 2003ede95d2d8..6edb06ecd2e20 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,6 +60,7 @@ VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b5b2570966600..ab8207f128348 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -116,10 +116,10 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() supported = capability >= min_capability if error and not supported: raise RuntimeError( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3ccf1af9eb898..eb59344f36d2e 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -32,9 +32,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = not current_platform.has_device_capability(89) @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 32affe06b89b7..b5feb55db0e74 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,9 +120,8 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if is_hip(): self.use_marlin = False diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..fea94cf7322ad 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool, device_capability: Optional[int] = None ): if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) if device_capability < 80: return [] @@ -52,8 +53,9 @@ def _check_marlin_supported( device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( has_zp, device_capability) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5f9d8658a342f..8b3dfaae971c3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -10,8 +10,7 @@ def is_fp8_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 + return current_platform.has_device_capability(80) def apply_fp8_marlin_linear( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 887ee6605560c..d86fea63d8a1b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm if is_hip(): return False - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() return ops.cutlass_scaled_mm_supports_fp8(capability) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fd9533ab156a5..f0d2a9e7f06be 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -97,10 +97,10 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() if capability < quant_config.get_min_capability(): raise ValueError( f"The quantization method {model_config.quantization} " diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 179399a12a3d5..a9a0329e99f08 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -207,7 +207,7 @@ def __init__( selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.get_device_capability()[0] >= 8 + device_available = current_platform.has_device_capability(80) if device_available: from transformers.utils import is_flash_attn_2_available diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005cf..d7eec818cbba4 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,17 +1,13 @@ """Utils for model executor.""" -import random from typing import Any, Dict, Optional -import numpy as np import torch +from vllm.utils import seed_everything + def set_random_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + seed_everything(seed) def set_weight_attrs( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4736e898b6a52..9b348f3e17a5f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -6,10 +6,10 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: return "cpu" - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8d18527e7c973..a9978d5d84d7c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int: class CudaPlatform(Platform): _enum = PlatformEnum.CUDA - @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_capability(physical_device_id) + major, minor = get_physical_device_capability(physical_device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return get_physical_device_name(physical_device_id) - @staticmethod + @classmethod @with_nvml_context - def is_full_nvlink(physical_device_ids: List[int]) -> bool: + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 676f4c9fccf5a..360590d7d5eb6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple, Union import torch @@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum): UNSPECIFIED = enum.auto() +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + class Platform: _enum: PlatformEnum @@ -27,16 +44,47 @@ def is_tpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - @staticmethod - def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 28525e8ff8811..b6a19eca01745 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,12 +1,11 @@ import os from functools import lru_cache -from typing import Tuple import torch from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -20,12 +19,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - return torch.cuda.get_device_capability(device_id) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_name(device_id: int = 0) -> str: + def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 393fc230da0b9..b30bccb103af3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,6 +6,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU - @staticmethod - def inference_mode(): + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py index 989cc5a0f87c8..4cde2a0254b90 100644 --- a/vllm/prompt_adapter/utils.py +++ b/vllm/prompt_adapter/utils.py @@ -8,13 +8,15 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file +from vllm.platforms import current_platform + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" # Get current device name based on available devices def infer_device() -> str: - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): return "cuda" return "cpu" diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 515e0a4d8abe7..7fadfd5dfffb4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.platforms import current_platform from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -151,7 +152,7 @@ def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, extra_kvs: Dict[str, Any]) -> None: # Platform information - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): device_property = torch.cuda.get_device_properties(0) self.gpu_count = torch.cuda.device_count() self.gpu_type = device_property.name diff --git a/vllm/utils.py b/vllm/utils.py index 29b8a8c2907eb..060b387ec7834 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import os +import random import socket import subprocess import sys @@ -32,6 +33,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -373,6 +375,22 @@ def get_cpu_memory() -> int: return psutil.virtual_memory().total +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) @@ -678,9 +694,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -750,7 +764,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) elif is_xpu(): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 52092dc2dc291..3851843afc960 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -454,14 +454,20 @@ def init_worker_distributed_environment( def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: - compute_capability = current_platform.get_device_capability() - if compute_capability[0] < 8: + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}. " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.")