Skip to content

Commit db16ee1

Browse files
committed
Merge branch 'main' into luka/custom-op-matching-2
2 parents 12a7c6d + f0862ea commit db16ee1

File tree

1 file changed

+74
-26
lines changed

1 file changed

+74
-26
lines changed

tests/compile/test_decorator.py

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
34
import torch
45
from torch import nn
56

@@ -14,6 +15,7 @@
1415
set_current_vllm_config,
1516
)
1617
from vllm.forward_context import BatchDescriptor, set_forward_context
18+
from vllm.utils import is_torch_equal_or_newer
1719

1820
# This import automatically registers `torch.ops.silly.attention`
1921
from . import silly_attention # noqa: F401
@@ -65,19 +67,40 @@ def run_model(
6567
return output.cpu()
6668

6769

68-
def test_ignore_torch_compile_decorator():
69-
# vllmcompile
70+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
71+
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
72+
# disable compile cache so that we can count the number of compilations
73+
# appropriately
74+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
75+
76+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
77+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
78+
79+
# piecewise
7080
vllm_config = VllmConfig(
7181
compilation_config=CompilationConfig(
7282
mode=CompilationMode.VLLM_COMPILE,
7383
use_cudagraph=True,
7484
splitting_ops=["silly::attention"],
7585
cudagraph_capture_sizes=[1, 2],
76-
use_inductor_graph_partition=False, # TODO test both?
86+
use_inductor_graph_partition=use_inductor_graph_partition,
7787
)
7888
)
7989
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
8090

91+
expected_num_graphs_seen = 1
92+
expected_num_cudagraph_captured = (
93+
4 # num_cudagraph_sizes * num cudagraphs to capture
94+
)
95+
if use_inductor_graph_partition:
96+
expected_num_piecewise_graphs_seen = 1
97+
expected_num_piecewise_capturable_graphs_seen = 1
98+
expected_num_backend_compilations = 1
99+
else:
100+
expected_num_piecewise_graphs_seen = 3
101+
expected_num_piecewise_capturable_graphs_seen = 2
102+
expected_num_backend_compilations = 2
103+
81104
@support_torch_compile
82105
class A(nn.Module):
83106
def __init__(
@@ -104,12 +127,11 @@ class C(B): ...
104127

105128
# A has support_torch_compile
106129
with compilation_counter.expect(
107-
num_graphs_seen=1,
108-
num_piecewise_graphs_seen=3,
109-
num_piecewise_capturable_graphs_seen=2,
110-
num_backend_compilations=2,
111-
num_cudagraph_captured=4,
112-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
130+
num_graphs_seen=expected_num_graphs_seen,
131+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
132+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
133+
num_backend_compilations=expected_num_backend_compilations,
134+
num_cudagraph_captured=expected_num_cudagraph_captured,
113135
):
114136
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
115137

@@ -131,12 +153,11 @@ class C(B): ...
131153

132154
# C's support_torch_compile should override B's ignore_torch_compile
133155
with compilation_counter.expect(
134-
num_graphs_seen=1,
135-
num_piecewise_graphs_seen=3,
136-
num_piecewise_capturable_graphs_seen=2,
137-
num_backend_compilations=2,
138-
num_cudagraph_captured=4,
139-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
156+
num_graphs_seen=expected_num_graphs_seen,
157+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
158+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
159+
num_backend_compilations=expected_num_backend_compilations,
160+
num_cudagraph_captured=expected_num_cudagraph_captured,
140161
):
141162
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
142163

@@ -179,7 +200,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
179200
return x
180201

181202

182-
def test_conditional_compile_enable_if():
203+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
204+
def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
205+
# disable compile cache so that we can count the number of compilations
206+
# appropriately
207+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
208+
209+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
210+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
211+
183212
vllm_config = VllmConfig(
184213
cache_config=CacheConfig(
185214
kv_sharing_fast_prefill=True,
@@ -189,25 +218,34 @@ def test_conditional_compile_enable_if():
189218
use_cudagraph=True,
190219
splitting_ops=["silly::attention"],
191220
cudagraph_capture_sizes=[1, 2],
192-
use_inductor_graph_partition=False, # TODO test both
221+
use_inductor_graph_partition=use_inductor_graph_partition,
193222
),
194223
)
195224
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
196225

197226
with set_current_vllm_config(vllm_config):
198227
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
199228

229+
if use_inductor_graph_partition:
230+
expected_num_piecewise_graphs_seen = 2
231+
expected_num_piecewise_capturable_graphs_seen = 2
232+
expected_num_backend_compilations = 2
233+
else:
234+
expected_num_piecewise_graphs_seen = 6
235+
expected_num_piecewise_capturable_graphs_seen = 4
236+
expected_num_backend_compilations = 4
237+
200238
# A has support_torch_compile but enable_if fn returns False
201239
# enalbe_if will be True for B, so we expect mod1 and mod2
202240
# to be compiled
203241
with compilation_counter.expect(
204242
num_graphs_seen=2,
205-
num_piecewise_graphs_seen=6,
243+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
206244
# 3 piecewise graphs per instance of B()
207-
num_piecewise_capturable_graphs_seen=4,
208-
num_backend_compilations=4,
245+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
246+
num_backend_compilations=expected_num_backend_compilations,
209247
num_cudagraph_captured=8,
210-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248+
# num_cudagraph_sizes * num cudagraphable graphs to capture
211249
):
212250
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
213251

@@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
222260
use_cudagraph=True,
223261
splitting_ops=["silly::attention"],
224262
cudagraph_capture_sizes=[1, 2],
225-
use_inductor_graph_partition=False, # TODO test both?
263+
use_inductor_graph_partition=use_inductor_graph_partition,
226264
),
227265
)
228266

229267
with set_current_vllm_config(vllm_config):
230268
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
231269

270+
if use_inductor_graph_partition:
271+
expected_num_piecewise_graphs_seen = 1
272+
expected_num_piecewise_capturable_graphs_seen = 1
273+
expected_num_backend_compilations = 1
274+
else:
275+
# 3 attn ops and 4 non-attn ops
276+
expected_num_piecewise_graphs_seen = 7
277+
expected_num_piecewise_capturable_graphs_seen = 4
278+
expected_num_backend_compilations = 4
279+
232280
with compilation_counter.expect(
233281
num_graphs_seen=1,
234-
num_piecewise_graphs_seen=7,
282+
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
235283
# 3 attn ops and 4 non-attn ops
236-
num_piecewise_capturable_graphs_seen=4,
237-
num_backend_compilations=4,
284+
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
285+
num_backend_compilations=expected_num_backend_compilations,
238286
num_cudagraph_captured=8,
239-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
287+
# num_cudagraph_sizes * num cudagraphable graphs to capture
240288
):
241289
run_model(vllm_config, mod_A, cudagraph_runtime_mode)

0 commit comments

Comments
 (0)