Skip to content

Commit d2020ac

Browse files
authored
config check sleep mode support oot platforms (#16562)
1 parent 1eb3c2e commit d2020ac

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,10 @@ def __init__(
417417

418418
from vllm.platforms import current_platform
419419

420-
if self.enable_sleep_mode and not current_platform.is_cuda():
421-
raise ValueError("Sleep mode is only supported on CUDA devices.")
420+
if (self.enable_sleep_mode
421+
and not current_platform.is_sleep_mode_available()):
422+
raise ValueError(
423+
"Sleep mode is not supported on current platform.")
422424

423425
hf_config = get_config(self.hf_config_path or self.model,
424426
trust_remote_code, revision, code_revision,

vllm/platforms/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def is_cuda_alike(self) -> bool:
148148
"""Stateless version of :func:`torch.cuda.is_available`."""
149149
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
150150

151+
def is_sleep_mode_available(self) -> bool:
152+
return self._enum == PlatformEnum.CUDA
153+
151154
@classmethod
152155
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
153156
dtype: torch.dtype, kv_cache_dtype: Optional[str],

0 commit comments

Comments
 (0)