Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,15 @@ def validate_request(
) -> None:
"""Raises if this request is unsupported on this platform"""

def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s doesn't has '%s' attribute.",
self.device_name, key)
return None


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
11 changes: 6 additions & 5 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -89,14 +90,14 @@ def init_tensors(self,
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
stream = current_platform.Stream
if stream is not None:
self._copy_stream = stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
from vllm.platforms import current_platform
if not current_platform.is_cuda_alike():
# Skip for any platform that doesn't have device Event
if current_platform.Event is None:
return None

# If a copy was initiated in the previous call, collect and return.
Expand Down