File tree Expand file tree Collapse file tree 2 files changed +15
-5
lines changed Expand file tree Collapse file tree 2 files changed +15
-5
lines changed Original file line number Diff line number Diff line change @@ -403,6 +403,15 @@ def validate_request(
403403 ) -> None :
404404 """Raises if this request is unsupported on this platform"""
405405
406+ def __getattr__ (self , key : str ):
407+ device = getattr (torch , self .device_name , None )
408+ if device is not None and hasattr (device , key ):
409+ return getattr (device , key )
410+ else :
411+ logger .warning ("Current platform %s doesn't has '%s' attribute." ,
412+ self .device_name , key )
413+ return None
414+
406415
407416class UnspecifiedPlatform (Platform ):
408417 _enum = PlatformEnum .UNSPECIFIED
Original file line number Diff line number Diff line change 88
99from vllm .model_executor .layers .spec_decode_base_sampler import (
1010 SpecDecodeBaseSampler )
11+ from vllm .platforms import current_platform
1112from vllm .utils import is_pin_memory_available
1213
1314
@@ -89,14 +90,14 @@ def init_tensors(self,
8990 self ._rank = rank
9091 if isinstance (device_type , torch .device ):
9192 device_type = device_type .type
92- if device_type == 'cuda' :
93- self ._copy_stream = torch .cuda .Stream ()
93+ stream = current_platform .Stream
94+ if stream is not None :
95+ self ._copy_stream = stream ()
9496
9597 def maybe_collect_rejsample_metrics (
9698 self , k : int ) -> Optional [SpecDecodeWorkerMetrics ]:
97- # currently using cuda.Event, skip for any non_cuda_alike platform
98- from vllm .platforms import current_platform
99- if not current_platform .is_cuda_alike ():
99+ # Skip for any platform that doesn't have device Event
100+ if current_platform .Event is None :
100101 return None
101102
102103 # If a copy was initiated in the previous call, collect and return.
You can’t perform that action at this time.
0 commit comments