Skip to content

Commit c468b21

Browse files
fhl2000ProExpertProghmellor
authored andcommitted
[MISC] cudagraph_capture_sizes related improvements (vllm-project#26016)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent 088c922 commit c468b21

File tree

14 files changed

+303
-110
lines changed

14 files changed

+303
-110
lines changed

tests/compile/test_config.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
4+
from contextlib import nullcontext
45

56
import pytest
67

78
from vllm.compilation.counter import compilation_counter
89
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
910
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
1011
from vllm.config.compilation import CompilationMode
12+
from vllm.engine.arg_utils import EngineArgs
13+
from vllm.platforms import current_platform
1114
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
1215

1316

@@ -233,3 +236,73 @@ def test_resolve_operator_overload():
233236
assert len(resolved) == 2 # Only 2 valid ops
234237
assert resolved[0] is torch.ops.aten.mm.default
235238
assert resolved[1] is torch.ops.aten.addmm.default
239+
240+
241+
@pytest.mark.skipif(
242+
not current_platform.support_static_graph_mode(),
243+
reason="Skip if not cudagraph mode supported",
244+
)
245+
@pytest.mark.parametrize(
246+
(
247+
"cudagraph_capture_sizes",
248+
"max_cudagraph_capture_size",
249+
"tp_size",
250+
"enable_sequence_parallelism",
251+
"max_num_batched_tokens",
252+
"use_cudagraph",
253+
"expected_max_size",
254+
),
255+
[
256+
(None, None, 1, False, 2048, True, 512),
257+
([1, 2, 4], 4, 1, False, 2048, True, 4),
258+
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
259+
([1, 256], None, 1, False, 2048, 256),
260+
([], None, 1, False, 2048, False, 0),
261+
(None, 0, 1, False, 2048, False, 0),
262+
# truncated to nearest multiple of 8 or 16
263+
(None, 257, 1, False, 2048, True, 256),
264+
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
265+
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
266+
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
267+
# the list should contain at least 1 element when use cudagraph
268+
([], None, 1, False, 2048, True, RuntimeError),
269+
# the max capturing size should be >= 1 when use cudagraph
270+
(None, 0, 1, False, 2048, True, RuntimeError),
271+
],
272+
)
273+
def test_cudagraph_sizes_post_init(
274+
cudagraph_capture_sizes,
275+
max_cudagraph_capture_size,
276+
tp_size,
277+
enable_sequence_parallelism,
278+
max_num_batched_tokens,
279+
use_cudagraph,
280+
expected_max_size,
281+
):
282+
ctx = nullcontext()
283+
if isinstance(expected_max_size, Exception):
284+
ctx = pytest.raises(expected_max_size)
285+
286+
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
287+
with ctx:
288+
compilation_config = CompilationConfig(
289+
cudagraph_capture_sizes=cudagraph_capture_sizes,
290+
max_cudagraph_capture_size=max_cudagraph_capture_size,
291+
pass_config={
292+
"enable_sequence_parallelism": enable_sequence_parallelism,
293+
"enable_fusion": True,
294+
"enable_noop": True,
295+
},
296+
cudagraph_mode=cudagraph_mode,
297+
)
298+
engine_args = EngineArgs(
299+
model="facebook/opt-125m",
300+
tensor_parallel_size=tp_size,
301+
max_num_batched_tokens=max_num_batched_tokens,
302+
compilation_config=compilation_config,
303+
)
304+
vllm_config = engine_args.create_engine_config()
305+
306+
assert (
307+
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
308+
)

vllm/config/compilation.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ class CompilationConfig:
154154
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
155155
- [`cudagraph_capture_sizes`]
156156
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
157+
- [`max_cudagraph_capture_size`]
158+
[vllm.config.CompilationConfig.max_cudagraph_capture_size]
157159
- [`cudagraph_num_of_warmups`]
158160
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
159161
- [`cudagraph_copy_inputs`]
@@ -327,18 +329,16 @@ class CompilationConfig:
327329
more modes may be added.
328330
"""
329331
use_cudagraph: bool = True
330-
"""Whether to use cudagraph inside compilation.
331-
- False: cudagraph inside compilation is not used.
332+
"""Whether to use cudagraph inside compilation:
333+
334+
- False: cudagraph inside compilation is not used.\n
332335
- True: cudagraph inside compilation is used. It requires
333336
that all input buffers have fixed addresses, and all
334337
splitting ops write their outputs to input buffers.
335-
In the vLLM V1 Engine, this flag only applies for
336-
CompilationMode.VLLM_COMPILE (aka -O3).
337-
Note that this is orthogonal to the cudagraph capture logic
338-
outside of compilation.
338+
339339
Warning: This flag is deprecated and will be removed in the next major or
340-
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
341-
instead.
340+
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
341+
_PIECEWISE instead.
342342
"""
343343
cudagraph_num_of_warmups: int = 0
344344
"""Number of warmup runs for cudagraph.
@@ -398,8 +398,22 @@ class CompilationConfig:
398398
pass_config: PassConfig = field(default_factory=PassConfig)
399399
"""Custom inductor passes, see PassConfig for more details"""
400400

401-
max_capture_size: int = field(default=None, init=False) # type: ignore
402-
"""not configurable, computed after init"""
401+
max_cudagraph_capture_size: int | None = field(default=None)
402+
"""The maximum cudagraph capture size.
403+
404+
If cudagraph_capture_sizes is specified, this will be set to the largest
405+
size in that list (or checked for consistency if specified). If
406+
cudagraph_capture_sizes is not specified, the list of sizes is generated
407+
automatically following the pattern:
408+
409+
[1, 2, 4] + list(range(8, 256, 8)) + list(
410+
range(256, max_cudagraph_capture_size + 1, 16))
411+
412+
If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2,
413+
512) by default. This voids OOM in tight memory scenarios with small
414+
max_num_seqs, and prevents capture of many large graphs (>512) that would
415+
greatly increase startup time with limited performance benefit.
416+
"""
403417
local_cache_dir: str = field(default=None, init=False) # type: ignore
404418
"""local cache dir for each rank"""
405419
bs_to_padded_graph_size: list[int] = field(
@@ -408,7 +422,7 @@ class CompilationConfig:
408422
)
409423
"""optimization:
410424
Intuitively, bs_to_padded_graph_size should be dict[int, int].
411-
since we know all keys are in a range [0, max_capture_size],
425+
since we know all keys are in a range [0, max_cudagraph_capture_size],
412426
we can optimize it to list[int] for better lookup performance."""
413427

414428
# keep track of enabled and disabled custom ops
@@ -672,25 +686,12 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
672686

673687
return VllmBackend(vllm_config)
674688

675-
def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
676-
"""To complete the initialization of config,
677-
we need to know the cudagraph sizes."""
678-
679-
if self.cudagraph_capture_sizes is None:
680-
self.cudagraph_capture_sizes = cudagraph_capture_sizes
681-
else:
682-
# de-duplicate the sizes provided by the config
683-
dedup_sizes = list(set(self.cudagraph_capture_sizes))
684-
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
685-
logger.info(
686-
(
687-
"cudagraph sizes specified by model runner"
688-
" %s is overridden by config %s"
689-
),
690-
cudagraph_capture_sizes,
691-
dedup_sizes,
692-
)
693-
self.cudagraph_capture_sizes = dedup_sizes
689+
def post_init_cudagraph_sizes(self) -> None:
690+
"""To complete the initialization after cudagraph related
691+
configs are set. This includes:
692+
- initialize compile_sizes
693+
- pre-compute the mapping bs_to_padded_graph_size
694+
"""
694695

695696
computed_compile_sizes = []
696697
if self.compile_sizes is not None:
@@ -708,23 +709,24 @@ def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
708709
computed_compile_sizes.append(x)
709710
self.compile_sizes = computed_compile_sizes # type: ignore
710711

711-
# sort to make sure cudagraph capture sizes are in descending order
712-
self.cudagraph_capture_sizes.sort(reverse=True)
713-
self.max_capture_size = (
714-
self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
715-
)
712+
# make sure the sizes are in ascending order
713+
self.cudagraph_capture_sizes.sort()
714+
if self.cudagraph_capture_sizes:
715+
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
716716

717717
# pre-compute the mapping from batch size to padded graph size
718-
self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
718+
self.bs_to_padded_graph_size = [
719+
0 for i in range(self.max_cudagraph_capture_size + 1)
720+
]
719721
for end, start in zip(
720-
self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
722+
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
723+
[0] + self.cudagraph_capture_sizes,
721724
):
722725
for bs in range(start, end):
723726
if bs == start:
724727
self.bs_to_padded_graph_size[bs] = start
725728
else:
726729
self.bs_to_padded_graph_size[bs] = end
727-
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
728730

729731
def set_splitting_ops_for_v1(self):
730732
# NOTE: this function needs to be called only when mode is

vllm/config/scheduler.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,6 @@ class SchedulerConfig:
7171
NOTE: This will be replaced by speculative config in the future; it is
7272
present to enable correctness tests until then."""
7373

74-
cuda_graph_sizes: list[int] = field(default_factory=list)
75-
"""Cuda graph capture sizes
76-
1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
77-
2. if one value is provided, then the capture list would follow the
78-
pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
79-
3. more than one value (e.g. 1 2 128) is provided, then the capture list
80-
will follow the provided list."""
81-
8274
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
8375
"""If True, prefill requests can be chunked based
8476
on the remaining max_num_batched_tokens."""
@@ -235,13 +227,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
235227
self.long_prefill_token_threshold,
236228
)
237229

238-
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
239-
# This avoids OOM in tight memory scenarios with small max_num_seqs,
240-
# and prevents capture of many large graphs (>512) that would greatly
241-
# increase startup time with limited performance benefit.
242-
if not self.cuda_graph_sizes:
243-
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
244-
245230
if self.async_scheduling:
246231
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
247232

0 commit comments

Comments
 (0)