Skip to content

Commit ce65ce2

Browse files
BoyuanFengProExpertProg
authored andcommitted
[torch.compile] CUDAGraph Inductor partition integration (#24281)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent d4006bd commit ce65ce2

File tree

9 files changed

+280
-32
lines changed

9 files changed

+280
-32
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
VllmConfig, set_current_vllm_config)
1616
from vllm.envs import VLLM_USE_V1
1717
from vllm.forward_context import BatchDescriptor, set_forward_context
18+
from vllm.utils import is_torch_equal_or_newer
1819

1920
# This import automatically registers `torch.ops.silly.attention`
2021
from ..silly_attention import get_global_counter, reset_global_counter
@@ -50,16 +51,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5051
return x
5152

5253

53-
@pytest.mark.parametrize("use_inductor", [True, False])
54-
@torch.inference_mode()
55-
def test_simple_piecewise_compile(use_inductor):
56-
assert VLLM_USE_V1
57-
54+
def _run_simple_model(
55+
splitting_ops,
56+
use_inductor_graph_partition,
57+
use_inductor,
58+
expected_num_piecewise_graphs_seen,
59+
expected_num_piecewise_capturable_graphs_seen,
60+
expected_num_backend_compilations,
61+
expected_num_cudagraph_captured,
62+
):
5863
vllm_config = VllmConfig(compilation_config=CompilationConfig(
5964
level=CompilationLevel.PIECEWISE,
6065
use_cudagraph=True,
6166
use_inductor=use_inductor,
62-
splitting_ops=["silly.attention"],
67+
splitting_ops=splitting_ops,
68+
use_inductor_graph_partition=use_inductor_graph_partition,
6369
cudagraph_copy_inputs=True,
6470
cudagraph_capture_sizes=[1, 2],
6571
))
@@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
7076

7177
with compilation_counter.expect(
7278
num_graphs_seen=1, # one graph for the model
73-
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
74-
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
75-
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
76-
num_cudagraph_captured=
77-
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
79+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
80+
num_piecewise_capturable_graphs_seen=
81+
expected_num_piecewise_capturable_graphs_seen,
82+
num_backend_compilations=expected_num_backend_compilations,
83+
num_cudagraph_captured=expected_num_cudagraph_captured,
7884
), set_forward_context(None,
7985
vllm_config=vllm_config): # background context
8086
# warm up with background context
@@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
104110
output = model(input)
105111
assert get_global_counter() == 2
106112
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
113+
114+
115+
@pytest.mark.parametrize("use_inductor", [True, False])
116+
@torch.inference_mode()
117+
def test_simple_piecewise_compile(use_inductor):
118+
assert VLLM_USE_V1
119+
_run_simple_model(
120+
splitting_ops=["silly.attention"],
121+
use_inductor_graph_partition=False,
122+
use_inductor=use_inductor,
123+
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
124+
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
125+
expected_num_backend_compilations=
126+
3, # num_piecewise_capturable_graphs_seen
127+
expected_num_cudagraph_captured=
128+
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
129+
)
130+
131+
132+
@torch.inference_mode()
133+
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
134+
def test_simple_inductor_graph_partition(splitting_ops):
135+
assert VLLM_USE_V1
136+
if not is_torch_equal_or_newer("2.9.0.dev"):
137+
pytest.skip("inductor graph partition is only available "
138+
"in PyTorch 2.9+")
139+
140+
_run_simple_model(
141+
# inductor graph partition automatically resets splitting_ops
142+
# to be an empty list
143+
splitting_ops=splitting_ops,
144+
use_inductor_graph_partition=True,
145+
use_inductor=True,
146+
expected_num_piecewise_graphs_seen=
147+
1, # since not splitting at fx graph level
148+
expected_num_piecewise_capturable_graphs_seen=
149+
1, # since not splitting at fx graph level
150+
expected_num_backend_compilations=
151+
1, # since not splitting at fx graph level
152+
expected_num_cudagraph_captured=
153+
6, # inductor graph partition still captures 6
154+
# graph, same as fx graph partition.
155+
)

tests/compile/silly_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
6060
mutates_args=["out"],
6161
fake_impl=silly_attention_fake,
6262
target_lib=silly_lib,
63+
tags=(torch._C.Tag.cudagraph_unsafe, ),
6364
)

tests/compile/test_full_graph.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33

44
from __future__ import annotations
55

6+
import logging
67
import tempfile
78
from typing import Any, Optional, Union
89

910
import pytest
1011
import torch
1112

1213
from tests.quantization.utils import is_quant_method_supported
14+
from tests.v1.attention.utils import _Backend
1315
from vllm import LLM, SamplingParams
14-
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
16+
from vllm.attention.selector import global_force_attn_backend_context_manager
17+
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
18+
PassConfig)
1519
from vllm.platforms import current_platform
20+
from vllm.utils import is_torch_equal_or_newer
1621

1722
from ..utils import create_new_process_for_each_test
1823

@@ -105,18 +110,70 @@ def test_full_graph(
105110
(CompilationConfig(level=CompilationLevel.PIECEWISE,
106111
debug_dump_path=tempfile.gettempdir()),
107112
("facebook/opt-125m", {})),
113+
] + [
114+
# graph inductor partition
115+
(
116+
CompilationConfig(
117+
level=CompilationLevel.PIECEWISE,
118+
# inductor graph partition uses
119+
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
120+
use_inductor_graph_partition=True,
121+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
122+
compile_sizes=[1, 2]),
123+
model) for model in models_list(all=False)
124+
if is_torch_equal_or_newer("2.9.0.dev")
108125
])
109126
# only test some of the models
110127
@create_new_process_for_each_test()
111128
def test_custom_compile_config(
112129
compilation_config: CompilationConfig,
113130
model_info: tuple[str, dict[str, Any]],
114131
):
132+
if (compilation_config.use_inductor_graph_partition
133+
and not is_torch_equal_or_newer("2.9.0.dev")):
134+
pytest.skip("inductor graph partition is only available "
135+
"in PyTorch 2.9+")
136+
115137
model, model_kwargs = model_info
116138
print(f"MODEL={model}")
117139
run_model(compilation_config, model, model_kwargs)
118140

119141

142+
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
143+
if not is_torch_equal_or_newer("2.9.0.dev"):
144+
pytest.skip("inductor graph partition is only available "
145+
"in PyTorch 2.9+")
146+
147+
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
148+
compilation_config = CompilationConfig(
149+
level=CompilationLevel.PIECEWISE,
150+
use_inductor_graph_partition=True,
151+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
152+
custom_ops=["+quant_fp8"],
153+
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
154+
)
155+
model_kwargs = {
156+
"kv_cache_dtype": "fp8",
157+
"max_model_len": 1024,
158+
}
159+
with caplog_vllm.at_level(
160+
logging.DEBUG), global_force_attn_backend_context_manager(
161+
_Backend.FLASHINFER):
162+
run_model(compilation_config, model, model_kwargs)
163+
164+
try:
165+
assert ("Fused quantization onto 48 attention nodes"
166+
in caplog_vllm.text), caplog_vllm.text
167+
except AssertionError:
168+
# Note: this message is only triggered when the compilation goes
169+
# through the custom pass. Due to multiple layers of cache on
170+
# PyTorch side, the compilation of a graph may be cached such
171+
# that custom pass directly goes through cache. In this case,
172+
# we go through this branch and assert that the pass is not
173+
# triggered.
174+
assert "Fused quantization" not in caplog_vllm.text
175+
176+
120177
def run_model(compile_config: Union[int, CompilationConfig], model: str,
121178
model_kwargs: dict[str, Any]):
122179
prompts = [

tests/compile/test_fusion_attn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2828
Fp8LinearOp)
2929
from vllm.platforms import current_platform
30+
from vllm.utils import is_torch_equal_or_newer
3031
from vllm.v1.kv_cache_interface import AttentionSpec
3132

3233
FP8_DTYPE = current_platform.fp8_dtype()
@@ -339,6 +340,10 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
339340
@pytest.mark.parametrize(
340341
"split_attention",
341342
[False, True] if current_platform.is_rocm() else [False])
343+
# TODO(boyuan): test inductor graph partition on rocm
344+
@pytest.mark.parametrize(
345+
"use_inductor_graph_partition",
346+
[False] if current_platform.is_rocm() else [False, True])
342347
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
343348
reason="Only test ROCm or CUDA")
344349
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
352357
dtype: torch.dtype, model_name: str,
353358
model_class: type[AttentionQuantPatternModel],
354359
backend: _Backend, split_attention: bool,
355-
monkeypatch, dist_init):
360+
use_inductor_graph_partition: bool,
361+
monkeypatch, dist_init, caplog_vllm):
356362
"""Test AttentionStaticQuantPattern fusion pass"""
357363

364+
if use_inductor_graph_partition and not is_torch_equal_or_newer(
365+
"2.9.0.dev"):
366+
pytest.skip("inductor graph partition is only available "
367+
"in PyTorch 2.9+")
368+
358369
monkeypatch.setenv("VLLM_USE_V1", "1")
359370
if split_attention:
360371
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
@@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
372383
compilation_config=CompilationConfig(
373384
level=CompilationLevel.PIECEWISE,
374385
custom_ops=["+quant_fp8"],
386+
use_inductor_graph_partition=use_inductor_graph_partition,
375387
),
376388
cache_config=CacheConfig(cache_dtype="fp8"))
377389

@@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
444456
backend=test_backend,
445457
fullgraph=True)
446458
assert model_compiled.attn._o_scale_float is None
459+
447460
result_fused_1 = model_compiled(q, k, v)
448461

449462
if backend == _Backend.FLASHINFER:
@@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
453466
# _o_scale_float
454467
assert model_compiled.attn._o_scale_float is not None
455468
result_fused_2 = model_compiled(q, k, v)
469+
456470
assert model_compiled.attn._o_scale_float is not None
457471

458472
torch.testing.assert_close(result_unfused,

vllm/attention/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def unified_attention_fake(
577577
mutates_args=[],
578578
fake_impl=unified_attention_fake,
579579
dispatch_key=current_platform.dispatch_key,
580+
tags=(torch._C.Tag.cudagraph_unsafe, ),
580581
)
581582

582583

@@ -627,4 +628,5 @@ def unified_attention_with_output_fake(
627628
mutates_args=["output", "output_block_scale"],
628629
fake_impl=unified_attention_with_output_fake,
629630
dispatch_key=current_platform.dispatch_key,
631+
tags=(torch._C.Tag.cudagraph_unsafe, ),
630632
)

vllm/compilation/backends.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def call_module(self, target: torch.fx.node.Target,
329329
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
330330
]
331331
global compilation_start_time
332+
332333
compiled_graph_for_dynamic_shape = self.vllm_backend.\
333334
compiler_manager.compile(
334335
submod,
@@ -339,15 +340,20 @@ def call_module(self, target: torch.fx.node.Target,
339340
num_graphs=len(self.compile_submod_names),
340341
runtime_shape=None)
341342
# Lazy import here to avoid circular import
342-
from .cuda_graph import CUDAGraphOptions
343343
from .cuda_piecewise_backend import PiecewiseBackend
344344

345345
piecewise_backend = PiecewiseBackend(
346346
submod, self.vllm_config, index,
347347
len(self.compile_submod_names), sym_shape_indices,
348348
compiled_graph_for_dynamic_shape, self.vllm_backend)
349349

350-
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
350+
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
351+
and
352+
not self.compilation_config.use_inductor_graph_partition):
353+
# We're using Dynamo-based piecewise splitting, so we wrap
354+
# the whole subgraph with a static graph wrapper.
355+
from .cuda_graph import CUDAGraphOptions
356+
351357
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
352358
# class) as platform dependent.
353359
static_graph_wrapper_class = resolve_obj_by_qualname(

vllm/compilation/decorators.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import contextlib
45
import inspect
56
from typing import Callable, Optional, TypeVar, Union, overload
67
from unittest.mock import patch
@@ -14,7 +15,7 @@
1415
from vllm.config import CompilationLevel, VllmConfig
1516
from vllm.logger import init_logger
1617
from vllm.sequence import IntermediateTensors
17-
from vllm.utils import supports_dynamo
18+
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
1819

1920
from .monitor import start_monitoring_torch_compile
2021

@@ -301,8 +302,11 @@ def patched_inline_call(parent, func, args, kwargs):
301302

302303
with patch.object(InliningInstructionTranslator, 'inline_call',
303304
patched_inline_call), torch._dynamo.config.patch(
304-
**dynamo_config_patches):
305+
**dynamo_config_patches
306+
), maybe_use_cudagraph_partition_wrapper(
307+
self.vllm_config):
305308
output = self.compiled_callable(*args, **kwargs)
309+
306310
return output
307311

308312
# usually, capturing the model once is enough, and then we can
@@ -314,3 +318,52 @@ def patched_inline_call(parent, func, args, kwargs):
314318

315319
cls.__call__ = __call__
316320
return cls
321+
322+
323+
@contextlib.contextmanager
324+
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
325+
"""
326+
Context manager to set/unset customized cudagraph partition wrappers.
327+
328+
If we're using Inductor-based graph partitioning, we currently have the
329+
whole `fx.Graph` before Inductor lowering and and the piecewise
330+
splitting happens after all graph passes and fusions. Here, we add
331+
a custom hook for Inductor to wrap each partition with our static
332+
graph wrapper class to maintain more control over static graph
333+
capture and replay.
334+
"""
335+
from vllm.config import CUDAGraphMode
336+
337+
compilation_config = vllm_config.compilation_config
338+
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
339+
and compilation_config.use_inductor_graph_partition):
340+
from torch._inductor.utils import CUDAGraphWrapperMetadata
341+
342+
from vllm.compilation.cuda_graph import CUDAGraphOptions
343+
from vllm.platforms import current_platform
344+
345+
static_graph_wrapper_class = resolve_obj_by_qualname(
346+
current_platform.get_static_graph_wrapper_cls())
347+
348+
def customized_cudagraph_wrapper(f,
349+
metadata: CUDAGraphWrapperMetadata):
350+
partition_id = metadata.partition_index
351+
num_partitions = metadata.num_partitions
352+
return static_graph_wrapper_class(
353+
runnable=f,
354+
vllm_config=vllm_config,
355+
runtime_mode=CUDAGraphMode.PIECEWISE,
356+
cudagraph_options=CUDAGraphOptions(
357+
debug_log_enable=partition_id == 0,
358+
gc_disable=partition_id != 0,
359+
weak_ref_output=partition_id == num_partitions - 1,
360+
))
361+
362+
torch._inductor.utils.set_customized_partition_wrappers(
363+
customized_cudagraph_wrapper)
364+
365+
yield
366+
367+
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
368+
and compilation_config.use_inductor_graph_partition):
369+
torch._inductor.utils.set_customized_partition_wrappers(None)

0 commit comments

Comments
 (0)