Skip to content

Commit 15b2645

Browse files
shen-shanshanYuqi Zhang
authored andcommitted
[Misc] Replace cuda hard code with current_platform (vllm-project#16983)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 7ba7f85 commit 15b2645

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

vllm/distributed/parallel_state.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
12211221
ray.shutdown()
12221222
gc.collect()
12231223
from vllm.platforms import current_platform
1224-
if not current_platform.is_cpu():
1225-
torch.cuda.empty_cache()
1224+
empty_cache = current_platform.empty_cache
1225+
if empty_cache is not None:
1226+
empty_cache()
12261227
try:
12271228
torch._C._host_emptyCache()
12281229
except AttributeError:

vllm/forward_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
120120
# we use synchronous scheduling right now,
121121
# adding a sync point here should not affect
122122
# scheduling of the next batch
123-
torch.cuda.synchronize()
123+
from vllm.platforms import current_platform
124+
synchronize = current_platform.synchronize
125+
if synchronize is not None:
126+
synchronize()
124127
now = time.perf_counter()
125128
# time measurement is in milliseconds
126129
batchsize_forward_time[batchsize].append(

vllm/spec_decode/metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
126126
"""Copy rejection/typical-acceptance sampling metrics
127127
(number of accepted tokens, etc) to CPU asynchronously.
128128
129-
Returns a CUDA event recording when the copy is complete.
129+
Returns a device event recording when the copy is complete.
130130
"""
131131
assert self._copy_stream is not None
132-
self._copy_stream.wait_stream(torch.cuda.current_stream())
132+
self._copy_stream.wait_stream(current_platform.current_stream())
133133

134-
with torch.cuda.stream(self._copy_stream):
134+
with current_platform.stream(self._copy_stream):
135135
self._aggregate_num_accepted_tokens.copy_(
136136
self.spec_decode_sampler.num_accepted_tokens,
137137
non_blocking=True)
@@ -142,7 +142,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
142142
self._aggregate_num_draft_tokens = (
143143
self.spec_decode_sampler.num_draft_tokens)
144144

145-
aggregate_metrics_ready = torch.cuda.Event()
145+
aggregate_metrics_ready = current_platform.Event()
146146
aggregate_metrics_ready.record(self._copy_stream)
147147

148148
return aggregate_metrics_ready

0 commit comments

Comments
 (0)