File tree Expand file tree Collapse file tree 2 files changed +15
-5
lines changed Expand file tree Collapse file tree 2 files changed +15
-5
lines changed Original file line number Diff line number Diff line change @@ -404,6 +404,15 @@ def validate_request(
404404 ) -> None :
405405 """Raises if this request is unsupported on this platform"""
406406
407+ def __getattr__ (self , key : str ):
408+ device = getattr (torch , self .device_name , None )
409+ if device is not None and hasattr (device , key ):
410+ return getattr (device , key )
411+ else :
412+ logger .warning ("Current platform %s doesn't has '%s' attribute." ,
413+ self .device_name , key )
414+ return None
415+
407416
408417class UnspecifiedPlatform (Platform ):
409418 _enum = PlatformEnum .UNSPECIFIED
Original file line number Diff line number Diff line change 88
99from vllm .model_executor .layers .spec_decode_base_sampler import (
1010 SpecDecodeBaseSampler )
11+ from vllm .platforms import current_platform
1112from vllm .utils import is_pin_memory_available
1213
1314
@@ -89,14 +90,14 @@ def init_tensors(self,
8990 self ._rank = rank
9091 if isinstance (device_type , torch .device ):
9192 device_type = device_type .type
92- if device_type == 'cuda' :
93- self ._copy_stream = torch .cuda .Stream ()
93+ stream = current_platform .Stream
94+ if stream is not None :
95+ self ._copy_stream = stream ()
9496
9597 def maybe_collect_rejsample_metrics (
9698 self , k : int ) -> Optional [SpecDecodeWorkerMetrics ]:
97- # currently using cuda.Event, skip for any non_cuda_alike platform
98- from vllm .platforms import current_platform
99- if not current_platform .is_cuda_alike ():
99+ # Skip for any platform that doesn't have device Event
100+ if current_platform .Event is None :
100101 return None
101102
102103 # If a copy was initiated in the previous call, collect and return.
You can’t perform that action at this time.
0 commit comments