Skip to content

Commit 44de6d0

Browse files
angelayiProExpertProg
authored andcommitted
[BugFix] Patch inductor partitioning logic
Signed-off-by: angelayi <yiangela7@gmail.com>
1 parent 4faad1e commit 44de6d0

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

vllm/env_override.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import os
44

55
import torch
6+
from torch._inductor.graph import GraphLowering
67

78
from vllm.logger import init_logger
9+
from vllm.utils import is_torch_equal_or_newer
810

911
logger = init_logger(__name__)
1012

@@ -21,3 +23,97 @@
2123
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
2224
# see https://github.com/vllm-project/vllm/issues/10619
2325
torch._inductor.config.compile_threads = 1
26+
27+
28+
def should_partition_patched(self, node, should_log: bool = False) -> bool:
29+
# Copied from torch._inductor.scheduler.Scheduler.should_partition. Patches
30+
# [this code](https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724)
31+
# so that we always return True.
32+
"""Return True if we should partition the inductor graph on this node"""
33+
34+
import torch._inductor.ir as ir
35+
from torch._inductor.scheduler import (
36+
BaseSchedulerNode,
37+
FusedSchedulerNode,
38+
_custom_should_partition_fns,
39+
)
40+
from torch._inductor.utils import (
41+
_unstable_customized_partition_wrapper,
42+
is_cudagraph_unsafe_op,
43+
maybe_log_cudagraph_partition,
44+
)
45+
46+
# Allow users to manually specify if a node should be partitioned
47+
# Can only do this for FallbackKernels
48+
ir_node = node.node
49+
if isinstance(ir_node, ir.FallbackKernel):
50+
operator = ir_node.op_overload
51+
if operator is not None and operator in _custom_should_partition_fns:
52+
return True
53+
54+
# When not using cudagraphs, keep all kernels in the `call` function
55+
# instead of graph partition functions, since graph partition only brings
56+
# benefit to cudagraph
57+
if (
58+
not torch._inductor.config.triton.cudagraphs
59+
and _unstable_customized_partition_wrapper.wrapper is None
60+
):
61+
return True
62+
63+
# avoid duplicating logs when should_partition is called multiple times
64+
# on the same node
65+
def noop_log(msg: str, node: BaseSchedulerNode | None) -> None:
66+
return
67+
68+
log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log
69+
70+
if isinstance(node, FusedSchedulerNode):
71+
return any(self.should_partition(snode) for snode in node.snodes)
72+
73+
assert node.node is not None
74+
75+
if not node.is_gpu():
76+
log_partition_reason("non gpu ops", node=node)
77+
78+
return True
79+
80+
if isinstance(node.node, ir.DeviceCopy):
81+
log_partition_reason("DeviceCopy ops", node=node)
82+
return True
83+
84+
if isinstance(node.node, ir.Conditional):
85+
log_partition_reason("Conditional ops", node=node)
86+
return True
87+
88+
if getattr(node.node, "unbacked_bindings", None):
89+
log_partition_reason("unbacked binding ops", node=node)
90+
return True
91+
92+
if is_cudagraph_unsafe_op(node.node):
93+
log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
94+
return True
95+
96+
return False
97+
98+
99+
def _update_scheduler_patched(self) -> None:
100+
# Copied from torch._inductor.graph.GrahLowering._update_scheduler. Patches
101+
# this method so that we can patch Scheduler.should_partition with the
102+
# function above
103+
"""
104+
(Re)initializes the scheduler member. When initializing the scheduler, no CUBIN
105+
files should be generated (to avoid biasing any benchmarks and pessimizing
106+
fusion decisions).
107+
"""
108+
import torch._inductor.config as config
109+
from torch._inductor.scheduler import Scheduler
110+
111+
Scheduler.should_partition = should_partition_patched
112+
113+
with config.patch("triton.store_cubin", False):
114+
self.scheduler = Scheduler(self.operations)
115+
116+
117+
# see https://github.com/vllm-project/vllm/issues/26678
118+
if is_torch_equal_or_newer("2.9.0.dev"):
119+
GraphLowering._update_scheduler = _update_scheduler_patched

0 commit comments

Comments
 (0)