Skip to content

Commit 0c73026

Browse files
authored
[V1][PP] Fix memory profiling in PP (#13315)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 6a854c7 commit 0c73026

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,11 +1158,12 @@ def profile_run(self) -> None:
11581158
# Trigger compilation for general shape.
11591159
hidden_states = self._dummy_run(self.max_num_tokens,
11601160
dummy_kv_caches)
1161-
if not get_pp_group().is_last_rank:
1162-
return hidden_states
1163-
hidden_states = hidden_states[logit_indices]
1164-
logits = self.model.compute_logits(hidden_states, None)
1165-
# TODO(woosuk): Consider the memory usage of the sampler.
1161+
if get_pp_group().is_last_rank:
1162+
hidden_states = hidden_states[logit_indices]
1163+
logits = self.model.compute_logits(hidden_states, None)
1164+
# TODO(woosuk): Consider the memory usage of the sampler.
1165+
else:
1166+
logits = None
11661167
torch.cuda.synchronize()
11671168
del hidden_states, logits
11681169
self.encoder_cache.clear()

0 commit comments

Comments
 (0)