File tree Expand file tree Collapse file tree 3 files changed +11
-7
lines changed Expand file tree Collapse file tree 3 files changed +11
-7
lines changed Original file line number Diff line number Diff line change @@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
12211221 ray .shutdown ()
12221222 gc .collect ()
12231223 from vllm .platforms import current_platform
1224- if not current_platform .is_cpu ():
1225- torch .cuda .empty_cache ()
1224+ empty_cache = current_platform .empty_cache
1225+ if empty_cache is not None :
1226+ empty_cache ()
12261227 try :
12271228 torch ._C ._host_emptyCache ()
12281229 except AttributeError :
Original file line number Diff line number Diff line change @@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
120120 # we use synchronous scheduling right now,
121121 # adding a sync point here should not affect
122122 # scheduling of the next batch
123- torch .cuda .synchronize ()
123+ from vllm .platforms import current_platform
124+ synchronize = current_platform .synchronize
125+ if synchronize is not None :
126+ synchronize ()
124127 now = time .perf_counter ()
125128 # time measurement is in milliseconds
126129 batchsize_forward_time [batchsize ].append (
Original file line number Diff line number Diff line change @@ -126,12 +126,12 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
126126 """Copy rejection/typical-acceptance sampling metrics
127127 (number of accepted tokens, etc) to CPU asynchronously.
128128
129- Returns a CUDA event recording when the copy is complete.
129+ Returns a device event recording when the copy is complete.
130130 """
131131 assert self ._copy_stream is not None
132- self ._copy_stream .wait_stream (torch . cuda .current_stream ())
132+ self ._copy_stream .wait_stream (current_platform .current_stream ())
133133
134- with torch . cuda .stream (self ._copy_stream ):
134+ with current_platform .stream (self ._copy_stream ):
135135 self ._aggregate_num_accepted_tokens .copy_ (
136136 self .spec_decode_sampler .num_accepted_tokens ,
137137 non_blocking = True )
@@ -142,7 +142,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
142142 self ._aggregate_num_draft_tokens = (
143143 self .spec_decode_sampler .num_draft_tokens )
144144
145- aggregate_metrics_ready = torch . cuda .Event ()
145+ aggregate_metrics_ready = current_platform .Event ()
146146 aggregate_metrics_ready .record (self ._copy_stream )
147147
148148 return aggregate_metrics_ready
You can’t perform that action at this time.
0 commit comments