11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import pytest
34import torch
45from torch import nn
56
1415 set_current_vllm_config ,
1516)
1617from vllm .forward_context import BatchDescriptor , set_forward_context
18+ from vllm .utils import is_torch_equal_or_newer
1719
1820# This import automatically registers `torch.ops.silly.attention`
1921from . import silly_attention # noqa: F401
@@ -65,19 +67,40 @@ def run_model(
6567 return output .cpu ()
6668
6769
68- def test_ignore_torch_compile_decorator ():
69- # vllmcompile
70+ @pytest .mark .parametrize ("use_inductor_graph_partition" , [True , False ])
71+ def test_ignore_torch_compile_decorator (use_inductor_graph_partition , monkeypatch ):
72+ # disable compile cache so that we can count the number of compilations
73+ # appropriately
74+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
75+
76+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
77+ pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
78+
79+ # piecewise
7080 vllm_config = VllmConfig (
7181 compilation_config = CompilationConfig (
7282 mode = CompilationMode .VLLM_COMPILE ,
7383 use_cudagraph = True ,
7484 splitting_ops = ["silly::attention" ],
7585 cudagraph_capture_sizes = [1 , 2 ],
76- use_inductor_graph_partition = False , # TODO test both?
86+ use_inductor_graph_partition = use_inductor_graph_partition ,
7787 )
7888 )
7989 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
8090
91+ expected_num_graphs_seen = 1
92+ expected_num_cudagraph_captured = (
93+ 4 # num_cudagraph_sizes * num cudagraphs to capture
94+ )
95+ if use_inductor_graph_partition :
96+ expected_num_piecewise_graphs_seen = 1
97+ expected_num_piecewise_capturable_graphs_seen = 1
98+ expected_num_backend_compilations = 1
99+ else :
100+ expected_num_piecewise_graphs_seen = 3
101+ expected_num_piecewise_capturable_graphs_seen = 2
102+ expected_num_backend_compilations = 2
103+
81104 @support_torch_compile
82105 class A (nn .Module ):
83106 def __init__ (
@@ -104,12 +127,11 @@ class C(B): ...
104127
105128 # A has support_torch_compile
106129 with compilation_counter .expect (
107- num_graphs_seen = 1 ,
108- num_piecewise_graphs_seen = 3 ,
109- num_piecewise_capturable_graphs_seen = 2 ,
110- num_backend_compilations = 2 ,
111- num_cudagraph_captured = 4 ,
112- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
130+ num_graphs_seen = expected_num_graphs_seen ,
131+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
132+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
133+ num_backend_compilations = expected_num_backend_compilations ,
134+ num_cudagraph_captured = expected_num_cudagraph_captured ,
113135 ):
114136 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
115137
@@ -131,12 +153,11 @@ class C(B): ...
131153
132154 # C's support_torch_compile should override B's ignore_torch_compile
133155 with compilation_counter .expect (
134- num_graphs_seen = 1 ,
135- num_piecewise_graphs_seen = 3 ,
136- num_piecewise_capturable_graphs_seen = 2 ,
137- num_backend_compilations = 2 ,
138- num_cudagraph_captured = 4 ,
139- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
156+ num_graphs_seen = expected_num_graphs_seen ,
157+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
158+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
159+ num_backend_compilations = expected_num_backend_compilations ,
160+ num_cudagraph_captured = expected_num_cudagraph_captured ,
140161 ):
141162 run_model (vllm_config , mod_C , cudagraph_runtime_mode )
142163
@@ -179,7 +200,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
179200 return x
180201
181202
182- def test_conditional_compile_enable_if ():
203+ @pytest .mark .parametrize ("use_inductor_graph_partition" , [True , False ])
204+ def test_conditional_compile_enable_if (use_inductor_graph_partition , monkeypatch ):
205+ # disable compile cache so that we can count the number of compilations
206+ # appropriately
207+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
208+
209+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
210+ pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
211+
183212 vllm_config = VllmConfig (
184213 cache_config = CacheConfig (
185214 kv_sharing_fast_prefill = True ,
@@ -189,25 +218,34 @@ def test_conditional_compile_enable_if():
189218 use_cudagraph = True ,
190219 splitting_ops = ["silly::attention" ],
191220 cudagraph_capture_sizes = [1 , 2 ],
192- use_inductor_graph_partition = False , # TODO test both
221+ use_inductor_graph_partition = use_inductor_graph_partition ,
193222 ),
194223 )
195224 cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
196225
197226 with set_current_vllm_config (vllm_config ):
198227 mod_A = A (vllm_config = vllm_config , prefix = "" ).eval ().cuda ()
199228
229+ if use_inductor_graph_partition :
230+ expected_num_piecewise_graphs_seen = 2
231+ expected_num_piecewise_capturable_graphs_seen = 2
232+ expected_num_backend_compilations = 2
233+ else :
234+ expected_num_piecewise_graphs_seen = 6
235+ expected_num_piecewise_capturable_graphs_seen = 4
236+ expected_num_backend_compilations = 4
237+
200238 # A has support_torch_compile but enable_if fn returns False
201239 # enalbe_if will be True for B, so we expect mod1 and mod2
202240 # to be compiled
203241 with compilation_counter .expect (
204242 num_graphs_seen = 2 ,
205- num_piecewise_graphs_seen = 6 ,
243+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
206244 # 3 piecewise graphs per instance of B()
207- num_piecewise_capturable_graphs_seen = 4 ,
208- num_backend_compilations = 4 ,
245+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
246+ num_backend_compilations = expected_num_backend_compilations ,
209247 num_cudagraph_captured = 8 ,
210- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248+ # num_cudagraph_sizes * num cudagraphable graphs to capture
211249 ):
212250 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
213251
@@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
222260 use_cudagraph = True ,
223261 splitting_ops = ["silly::attention" ],
224262 cudagraph_capture_sizes = [1 , 2 ],
225- use_inductor_graph_partition = False , # TODO test both?
263+ use_inductor_graph_partition = use_inductor_graph_partition ,
226264 ),
227265 )
228266
229267 with set_current_vllm_config (vllm_config ):
230268 mod_A = A (vllm_config = vllm_config , prefix = "" ).eval ().cuda ()
231269
270+ if use_inductor_graph_partition :
271+ expected_num_piecewise_graphs_seen = 1
272+ expected_num_piecewise_capturable_graphs_seen = 1
273+ expected_num_backend_compilations = 1
274+ else :
275+ # 3 attn ops and 4 non-attn ops
276+ expected_num_piecewise_graphs_seen = 7
277+ expected_num_piecewise_capturable_graphs_seen = 4
278+ expected_num_backend_compilations = 4
279+
232280 with compilation_counter .expect (
233281 num_graphs_seen = 1 ,
234- num_piecewise_graphs_seen = 7 ,
282+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
235283 # 3 attn ops and 4 non-attn ops
236- num_piecewise_capturable_graphs_seen = 4 ,
237- num_backend_compilations = 4 ,
284+ num_piecewise_capturable_graphs_seen = expected_num_piecewise_capturable_graphs_seen ,
285+ num_backend_compilations = expected_num_backend_compilations ,
238286 num_cudagraph_captured = 8 ,
239- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
287+ # num_cudagraph_sizes * num cudagraphable graphs to capture
240288 ):
241289 run_model (vllm_config , mod_A , cudagraph_runtime_mode )
0 commit comments