diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2695da5778aa..5e0e182b8caa 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -403,6 +403,15 @@ def validate_request( ) -> None: """Raises if this request is unsupported on this platform""" + def __getattr__(self, key: str): + device = getattr(torch, self.device_name, None) + if device is not None and hasattr(device, key): + return getattr(device, key) + else: + logger.warning("Current platform %s doesn't has '%s' attribute.", + self.device_name, key) + return None + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index bc0e0a121cd5..0bb8d602ec8f 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -89,14 +90,14 @@ def init_tensors(self, self._rank = rank if isinstance(device_type, torch.device): device_type = device_type.type - if device_type == 'cuda': - self._copy_stream = torch.cuda.Stream() + stream = current_platform.Stream + if stream is not None: + self._copy_stream = stream() def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: - # currently using cuda.Event, skip for any non_cuda_alike platform - from vllm.platforms import current_platform - if not current_platform.is_cuda_alike(): + # Skip for any platform that doesn't have device Event + if current_platform.Event is None: return None # If a copy was initiated in the previous call, collect and return.