Skip to content

Commit c8675ff

Browse files
committed
log depyf folder, fix context for TestBackend, fix pattern dump
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent d09a278 commit c8675ff

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

tests/compile/backend.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import weakref
55
from collections.abc import Sequence
6+
from contextlib import nullcontext
67
from copy import deepcopy
78
from typing import Callable, Union
89

@@ -16,6 +17,9 @@
1617
from vllm.compilation.pass_manager import with_pattern_match_debug
1718
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
1819
from vllm.config import VllmConfig, get_current_vllm_config
20+
from vllm.logger import init_logger
21+
22+
logger = init_logger("vllm.tests.compile.backend")
1923

2024

2125
class LazyInitPass(InductorPass):
@@ -55,16 +59,19 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
5559
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
5660

5761
if debug_dump_path := vllm_config.compile_debug_dump_path():
58-
self.ctx = depyf.prepare_debug(debug_dump_path.as_posix())
59-
self.ctx.__enter__()
62+
logger.debug("Dumping depyf output to %s", debug_dump_path)
63+
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
6064
else:
61-
self.ctx = None
65+
self.debug_ctx = nullcontext()
6266

6367
def __call__(self, graph: fx.GraphModule, example_inputs):
6468
self.graph_pre_compile = deepcopy(graph)
6569
from torch._inductor.compile_fx import compile_fx
6670

67-
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
71+
with self.debug_ctx:
72+
return compile_fx(
73+
graph, example_inputs, config_patches=self.inductor_config
74+
)
6875

6976
@with_pattern_match_debug
7077
def post_pass(self, graph: fx.Graph):
@@ -83,9 +90,6 @@ def post_pass(self, graph: fx.Graph):
8390
# assign by reference, will reflect the final state of the graph
8491
self.final_graph = graph
8592

86-
if self.ctx is not None:
87-
self.ctx.__exit__(None, None, None)
88-
8993
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
9094
for op in ops:
9195
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))

vllm/compilation/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
2222
import depyf
2323

2424
path.mkdir(parents=True, exist_ok=True)
25+
logger.debug("Dumping depyf output to %s", path)
2526
global context_manager
2627
context_manager = depyf.prepare_debug(path.as_posix())
2728
context_manager.__enter__()

vllm/compilation/vllm_inductor_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
115115
f" please add to dump_patterns if there are any errors.\n\n"
116116
f"from torch._higher_order_ops.auto_functionalize import "
117117
f"auto_functionalized as auto_functionalized\n"
118-
f"from torch._inductor.pattern_matcher import *",
118+
f"from torch._inductor.pattern_matcher import *\n"
119+
f"vllm = torch.ops.vllm",
119120
file=f,
120121
)
121122

0 commit comments

Comments
 (0)