Skip to content

Commit 19ca497

Browse files
committed
[BugFix] Work around graph partition x torch.compile cache issue
In PyTorch 2.9, torch.compile has a bug where the graph partition is not taken into account during caching. Because vLLM's Mode.VLLM_COMPILE is the only mode that uses Inductor graph partition, and VLLM_COMPILE implies there is a PostGradPassManager, we put the list of operators to graph partition into the PostGradPassManager's uuid (which then gets incorporated into Inductor's FX graph cache key). Remove this hack whenever torch.compile fixes it. Signed-off-by: Richard Zou <zou3519@gmail.com>
1 parent 87efc68 commit 19ca497

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

vllm/compilation/pass_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,23 @@ def configure(self, config: VllmConfig):
110110
self.post_cleanup = PostCleanupPass(config)
111111
self.fix_functionalization = FixFunctionalizationPass(config)
112112

113+
# [HACK: Bug with Inductor graph partition and torch.compile cache]
114+
# In PyTorch 2.9, torch.compile has a bug where the graph
115+
# partition is not taken into account during caching.
116+
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
117+
# Inductor graph partition, and VLLM_COMPILE implies there
118+
# is a PostGradPassManager, we put the list of operators to graph
119+
# partition into the PostGradPassManager's uuid (which
120+
# then gets incorporated into Inductor's FX graph cache key).
121+
# Remove this hack whenever torch.compile fixes it.
122+
self.splitting_ops = None
123+
if config.compilation_config.use_inductor_graph_partition:
124+
if config.compilation_config.splitting_ops is None:
125+
self.splitting_ops = []
126+
else:
127+
# Sort them so we're not dependent on the ordering.
128+
self.splitting_ops = sorted(config.compilation_config.splitting_ops)
129+
113130
def add(self, pass_: InductorPass):
114131
assert isinstance(pass_, InductorPass)
115132
self.passes.append(pass_)
@@ -120,8 +137,17 @@ def uuid(self):
120137
affects compilation caching. Its uuid depends on the UUIDs of all
121138
dependent passes and the pass config. See InductorPass for more info.
122139
"""
123-
state = {"pass_config": self.pass_config.uuid(), "passes": []}
140+
state = {
141+
"pass_config": self.pass_config.uuid(),
142+
"passes": [],
143+
"splitting_ops": [],
144+
}
124145
for pass_ in self.passes:
125146
state["passes"].append(pass_.uuid())
126147
state["passes"].append(self.fix_functionalization.uuid())
148+
149+
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
150+
if self.splitting_ops is not None:
151+
state["splitting_ops"].extend(self.splitting_ops)
152+
127153
return InductorPass.hash_dict(state)

0 commit comments

Comments
 (0)