Skip to content

Commit 06d939f

Browse files
baonudesifeizhaiProExpertProg
authored andcommitted
[torch.compile] Make inductor partition rules respect splitting_ops vllm-project#25691 (vllm-project#25845)
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent d6bb074 commit 06d939f

File tree

9 files changed

+267
-112
lines changed

9 files changed

+267
-112
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
198198
compilation_config=CompilationConfig(
199199
level=CompilationLevel.PIECEWISE,
200200
use_cudagraph=True,
201-
splitting_ops=["silly.attention"],
201+
splitting_ops=["silly::attention"],
202202
cudagraph_capture_sizes=[1, 2],
203203
)
204204
)
@@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
267267
compilation_config=CompilationConfig(
268268
level=CompilationLevel.PIECEWISE,
269269
use_cudagraph=False,
270-
splitting_ops=["silly.attention"],
270+
splitting_ops=["silly::attention"],
271271
)
272272
)
273273
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE

tests/compile/piecewise/test_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _run_simple_model(
127127
@torch.inference_mode()
128128
def test_simple_piecewise_compile(use_inductor):
129129
_run_simple_model(
130-
splitting_ops=["silly.attention"],
130+
splitting_ops=["silly::attention"],
131131
use_inductor_graph_partition=False,
132132
use_inductor=use_inductor,
133133
# 2 * num_layers + 1
@@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
142142

143143

144144
@torch.inference_mode()
145-
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
145+
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
146146
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
147147
if not is_torch_equal_or_newer("2.9.0.dev"):
148148
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

tests/compile/piecewise/test_toy_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def run_model(
268268
cudagraph_capture_sizes=[1, 2],
269269
)
270270
if split_attn:
271-
compilation_config.splitting_ops = ["silly.attention"]
271+
compilation_config.splitting_ops = ["silly::attention"]
272272
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
273273
else:
274274
compilation_config = CompilationConfig(
@@ -438,7 +438,7 @@ def benchmark():
438438
compilation_config = CompilationConfig(
439439
level=CompilationLevel.PIECEWISE,
440440
use_cudagraph=True,
441-
splitting_ops=["silly.attention"],
441+
splitting_ops=["silly::attention"],
442442
cudagraph_capture_sizes=cudagraph_sizes,
443443
)
444444
else:

tests/compile/test_config.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from vllm.compilation.counter import compilation_counter
66
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
7-
from vllm.utils import _is_torch_equal_or_newer
7+
from vllm.config.compilation import CompilationLevel
8+
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
89

910

1011
def test_version():
12+
# Test the version comparison logic using the private function
1113
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
1214
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
1315
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
@@ -17,6 +19,9 @@ def test_version():
1719

1820
def test_use_cudagraphs_dynamic():
1921
vllm_config = VllmConfig()
22+
# Default V1 configuration now starts without cudagraphs enabled; the
23+
# engine decides when to capture based on runtime settings instead of a
24+
# blanket default.
2025
assert vllm_config.compilation_config.use_cudagraph
2126

2227

@@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
137142
def test_splitting_ops_dynamic():
138143
# Default config
139144
config = VllmConfig()
140-
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
141-
assert config.compilation_config.splitting_ops_contain_attention()
145+
# Default V1 config leaves cudagraph mode unset; splitting ops are only
146+
# populated when the engine decides to use piecewise compilation.
147+
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
148+
assert not config.compilation_config.splitting_ops_contain_attention()
142149

143150
# When use_inductor_graph_partition=True
144-
if _is_torch_equal_or_newer("2.9.0.dev"):
145-
# inductor graph partition is only available in PyTorch 2.9+.
146-
# this is a fast config check so we are not using pytest.skip.
151+
if is_torch_equal_or_newer("2.9.0.dev"):
147152
config = VllmConfig(
148153
compilation_config=CompilationConfig(
149-
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
154+
level=CompilationLevel.PIECEWISE,
155+
use_inductor_graph_partition=True,
156+
splitting_ops=["vllm::unified_attention"],
150157
)
151158
)
152-
# should ignore splitting_ops
153-
assert config.compilation_config.splitting_ops == []
159+
# with inductor partition we use splitting_ops directly for
160+
# partition rules
161+
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
154162

155-
# When attn_fusion pass enabled.
163+
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
156164
config = VllmConfig(
157165
compilation_config=CompilationConfig(
166+
level=CompilationLevel.PIECEWISE,
158167
pass_config={"enable_attn_fusion": True, "enable_noop": True},
159168
custom_ops=["+quant_fp8"],
160169
cudagraph_mode=CUDAGraphMode.PIECEWISE,
161170
)
162171
)
163-
assert config.compilation_config.splitting_ops == []
164-
# cudagraph mode also fall back to FULL
165-
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
166-
167-
# splitting_ops can not contain attention ops when attn_fusion
168-
# pass enabled.
169-
with pytest.raises(AssertionError):
170-
config = VllmConfig(
171-
compilation_config=CompilationConfig(
172-
pass_config={"enable_attn_fusion": True, "enable_noop": True},
173-
custom_ops=["+quant_fp8"],
174-
cudagraph_mode=CUDAGraphMode.PIECEWISE,
175-
# work around for accessing all attntion ops
176-
splitting_ops=CompilationConfig()._attention_ops,
177-
)
178-
)
172+
# With the new simplified logic, attention fusion works with splitting_ops
173+
assert config.compilation_config.splitting_ops_contain_attention()
174+
# cudagraph mode remains PIECEWISE
175+
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
179176

180177
# When both use_inductor_graph_partition and attn_fusion pass enabled.
181-
if _is_torch_equal_or_newer("2.9.0.dev"):
178+
if is_torch_equal_or_newer("2.9.0.dev"):
182179
config = VllmConfig(
183180
compilation_config=CompilationConfig(
181+
level=CompilationLevel.PIECEWISE,
184182
use_inductor_graph_partition=True,
185183
pass_config={"enable_attn_fusion": True, "enable_noop": True},
186184
custom_ops=["+quant_fp8"],
187185
cudagraph_mode=CUDAGraphMode.PIECEWISE,
188186
)
189187
)
190-
assert config.compilation_config.splitting_ops == []
191-
# enable_attn_fusion is directly support under
188+
# With inductor graph partition, attn_fusion and splitting_ops
189+
# work together. Default splitting_ops include attention ops.
190+
assert config.compilation_config.splitting_ops_contain_attention()
191+
# enable_attn_fusion is directly supported under
192192
# use_inductor_graph_partition=True, and cudagraph_mode
193193
# is unchanged.
194194
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
195+
196+
197+
def test_resolve_operator_overload():
198+
import torch
199+
200+
from vllm.compilation.partition_rules import resolve_defined_ops
201+
202+
# Test valid operator names
203+
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
204+
assert len(resolved) == 2
205+
assert resolved[0] is torch.ops.aten.mm.default
206+
assert resolved[1] is torch.ops.aten.addmm.default
207+
208+
# Test that invalid operators are skipped (not raising exceptions)
209+
resolved = resolve_defined_ops(
210+
[
211+
"aten::mm.default",
212+
"aten::nonexistent_op.default", # This should be skipped
213+
"aten::addmm.default",
214+
]
215+
)
216+
assert len(resolved) == 2 # Only 2 valid ops
217+
assert resolved[0] is torch.ops.aten.mm.default
218+
assert resolved[1] is torch.ops.aten.addmm.default

tests/compile/test_decorator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
7171
compilation_config=CompilationConfig(
7272
level=CompilationLevel.PIECEWISE,
7373
use_cudagraph=True,
74-
splitting_ops=["silly.attention"],
74+
splitting_ops=["silly::attention"],
7575
cudagraph_capture_sizes=[1, 2],
7676
)
7777
)
@@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
186186
compilation_config=CompilationConfig(
187187
level=CompilationLevel.PIECEWISE,
188188
use_cudagraph=True,
189-
splitting_ops=["silly.attention"],
189+
splitting_ops=["silly::attention"],
190190
cudagraph_capture_sizes=[1, 2],
191191
),
192192
)
@@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
218218
compilation_config=CompilationConfig(
219219
level=CompilationLevel.PIECEWISE,
220220
use_cudagraph=True,
221-
splitting_ops=["silly.attention"],
221+
splitting_ops=["silly::attention"],
222222
cudagraph_capture_sizes=[1, 2],
223223
),
224224
)

vllm/compilation/backends.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from torch._dispatch.python import enable_python_dispatcher
1616

1717
import vllm.envs as envs
18+
from vllm.compilation.inductor_pass import pass_context
19+
from vllm.compilation.partition_rules import (
20+
inductor_partition_rule_context,
21+
resolve_defined_ops,
22+
)
1823
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
1924
from vllm.logger import init_logger
2025
from vllm.platforms import current_platform
@@ -76,6 +81,21 @@ def __init__(self, compilation_config: CompilationConfig):
7681
def compute_hash(self, vllm_config: VllmConfig) -> str:
7782
return self.compiler.compute_hash(vllm_config)
7883

84+
@contextmanager
85+
def compile_context(self, runtime_shape: Optional[int] = None):
86+
"""Provide compilation context for the duration of compilation to set
87+
any torch global properties we want to scope to a single Inductor
88+
compilation (e.g. partition rules, pass context)."""
89+
with pass_context(runtime_shape):
90+
if self.compilation_config.use_inductor_graph_partition:
91+
inductor_partition_ops = resolve_defined_ops(
92+
self.compilation_config.splitting_ops
93+
)
94+
with inductor_partition_rule_context(inductor_partition_ops):
95+
yield
96+
else:
97+
yield
98+
7999
def initialize_cache(
80100
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
81101
):
@@ -197,9 +217,15 @@ def compile(
197217
maybe_key = None
198218
else:
199219
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
200-
compiled_graph, handle = self.compiler.compile(
201-
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
202-
)
220+
221+
with self.compile_context(runtime_shape):
222+
compiled_graph, handle = self.compiler.compile(
223+
graph,
224+
example_inputs,
225+
additional_inductor_config,
226+
runtime_shape,
227+
maybe_key,
228+
)
203229

204230
assert compiled_graph is not None, "Failed to compile the graph"
205231

@@ -258,7 +284,7 @@ class SplitItem:
258284

259285

260286
def split_graph(
261-
graph: fx.GraphModule, ops: list[str]
287+
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
262288
) -> tuple[fx.GraphModule, list[SplitItem]]:
263289
# split graph by ops
264290
subgraph_id = 0
@@ -267,7 +293,12 @@ def split_graph(
267293
for node in graph.graph.nodes:
268294
if node.op in ("output", "placeholder"):
269295
continue
270-
if node.op == "call_function" and str(node.target) in ops:
296+
# Match node.target against resolved_ops
297+
# node.target can be OpOverloadPacket, need to check .default
298+
if node.op == "call_function" and (
299+
node.target in resolved_ops
300+
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
301+
):
271302
subgraph_id += 1
272303
node_to_subgraph_id[node] = subgraph_id
273304
split_op_graphs.append(subgraph_id)
@@ -615,9 +646,14 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
615646
self.graph = graph
616647
self.configure_post_pass()
617648

618-
self.split_gm, self.piecewise_graphs = split_graph(
619-
graph, self.compilation_config.splitting_ops
620-
)
649+
if self.compilation_config.use_inductor_graph_partition:
650+
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
651+
fx_split_ops: list[str] = []
652+
else:
653+
fx_split_ops = self.compilation_config.splitting_ops or []
654+
655+
resolved_split_ops = resolve_defined_ops(fx_split_ops)
656+
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
621657

622658
from torch._dynamo.utils import lazy_format_graph_code
623659

vllm/compilation/compiler_interface.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm.config import VllmConfig
1818
from vllm.utils import is_torch_equal_or_newer
1919

20-
from .inductor_pass import pass_context
21-
2220

2321
class CompilerInterface:
2422
"""
@@ -209,13 +207,12 @@ def compile(
209207

210208
from torch._inductor import standalone_compile
211209

212-
with pass_context(runtime_shape):
213-
compiled_graph = standalone_compile(
214-
graph,
215-
example_inputs,
216-
dynamic_shapes=dynamic_shapes,
217-
options={"config_patches": current_config},
218-
)
210+
compiled_graph = standalone_compile(
211+
graph,
212+
example_inputs,
213+
dynamic_shapes=dynamic_shapes,
214+
options={"config_patches": current_config},
215+
)
219216

220217
# Save the compiled artifact to disk in the specified path
221218
assert key is not None
@@ -462,13 +459,12 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
462459
torch._functorch.config.patch(enable_remote_autograd_cache=False)
463460
)
464461

465-
with pass_context(runtime_shape):
466-
compiled_graph = compile_fx(
467-
graph,
468-
example_inputs,
469-
inner_compile=hijacked_compile_fx_inner,
470-
config_patches=current_config,
471-
)
462+
compiled_graph = compile_fx(
463+
graph,
464+
example_inputs,
465+
inner_compile=hijacked_compile_fx_inner,
466+
config_patches=current_config,
467+
)
472468

473469
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
474470
# compilation cache. So turn off the checks if we disable the

0 commit comments

Comments
 (0)