Skip to content

Commit 79e8784

Browse files
BoyuanFengxuebwang-amd
authored andcommitted
[Graph Partition][Cache] Use inductor partition ops config (vllm-project#27702)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent e9d39da commit 79e8784

File tree

4 files changed

+37
-63
lines changed

4 files changed

+37
-63
lines changed

vllm/compilation/backends.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,9 @@ def compile_context(self, runtime_shape: int | None = None):
9797
compilation (e.g. partition rules, pass context)."""
9898
with pass_context(runtime_shape):
9999
if self.compilation_config.use_inductor_graph_partition:
100-
inductor_partition_ops = resolve_defined_ops(
100+
with inductor_partition_rule_context(
101101
self.compilation_config.splitting_ops
102-
)
103-
with inductor_partition_rule_context(inductor_partition_ops):
102+
):
104103
yield
105104
else:
106105
yield

vllm/compilation/partition_rules.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33

44
import contextlib
55
import logging
6-
from typing import TYPE_CHECKING
76

7+
import torch
88
from torch._library.utils import lookup_op
99

1010
from vllm.logger import init_logger
1111

12-
if TYPE_CHECKING:
13-
import torch
14-
1512
logger = init_logger(__name__)
1613

1714

@@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
5653

5754

5855
@contextlib.contextmanager
59-
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
56+
def inductor_partition_rule_context(splitting_ops: list[str]):
6057
"""Context manager to temporarily register Inductor partition rules.
6158
6259
Registers custom partition rules for specified operators, forcing the
6360
Inductor scheduler to partition the graph at these operators. The rules
6461
are automatically restored to their previous state on exit.
6562
66-
Note: Callers should use resolve_defined_ops() to convert operator names
67-
to OpOverload objects before calling this function.
68-
6963
Args:
70-
overloads: List of resolved operator overload objects.
64+
splitting_ops: List of operator names to partition on.
7165
"""
72-
if not overloads:
66+
if not splitting_ops:
7367
logger.debug("No partition ops provided; skipping rule registration.")
7468
yield
7569
return
7670

77-
from torch._inductor.scheduler import ( # type: ignore
78-
_custom_should_partition_fns,
79-
register_should_partition_rule,
80-
)
81-
82-
def _always_partition(*_args, **_kwargs):
83-
return True
84-
8571
# Save current state before registering
86-
saved_rules = _custom_should_partition_fns.copy()
8772

88-
for overload in overloads:
89-
register_should_partition_rule(
90-
overload,
91-
_always_partition,
92-
)
73+
saved_splitting_ops: list[str] = list(
74+
torch._inductor.config.custom_should_partition_ops
75+
)
76+
torch._inductor.config.custom_should_partition_ops = splitting_ops
9377

94-
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
78+
logger.debug(
79+
"Registered inductor partition rules for %d operators", len(splitting_ops)
80+
)
9581

9682
try:
9783
yield
9884
finally:
9985
# Clear and restore previous state
100-
_custom_should_partition_fns.clear()
101-
_custom_should_partition_fns.update(saved_rules)
86+
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
10287
logger.debug("Restored previous partition rules state.")

vllm/compilation/pass_manager.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,6 @@ def configure(self, config: VllmConfig):
113113
self.post_cleanup = PostCleanupPass(config)
114114
self.fix_functionalization = FixFunctionalizationPass(config)
115115

116-
# [HACK: Bug with Inductor graph partition and torch.compile cache]
117-
# In PyTorch 2.9, torch.compile has a bug where the graph
118-
# partition is not taken into account during caching.
119-
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
120-
# Inductor graph partition, and VLLM_COMPILE implies there
121-
# is a PostGradPassManager, we put the list of operators to graph
122-
# partition into the PostGradPassManager's uuid (which
123-
# then gets incorporated into Inductor's FX graph cache key).
124-
# Remove this hack whenever torch.compile fixes it.
125-
126-
# This is the list of operators that vLLM asks Inductor to split.
127-
self.inductor_splitting_ops = []
128-
if (
129-
config.compilation_config.use_inductor_graph_partition
130-
and config.compilation_config.splitting_ops is not None
131-
):
132-
# Sort them so we're not dependent on the ordering.
133-
self.inductor_splitting_ops = sorted(
134-
config.compilation_config.splitting_ops
135-
)
136-
137116
def add(self, pass_: InductorPass):
138117
assert isinstance(pass_, InductorPass)
139118
self.passes.append(pass_)
@@ -144,16 +123,9 @@ def uuid(self):
144123
affects compilation caching. Its uuid depends on the UUIDs of all
145124
dependent passes and the pass config. See InductorPass for more info.
146125
"""
147-
state = {
148-
"pass_config": self.pass_config.uuid(),
149-
"passes": [],
150-
"inductor_splitting_ops": [],
151-
}
126+
state = {"pass_config": self.pass_config.uuid(), "passes": []}
152127
for pass_ in self.passes:
153128
state["passes"].append(pass_.uuid())
154129
state["passes"].append(self.fix_functionalization.uuid())
155130

156-
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
157-
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
158-
159131
return InductorPass.hash_dict(state)

vllm/env_override.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
272272
from torch._inductor.scheduler import (
273273
BaseSchedulerNode,
274274
FusedSchedulerNode,
275-
_custom_should_partition_fns,
276275
)
277276
from torch._inductor.utils import (
278277
_unstable_customized_partition_wrapper,
@@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
283282
# Allow users to manually specify if a node should be partitioned
284283
# Can only do this for FallbackKernels
285284
ir_node = node.node
286-
if isinstance(ir_node, ir.FallbackKernel):
287-
operator = ir_node.op_overload
288-
if operator is not None and operator in _custom_should_partition_fns:
285+
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
286+
op := ir_node.op_overload
287+
):
288+
op_overload_packet_name = op.name()
289+
op_overload_name = (
290+
f"{op_overload_packet_name}.{op._overloadname}"
291+
if isinstance(op, torch._ops.OpOverload)
292+
else op_overload_packet_name
293+
)
294+
if (
295+
op_overload_packet_name
296+
in torch._inductor.config.custom_should_partition_ops
297+
or op_overload_name in torch._inductor.config.custom_should_partition_ops
298+
):
299+
assert isinstance(op, torch._ops.OpOverload)
289300
return True
290301

291302
# When not using cudagraphs, keep all kernels in the `call` function
@@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
355366
if is_torch_equal("2.9.0"):
356367
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
357368
from torch._inductor.graph import GraphLowering
369+
from torch.utils._config_module import _Config, _ConfigEntry
370+
371+
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
372+
# not overwrite any user configs.
373+
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
374+
_Config(default=[])
375+
)
358376

359377
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
360378
GraphLowering._update_scheduler = _update_scheduler_patched

0 commit comments

Comments
 (0)