Skip to content

Commit 28ea554

Browse files
BoyuanFenghmellor
authored andcommitted
[BugFix] Patch inductor memory plan logic (vllm-project#26878)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent ca8de05 commit 28ea554

File tree

4 files changed

+108
-6
lines changed

4 files changed

+108
-6
lines changed

docs/mkdocs/hooks/generate_argparse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
class PydanticMagicMock(MagicMock):
2323
"""`MagicMock` that's able to generate pydantic-core schemas."""
2424

25+
def __init__(self, *args, **kwargs):
26+
name = kwargs.pop("name", None)
27+
super().__init__(*args, **kwargs)
28+
self.__spec__ = importlib.machinery.ModuleSpec(name, None)
29+
2530
def __get_pydantic_core_schema__(self, source_type, handler):
2631
return core_schema.any_schema()
2732

@@ -42,7 +47,9 @@ def auto_mock(module, attr, max_mocks=50):
4247
raise e
4348
except ModuleNotFoundError as e:
4449
logger.info("Mocking %s for argparse doc generation", e.name)
45-
sys.modules[e.name] = PydanticMagicMock()
50+
sys.modules[e.name] = PydanticMagicMock(name=e.name)
51+
except Exception as e:
52+
logger.warning("Failed to import %s.%s: %s", module, attr, e)
4653

4754
raise ImportError(
4855
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
set_current_vllm_config,
2121
)
2222
from vllm.forward_context import BatchDescriptor, set_forward_context
23+
from vllm.utils import is_torch_equal_or_newer
2324

2425
# This import automatically registers `torch.ops.silly.attention`
2526
from .. import silly_attention # noqa: F401
@@ -193,9 +194,8 @@ def run_model(
193194

194195
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
195196
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
196-
if use_inductor_graph_partition:
197-
# FIXME(luka/boyuan): this currently fails
198-
pytest.skip("Inductor graph partition not supported with multi-graph")
197+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
198+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
199199

200200
outputs = []
201201

vllm/env_override.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import os
44

55
import torch
6-
from packaging import version
76

87
from vllm.logger import init_logger
8+
from vllm.utils import is_torch_equal
99

1010
logger = init_logger(__name__)
1111

@@ -23,6 +23,72 @@
2323
# see https://github.com/vllm-project/vllm/issues/10619
2424
torch._inductor.config.compile_threads = 1
2525

26+
# ===================================================
27+
# torch 2.9 Inductor PythonWrapperCodegen monkeypatch
28+
# ===================================================
29+
# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around
30+
# a test failure for test_multi_graph_piecewise_compile_outputs_equal.
31+
# For more context, see https://github.com/pytorch/pytorch/pull/165514.
32+
33+
34+
def memory_plan_reuse_patched(self):
35+
import torch._inductor.ir as ir
36+
from torch._inductor.codegen.wrapper import (
37+
EnterSubgraphLine,
38+
ExitSubgraphLine,
39+
MemoryPlanningLine,
40+
MemoryPlanningState,
41+
SubgraphPythonWrapperCodegen,
42+
)
43+
from torch._inductor.virtualized import V
44+
45+
def get_output_names(graph_outputs) -> list[str]:
46+
import itertools
47+
48+
names = []
49+
shape_counter = itertools.count(0)
50+
none_counter = itertools.count(0)
51+
for node in graph_outputs:
52+
if isinstance(node, ir.NoneAsConstantBuffer):
53+
names.append(f"{V.graph.name}_none{next(none_counter)}")
54+
elif isinstance(node, ir.ShapeAsConstantBuffer):
55+
names.append(f"{V.graph.name}_shape{next(shape_counter)}")
56+
else:
57+
names.append(node.get_name())
58+
return names
59+
60+
if (
61+
isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
62+
and V.graph.wrapper_code.partition_signatures is not None
63+
):
64+
out_names = get_output_names(
65+
V.graph.wrapper_code.partition_signatures.output_nodes
66+
)
67+
else:
68+
out_names = V.graph.get_output_names()
69+
70+
while (
71+
self.lines
72+
and isinstance(self.lines[-1], MemoryPlanningLine)
73+
and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
74+
):
75+
# these lines will be pointless
76+
self.lines.pop()
77+
78+
# codegen allocations in two passes
79+
planning_states = [MemoryPlanningState()]
80+
past_planning_states = []
81+
for i in range(len(self.lines)):
82+
line = self.lines[i]
83+
if isinstance(line, MemoryPlanningLine):
84+
self.lines[i] = line.plan(planning_states[-1])
85+
elif isinstance(line, EnterSubgraphLine):
86+
planning_states.append(MemoryPlanningState())
87+
elif isinstance(line, ExitSubgraphLine):
88+
past_planning_states.append(planning_states.pop())
89+
past_planning_states.append(planning_states.pop())
90+
assert len(planning_states) == 0
91+
2692

2793
# ========================================
2894
# torch 2.9 Inductor Scheduler monkeypatch
@@ -135,7 +201,9 @@ def _update_scheduler_patched(self) -> None:
135201
self.scheduler = Scheduler(self.operations)
136202

137203

138-
if version.parse(str(torch.__version__)) == version.parse("2.9.0"):
204+
if is_torch_equal("2.9.0"):
205+
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
139206
from torch._inductor.graph import GraphLowering
140207

208+
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
141209
GraphLowering._update_scheduler = _update_scheduler_patched

vllm/utils/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3263,6 +3263,33 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
32633263
return torch_version >= version.parse(target)
32643264

32653265

3266+
def _is_torch_equal(target: str) -> bool:
3267+
assert target.count(".") == 2
3268+
torch_version = str(torch.__version__)
3269+
torch_version = version.parse(torch_version)
3270+
# torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu"
3271+
# or "2.6.0+cu128" but never "2.6.0.1"
3272+
return (
3273+
torch_version >= version.parse(target)
3274+
and version.parse(target + ".1") > torch_version
3275+
)
3276+
3277+
3278+
def is_torch_equal(target: str) -> bool:
3279+
"""Check if the installed torch version is == the target version.
3280+
3281+
Args:
3282+
target: a version string, like "2.6.0".
3283+
3284+
Returns:
3285+
Whether the condition meets.
3286+
"""
3287+
try:
3288+
return _is_torch_equal(target)
3289+
except Exception:
3290+
return Version(importlib.metadata.version("torch")) == Version(target)
3291+
3292+
32663293
@cache
32673294
def _has_module(module_name: str) -> bool:
32683295
"""Return True if *module_name* can be found in the current environment.

0 commit comments

Comments
 (0)