1515 VllmConfig , set_current_vllm_config )
1616from vllm .envs import VLLM_USE_V1
1717from vllm .forward_context import BatchDescriptor , set_forward_context
18+ from vllm .utils import is_torch_equal_or_newer
1819
1920# This import automatically registers `torch.ops.silly.attention`
2021from ..silly_attention import get_global_counter , reset_global_counter
@@ -50,16 +51,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5051 return x
5152
5253
53- @pytest .mark .parametrize ("use_inductor" , [True , False ])
54- @torch .inference_mode ()
55- def test_simple_piecewise_compile (use_inductor ):
56- assert VLLM_USE_V1
57-
54+ def _run_simple_model (
55+ splitting_ops ,
56+ use_inductor_graph_partition ,
57+ use_inductor ,
58+ expected_num_piecewise_graphs_seen ,
59+ expected_num_piecewise_capturable_graphs_seen ,
60+ expected_num_backend_compilations ,
61+ expected_num_cudagraph_captured ,
62+ ):
5863 vllm_config = VllmConfig (compilation_config = CompilationConfig (
5964 level = CompilationLevel .PIECEWISE ,
6065 use_cudagraph = True ,
6166 use_inductor = use_inductor ,
62- splitting_ops = ["silly.attention" ],
67+ splitting_ops = splitting_ops ,
68+ use_inductor_graph_partition = use_inductor_graph_partition ,
6369 cudagraph_copy_inputs = True ,
6470 cudagraph_capture_sizes = [1 , 2 ],
6571 ))
@@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
7076
7177 with compilation_counter .expect (
7278 num_graphs_seen = 1 , # one graph for the model
73- num_piecewise_graphs_seen = 5 , # 2 * num_layers + 1
74- num_piecewise_capturable_graphs_seen = 3 , # 1 + num_layers
75- num_backend_compilations = 3 , # num_piecewise_capturable_graphs_seen
76- num_cudagraph_captured =
77- 6 , # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
79+ num_piecewise_graphs_seen = expected_num_piecewise_graphs_seen ,
80+ num_piecewise_capturable_graphs_seen =
81+ expected_num_piecewise_capturable_graphs_seen ,
82+ num_backend_compilations = expected_num_backend_compilations ,
83+ num_cudagraph_captured = expected_num_cudagraph_captured ,
7884 ), set_forward_context (None ,
7985 vllm_config = vllm_config ): # background context
8086 # warm up with background context
@@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
104110 output = model (input )
105111 assert get_global_counter () == 2
106112 assert torch .allclose (output .cpu (), torch .tensor ([19.0 , 19.0 ]))
113+
114+
115+ @pytest .mark .parametrize ("use_inductor" , [True , False ])
116+ @torch .inference_mode ()
117+ def test_simple_piecewise_compile (use_inductor ):
118+ assert VLLM_USE_V1
119+ _run_simple_model (
120+ splitting_ops = ["silly.attention" ],
121+ use_inductor_graph_partition = False ,
122+ use_inductor = use_inductor ,
123+ expected_num_piecewise_graphs_seen = 5 , # 2 * num_layers + 1
124+ expected_num_piecewise_capturable_graphs_seen = 3 , # 1 + num_layers
125+ expected_num_backend_compilations =
126+ 3 , # num_piecewise_capturable_graphs_seen
127+ expected_num_cudagraph_captured =
128+ 6 , # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
129+ )
130+
131+
132+ @torch .inference_mode ()
133+ @pytest .mark .parametrize ("splitting_ops" , [["silly.attention" ], []])
134+ def test_simple_inductor_graph_partition (splitting_ops ):
135+ assert VLLM_USE_V1
136+ if not is_torch_equal_or_newer ("2.9.0.dev" ):
137+ pytest .skip ("inductor graph partition is only available "
138+ "in PyTorch 2.9+" )
139+
140+ _run_simple_model (
141+ # inductor graph partition automatically resets splitting_ops
142+ # to be an empty list
143+ splitting_ops = splitting_ops ,
144+ use_inductor_graph_partition = True ,
145+ use_inductor = True ,
146+ expected_num_piecewise_graphs_seen =
147+ 1 , # since not splitting at fx graph level
148+ expected_num_piecewise_capturable_graphs_seen =
149+ 1 , # since not splitting at fx graph level
150+ expected_num_backend_compilations =
151+ 1 , # since not splitting at fx graph level
152+ expected_num_cudagraph_captured =
153+ 6 , # inductor graph partition still captures 6
154+ # graph, same as fx graph partition.
155+ )
0 commit comments