Skip to content

Commit ce15379

Browse files
wangxiyuanilmarkov
authored andcommitted
[1/N][Platform] Cleanup useless function (vllm-project#26982)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent b72014b commit ce15379

File tree

7 files changed

+21
-106
lines changed

7 files changed

+21
-106
lines changed

tests/models/quantization/test_fp8.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from tests.quantization.utils import is_quant_method_supported
12+
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
1213
from vllm.platforms import current_platform
1314
from vllm.utils import STR_BACKEND_ENV_VAR
1415

@@ -69,8 +70,10 @@ def test_models(
6970
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
7071
pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
7172

72-
if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None):
73-
pytest.skip(f"{kv_cache_dtype} is not supported on this platform.")
73+
if not flash_attn_supports_fp8():
74+
pytest.skip(
75+
f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention."
76+
)
7477

7578
with monkeypatch.context() as m:
7679
m.setenv("TOKENIZERS_PARALLELISM", "true")

tests/quantization/test_compressed_tensors.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,6 @@ def check_model(model):
356356
assert output
357357

358358

359-
@pytest.mark.skipif(
360-
not current_platform.is_kv_cache_dtype_supported("fp8", None),
361-
reason="FP8 KV cache is not supported on this device.",
362-
)
363359
@pytest.mark.skipif(
364360
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
365361
)

vllm/platforms/cuda.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
if TYPE_CHECKING:
2525
from vllm.attention.backends.registry import _Backend
26-
from vllm.config import ModelConfig, VllmConfig
26+
from vllm.config import VllmConfig
2727
else:
2828
_Backend = None
2929

@@ -457,49 +457,6 @@ def get_static_graph_wrapper_cls(cls) -> str:
457457
def device_count(cls) -> int:
458458
return cuda_device_count_stateless()
459459

460-
@classmethod
461-
def is_kv_cache_dtype_supported(
462-
cls, kv_cache_dtype: str, model_config: "ModelConfig"
463-
) -> bool:
464-
fp8_attention = kv_cache_dtype.startswith("fp8")
465-
attention_backend = envs.VLLM_ATTENTION_BACKEND
466-
467-
supported = False
468-
if model_config is not None and model_config.use_mla:
469-
# Default to CutlassMLA for blackwell,
470-
# FlashMLA otherwise
471-
if attention_backend is None:
472-
if cls.is_device_capability(100):
473-
attention_backend = "CUTLASS_MLA"
474-
else:
475-
attention_backend = "FLASHMLA"
476-
477-
# Only FlashMLA and CUTLASS_MLA support fp8
478-
if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
479-
supported = True
480-
else:
481-
supported = not fp8_attention
482-
else:
483-
# Default to FlashAttention
484-
if attention_backend is None:
485-
attention_backend = "FLASH_ATTN"
486-
487-
# All Blackwell backends support fp8
488-
if cls.is_device_capability(100):
489-
supported = True
490-
elif attention_backend == "FLASH_ATTN":
491-
if fp8_attention:
492-
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
493-
494-
supported = flash_attn_supports_fp8()
495-
else:
496-
supported = True
497-
elif attention_backend == "FLASHINFER":
498-
supported = True
499-
elif attention_backend == "TRITON_ATTN":
500-
supported = cls.supports_fp8()
501-
return supported
502-
503460
@classmethod
504461
def check_if_supports_dtype(cls, dtype: torch.dtype):
505462
if dtype == torch.bfloat16: # noqa: SIM102

vllm/platforms/interface.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,31 @@
77
import random
88
import sys
99
from datetime import timedelta
10-
from platform import uname
1110
from typing import TYPE_CHECKING, Any, NamedTuple
1211

1312
import numpy as np
1413
import torch
15-
from torch.distributed import PrefixStore, ProcessGroup
1614

17-
from vllm.inputs import ProcessorInputs, PromptType
1815
from vllm.logger import init_logger
1916

2017
if TYPE_CHECKING:
18+
from torch.distributed import PrefixStore, ProcessGroup
19+
2120
from vllm.attention.backends.registry import _Backend
22-
from vllm.config import ModelConfig, VllmConfig
21+
from vllm.config import VllmConfig
22+
from vllm.inputs import ProcessorInputs, PromptType
2323
from vllm.pooling_params import PoolingParams
2424
from vllm.sampling_params import SamplingParams
2525
from vllm.utils import FlexibleArgumentParser
2626
else:
27-
_Backend = object
28-
ModelConfig = object
29-
VllmConfig = object
30-
PoolingParams = object
31-
SamplingParams = object
3227
FlexibleArgumentParser = object
3328

3429
logger = init_logger(__name__)
3530

3631

3732
def in_wsl() -> bool:
3833
# Reference: https://github.com/microsoft/WSL/issues/4071
39-
return "microsoft" in " ".join(uname()).lower()
34+
return "microsoft" in " ".join(platform.uname()).lower()
4035

4136

4237
class PlatformEnum(enum.Enum):
@@ -178,15 +173,16 @@ def import_kernels(cls) -> None:
178173
import vllm._moe_C # noqa: F401
179174

180175
@classmethod
181-
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
176+
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
177+
# Import _Backend here to avoid circular import.
182178
from vllm.attention.backends.registry import _Backend
183179

184180
return _Backend.TORCH_SDPA
185181

186182
@classmethod
187183
def get_attn_backend_cls(
188184
cls,
189-
selected_backend: _Backend,
185+
selected_backend: "_Backend",
190186
head_size: int,
191187
dtype: torch.dtype,
192188
kv_cache_dtype: str | None,
@@ -317,7 +313,7 @@ def pre_register_and_update(
317313
pass
318314

319315
@classmethod
320-
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
316+
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
321317
"""
322318
Check and update the configuration for the current platform.
323319
@@ -498,9 +494,9 @@ def opaque_attention_op(cls) -> bool:
498494
@classmethod
499495
def validate_request(
500496
cls,
501-
prompt: PromptType,
502-
params: SamplingParams | PoolingParams,
503-
processed_inputs: ProcessorInputs,
497+
prompt: "PromptType",
498+
params: "SamplingParams | PoolingParams",
499+
processed_inputs: "ProcessorInputs",
504500
) -> None:
505501
"""Raises if this request is unsupported on this platform"""
506502

@@ -543,25 +539,16 @@ def get_static_graph_wrapper_cls(cls) -> str:
543539
def stateless_init_device_torch_dist_pg(
544540
cls,
545541
backend: str,
546-
prefix_store: PrefixStore,
542+
prefix_store: "PrefixStore",
547543
group_rank: int,
548544
group_size: int,
549545
timeout: timedelta,
550-
) -> ProcessGroup:
546+
) -> "ProcessGroup":
551547
"""
552548
Init platform-specific torch distributed process group.
553549
"""
554550
raise NotImplementedError
555551

556-
@classmethod
557-
def is_kv_cache_dtype_supported(
558-
cls, kv_cache_dtype: str, model_config: ModelConfig
559-
) -> bool:
560-
"""
561-
Returns if the kv_cache_dtype is supported by the current platform.
562-
"""
563-
return False
564-
565552
@classmethod
566553
def check_if_supports_dtype(cls, dtype: torch.dtype):
567554
"""

vllm/platforms/rocm.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from vllm.attention.backends.registry import _Backend
18-
from vllm.config import ModelConfig, VllmConfig
18+
from vllm.config import VllmConfig
1919
else:
2020
_Backend = None
2121

@@ -474,12 +474,6 @@ def get_static_graph_wrapper_cls(cls) -> str:
474474
def device_count(cls) -> int:
475475
return cuda_device_count_stateless()
476476

477-
@classmethod
478-
def is_kv_cache_dtype_supported(
479-
cls, kv_cache_dtype: str, model_config: "ModelConfig"
480-
) -> bool:
481-
return True
482-
483477
@classmethod
484478
def check_if_supports_dtype(cls, dtype: torch.dtype):
485479
if dtype == torch.bfloat16: # noqa: SIM102

vllm/platforms/tpu.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,6 @@ def validate_request(
222222
):
223223
raise ValueError("Torch XLA does not support per-request seed.")
224224

225-
@classmethod
226-
def is_kv_cache_dtype_supported(
227-
cls, kv_cache_dtype: str, model_config: "ModelConfig"
228-
) -> bool:
229-
return True
230-
231225
@classmethod
232226
@torch.compile(backend="openxla")
233227
def insert_blocks_to_device(

vllm/platforms/xpu.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,6 @@ def get_attn_backend_cls(
8686
logger.info("Using Flash Attention backend on V1 engine.")
8787
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
8888

89-
@classmethod
90-
def is_kv_cache_dtype_supported(
91-
cls, kv_cache_dtype: str, model_config: "ModelConfig"
92-
) -> bool:
93-
"""
94-
Check if the kv_cache_dtype is supported.
95-
XPU only support fp8 kv cache with triton backend.
96-
"""
97-
if (
98-
envs.is_set("VLLM_ATTENTION_BACKEND")
99-
and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN"
100-
):
101-
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
102-
103-
return False
104-
10589
@classmethod
10690
def set_device(cls, device: torch.device) -> None:
10791
"""

0 commit comments

Comments
 (0)