Skip to content

Commit b545a0b

Browse files
authored
fix test_simple_inductor_graph_partition (#26522)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
1 parent 29255cf commit b545a0b

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

tests/compile/piecewise/test_simple.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,14 @@ def test_simple_piecewise_compile(use_inductor):
143143

144144
@torch.inference_mode()
145145
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
146-
def test_simple_inductor_graph_partition(splitting_ops):
146+
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
147147
if not is_torch_equal_or_newer("2.9.0.dev"):
148148
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
149149

150+
# disable compile cache so that we run separately for different splitting_ops
151+
# and get the expected number of cudagraphs captured.
152+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
153+
150154
_run_simple_model(
151155
# Inductor graph partition automatically resets splitting_ops to an empty list
152156
splitting_ops=splitting_ops,

vllm/compilation/compiler_interface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,10 @@ def hijack_load(*args, **kwargs):
332332
nonlocal file_path
333333
compiled_fn = inductor_compiled_graph.current_callable
334334
file_path = compiled_fn.__code__.co_filename # noqa
335-
if not file_path.startswith(self.base_cache_dir):
335+
if (
336+
not file_path.startswith(self.base_cache_dir)
337+
and compiled_fn.__closure__ is not None
338+
):
336339
# hooked in the align_inputs_from_check_idxs function
337340
# in torch/_inductor/utils.py
338341
for cell in compiled_fn.__closure__:
@@ -359,7 +362,10 @@ def hijacked_compile_fx_inner(*args, **kwargs):
359362
nonlocal file_path
360363
compiled_fn = inductor_compiled_graph.current_callable
361364
file_path = compiled_fn.__code__.co_filename # noqa
362-
if not file_path.startswith(self.base_cache_dir):
365+
if (
366+
not file_path.startswith(self.base_cache_dir)
367+
and compiled_fn.__closure__ is not None
368+
):
363369
# hooked in the align_inputs_from_check_idxs function
364370
# in torch/_inductor/utils.py
365371
for cell in compiled_fn.__closure__:

0 commit comments

Comments
 (0)