Skip to content

Commit cc14657

Browse files
shen-shanshanadobrzyn
authored andcommitted
[Misc] Refactor platform to get device specific stream and event (vllm-project#14411)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
1 parent 38cde69 commit cc14657

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

vllm/platforms/interface.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,15 @@ def validate_request(
404404
) -> None:
405405
"""Raises if this request is unsupported on this platform"""
406406

407+
def __getattr__(self, key: str):
408+
device = getattr(torch, self.device_name, None)
409+
if device is not None and hasattr(device, key):
410+
return getattr(device, key)
411+
else:
412+
logger.warning("Current platform %s doesn't has '%s' attribute.",
413+
self.device_name, key)
414+
return None
415+
407416

408417
class UnspecifiedPlatform(Platform):
409418
_enum = PlatformEnum.UNSPECIFIED

vllm/spec_decode/metrics.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.model_executor.layers.spec_decode_base_sampler import (
1010
SpecDecodeBaseSampler)
11+
from vllm.platforms import current_platform
1112
from vllm.utils import is_pin_memory_available
1213

1314

@@ -89,14 +90,14 @@ def init_tensors(self,
8990
self._rank = rank
9091
if isinstance(device_type, torch.device):
9192
device_type = device_type.type
92-
if device_type == 'cuda':
93-
self._copy_stream = torch.cuda.Stream()
93+
stream = current_platform.Stream
94+
if stream is not None:
95+
self._copy_stream = stream()
9496

9597
def maybe_collect_rejsample_metrics(
9698
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
97-
# currently using cuda.Event, skip for any non_cuda_alike platform
98-
from vllm.platforms import current_platform
99-
if not current_platform.is_cuda_alike():
99+
# Skip for any platform that doesn't have device Event
100+
if current_platform.Event is None:
100101
return None
101102

102103
# If a copy was initiated in the previous call, collect and return.

0 commit comments

Comments
 (0)