33
44import weakref
55from collections .abc import Sequence
6+ from contextlib import nullcontext
67from copy import deepcopy
78from typing import Callable , Union
89
1617from vllm .compilation .pass_manager import with_pattern_match_debug
1718from vllm .compilation .vllm_inductor_pass import VllmInductorPass
1819from vllm .config import VllmConfig , get_current_vllm_config
20+ from vllm .logger import init_logger
21+
22+ logger = init_logger ("vllm.tests.compile.backend" )
1923
2024
2125class LazyInitPass (InductorPass ):
@@ -55,16 +59,19 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
5559 self .inductor_config ["post_grad_custom_post_pass" ] = self .post_pass
5660
5761 if debug_dump_path := vllm_config .compile_debug_dump_path ():
58- self . ctx = depyf . prepare_debug ( debug_dump_path . as_posix () )
59- self .ctx . __enter__ ( )
62+ logger . debug ( "Dumping depyf output to %s" , debug_dump_path )
63+ self .debug_ctx = depyf . prepare_debug ( debug_dump_path . as_posix () )
6064 else :
61- self .ctx = None
65+ self .debug_ctx = nullcontext ()
6266
6367 def __call__ (self , graph : fx .GraphModule , example_inputs ):
6468 self .graph_pre_compile = deepcopy (graph )
6569 from torch ._inductor .compile_fx import compile_fx
6670
67- return compile_fx (graph , example_inputs , config_patches = self .inductor_config )
71+ with self .debug_ctx :
72+ return compile_fx (
73+ graph , example_inputs , config_patches = self .inductor_config
74+ )
6875
6976 @with_pattern_match_debug
7077 def post_pass (self , graph : fx .Graph ):
@@ -83,9 +90,6 @@ def post_pass(self, graph: fx.Graph):
8390 # assign by reference, will reflect the final state of the graph
8491 self .final_graph = graph
8592
86- if self .ctx is not None :
87- self .ctx .__exit__ (None , None , None )
88-
8993 def check_before_ops (self , ops : Sequence [OpOverload ], fully_replaced = True ):
9094 for op in ops :
9195 num_pre = len (list (find_op_nodes (op , self .graph_pre_pass )))
0 commit comments