Skip to content

Commit 38bf92d

Browse files
SageMoorertourgeman
authored andcommitted
Reduce the Cuda Graph memory footprint when running with DBO (vllm-project#25779)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
1 parent 247ed6c commit 38bf92d

File tree

2 files changed

+32
-28
lines changed

2 files changed

+32
-28
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3477,8 +3477,10 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
34773477
# We skip EPLB here since we don't want to record dummy metrics
34783478
for num_tokens in compilation_cases:
34793479
# We currently only capture ubatched graphs when its a FULL
3480-
# cudagraph and for uniform decode batches.
3481-
capture_ubatched_graph = self.parallel_config.enable_dbo \
3480+
# cudagraph, a uniform decode batch, and the number of tokens
3481+
# is above the threshold. Otherwise we just capture a non-ubatched
3482+
# version of the graph
3483+
allow_microbatching = self.parallel_config.enable_dbo \
34823484
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
34833485
and uniform_decode \
34843486
and check_ubatch_thresholds(
@@ -3487,37 +3489,27 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
34873489
uniform_decode=uniform_decode,
34883490
)
34893491

3490-
# Currently we capture both microbatched and non-microbatched
3491-
# graphs when capture_ubatched_graph is True, this is because
3492-
# occasionally we will be forced out of microbatching due to other
3493-
# DP ranks not microbatching (usually caused by an empty second
3494-
# microbatch; once we resolve this, we can remove the
3495-
# non-microbatched graph capture).
3496-
allow_microbatching_options = [True, False] if \
3497-
capture_ubatched_graph else [False]
3498-
for allow_microbatching in allow_microbatching_options:
3499-
for _ in range(
3500-
self.compilation_config.cudagraph_num_of_warmups):
3501-
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
3502-
# But be careful, warm up with `NONE`is orthogonal to
3503-
# if we want to warm up attention or not. This is
3504-
# different from the case where `FULL` implies capture
3505-
# attention while `PIECEWISE` implies no attention.
3506-
force_attention = (
3507-
cudagraph_runtime_mode == CUDAGraphMode.FULL)
3508-
self._dummy_run(num_tokens,
3509-
cudagraph_runtime_mode=CUDAGraphMode.NONE,
3510-
force_attention=force_attention,
3511-
uniform_decode=uniform_decode,
3512-
allow_microbatching=allow_microbatching,
3513-
skip_eplb=True,
3514-
remove_lora=False)
3492+
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
3493+
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
3494+
# But be careful, warm up with `NONE`is orthogonal to
3495+
# if we want to warm up attention or not. This is
3496+
# different from the case where `FULL` implies capture
3497+
# attention while `PIECEWISE` implies no attention.
3498+
force_attention = (
3499+
cudagraph_runtime_mode == CUDAGraphMode.FULL)
35153500
self._dummy_run(num_tokens,
3516-
cudagraph_runtime_mode=cudagraph_runtime_mode,
3501+
cudagraph_runtime_mode=CUDAGraphMode.NONE,
3502+
force_attention=force_attention,
35173503
uniform_decode=uniform_decode,
35183504
allow_microbatching=allow_microbatching,
35193505
skip_eplb=True,
35203506
remove_lora=False)
3507+
self._dummy_run(num_tokens,
3508+
cudagraph_runtime_mode=cudagraph_runtime_mode,
3509+
uniform_decode=uniform_decode,
3510+
allow_microbatching=allow_microbatching,
3511+
skip_eplb=True,
3512+
remove_lora=False)
35213513
self.maybe_remove_all_loras(self.lora_config)
35223514

35233515
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

vllm/v1/worker/gpu_ubatch_wrapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,18 @@ def __call__(self, *args, **kwargs):
330330

331331
# If there's no ubatching, just run the runnable object
332332
if ubatch_slices is None:
333+
334+
# This is to account for the case where ubatching was aborted.
335+
# When we capture full graphs we only capture one graph per shape,
336+
# meaning that if we have a ubatched cudagraph for the current
337+
# num_tokens, we don't have a non-ubatched one. Without this
338+
# check, the cudagraph wrapper will try to capture a cudagraph
339+
# for this shape during a normal run.
340+
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
341+
assert batch_descriptor is not None
342+
if batch_descriptor.num_tokens in self.cudagraphs:
343+
cudagraph_runtime_mode = CUDAGraphMode.NONE
344+
333345
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
334346
CUDAGraphMode.PIECEWISE):
335347
return self.runnable(*args, **kwargs)

0 commit comments

Comments
 (0)