Skip to content

Commit dc41934

Browse files
Lucaskabelailmarkov
authored andcommitted
[Misc][qwen2_5_vl][torch.compile] Enable supports_torch_compile on generic nn.Module and demonstrate speedup on Qwen Vision model (vllm-project#23207)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Lucas Kabela <lucasakabela@gmail.com>
1 parent c17817c commit dc41934

File tree

6 files changed

+335
-98
lines changed

6 files changed

+335
-98
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
from vllm.compilation.counter import compilation_counter
6+
from vllm.config.compilation import CompilationMode
7+
8+
9+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
10+
@pytest.mark.forked
11+
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
12+
"""Test that Qwen2.5-VL vision submodules are compiled.
13+
14+
This test verifies that the 3 vision submodules (Qwen2_5_VisionPatchEmbed,
15+
Qwen2_5_VisionBlock, and Qwen2_5_VisionPatchMerger) are properly tagged
16+
for compilation by checking that num_models_seen increases by at least 3.
17+
"""
18+
# Disable multiprocessing so that the counter is in the same process
19+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
20+
21+
with (
22+
# NOTE: Qwen2.5-VL has 35 models in total - the LLM backend
23+
# Vision Patch Embed, Vision Patch Merger, and then 32 Vision Blocks
24+
# (one for each layer) - in the future, we should fix vLLM compilation
25+
# logic to handle this case and only compile the Vision submodules once
26+
# and reuse the compiled code for all layers
27+
# See https://github.com/vllm-project/vllm/issues/27590
28+
compilation_counter.expect(num_models_seen=35),
29+
vllm_runner(
30+
"Qwen/Qwen2.5-VL-3B-Instruct",
31+
max_model_len=2048,
32+
gpu_memory_utilization=0.7,
33+
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
34+
) as _,
35+
):
36+
pass
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This file contains ops for ViT attention to be compatible with torch.compile
5+
as there are operations here not supported by torch.compile (for instance,
6+
`to_list` in xformers attn, or `.item()` in flash attention)
7+
8+
Using these ops and wrapping vision blocks with `torch.compile` can speed up
9+
throughput in vision models by ~5% relative on H100, and improve token
10+
latencies by ~7% (see qwen2_5_vl for example usage)
11+
12+
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
13+
"""
14+
15+
import einops
16+
import torch
17+
18+
from vllm.utils.torch_utils import direct_register_custom_op
19+
20+
21+
def xformers_attn_seqlens_wrapper(
22+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
23+
) -> torch.Tensor:
24+
from xformers import ops as xops
25+
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
26+
27+
attn_bias = BlockDiagonalMask.from_seqlens(
28+
q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
29+
)
30+
context_layer = xops.memory_efficient_attention_forward(
31+
q, k, v, attn_bias=attn_bias, p=0, scale=None
32+
)
33+
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
34+
return context_layer
35+
36+
37+
def xformers_attn_seqlens_wrapper_fake(
38+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
39+
) -> torch.Tensor:
40+
b, s, h, d = q.shape
41+
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
42+
43+
44+
direct_register_custom_op(
45+
op_name="xformers_attn_seqlens_wrapper",
46+
op_func=xformers_attn_seqlens_wrapper,
47+
fake_impl=xformers_attn_seqlens_wrapper_fake,
48+
)
49+
50+
51+
def vit_xformers_attn_wrapper(
52+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
53+
) -> torch.Tensor:
54+
return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)
55+
56+
57+
def flash_attn_maxseqlen_wrapper(
58+
q: torch.Tensor,
59+
k: torch.Tensor,
60+
v: torch.Tensor,
61+
cu_seqlens: torch.Tensor,
62+
max_seqlen: torch.Tensor,
63+
batch_size: int,
64+
is_rocm_aiter: bool,
65+
use_upstream_fa: bool,
66+
) -> torch.Tensor:
67+
if is_rocm_aiter:
68+
from aiter import flash_attn_varlen_func
69+
else:
70+
if use_upstream_fa:
71+
from flash_attn import flash_attn_varlen_func
72+
else:
73+
from vllm.vllm_flash_attn import flash_attn_varlen_func
74+
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
75+
output = flash_attn_varlen_func(
76+
q,
77+
k,
78+
v,
79+
cu_seqlens_q=cu_seqlens,
80+
cu_seqlens_k=cu_seqlens,
81+
max_seqlen_q=max_seqlen.item(),
82+
max_seqlen_k=max_seqlen.item(),
83+
dropout_p=0.0,
84+
causal=False,
85+
)
86+
context_layer = einops.rearrange(
87+
output, "(b s) h d -> s b (h d)", b=batch_size
88+
).contiguous()
89+
return context_layer
90+
91+
92+
def flash_attn_maxseqlen_wrapper_fake(
93+
q: torch.Tensor,
94+
k: torch.Tensor,
95+
v: torch.Tensor,
96+
cu_seqlens: torch.Tensor,
97+
max_seqlen: torch.Tensor,
98+
batch_size: int,
99+
is_rocm_aiter: bool,
100+
use_upstream_fa: bool,
101+
) -> torch.Tensor:
102+
b, s, h, d = q.shape
103+
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
104+
105+
106+
direct_register_custom_op(
107+
op_name="flash_attn_maxseqlen_wrapper",
108+
op_func=flash_attn_maxseqlen_wrapper,
109+
fake_impl=flash_attn_maxseqlen_wrapper_fake,
110+
)
111+
112+
113+
def vit_flash_attn_wrapper(
114+
q: torch.Tensor,
115+
k: torch.Tensor,
116+
v: torch.Tensor,
117+
cu_seqlens: torch.Tensor,
118+
max_seqlen: torch.Tensor,
119+
batch_size: int,
120+
is_rocm_aiter: bool,
121+
use_upstream_fa: bool,
122+
) -> torch.Tensor:
123+
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
124+
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
125+
)

vllm/compilation/decorators.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import vllm.envs as envs
1919
from vllm.compilation.counter import compilation_counter
2020
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
21-
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
21+
from vllm.config import (
22+
CompilationMode,
23+
VllmConfig,
24+
get_current_vllm_config,
25+
set_current_vllm_config,
26+
)
2227
from vllm.logger import init_logger
2328
from vllm.sequence import IntermediateTensors
2429
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -74,6 +79,21 @@ def support_torch_compile(
7479
) -> Callable[[_T], _T]: ...
7580

7681

82+
@overload
83+
def support_torch_compile(
84+
*,
85+
mark_unbacked_dims: dict[str, int | list[int]] | None,
86+
) -> Callable[[_T], _T]: ...
87+
88+
89+
@overload
90+
def support_torch_compile(
91+
*,
92+
dynamic_arg_dims: dict[str, int | list[int]] | None,
93+
mark_unbacked_dims: dict[str, int | list[int]] | None,
94+
) -> Callable[[_T], _T]: ...
95+
96+
7797
@overload
7898
def support_torch_compile(cls: _T) -> _T: ...
7999

@@ -82,6 +102,7 @@ def support_torch_compile(
82102
cls: _T | None = None,
83103
*,
84104
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
105+
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
85106
enable_if: Callable[[VllmConfig], bool] | None = None,
86107
) -> Callable[[_T], _T] | _T:
87108
"""
@@ -135,6 +156,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
135156
returns a boolean value indicating whether to compile the model or not.
136157
This is useful if you want to compile the model only when certain
137158
conditions are met.
159+
160+
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
161+
dim to be decorated with `mark_unbacked`. This is useful if we would like to
162+
enforce that dynamo do not specialize on 0/1 values in the case of dummy input
163+
such as for vision model compilation
138164
"""
139165

140166
def cls_decorator_helper(cls: _T) -> _T:
@@ -172,7 +198,9 @@ def cls_decorator_helper(cls: _T) -> _T:
172198
raise ValueError(
173199
f"Argument {k} not found in the forward method of {cls}"
174200
)
175-
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
201+
return _support_torch_compile(
202+
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
203+
)
176204

177205
if cls is not None:
178206
# use `support_torch_compile` as a decorator without arguments
@@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
212240
def _support_torch_compile(
213241
cls: _T,
214242
dynamic_arg_dims: dict[str, int | list[int]],
243+
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
215244
enable_if: Callable[[VllmConfig], bool] | None = None,
216245
) -> _T:
217246
"""
@@ -230,8 +259,22 @@ def _support_torch_compile(
230259

231260
setattr(cls, IGNORE_COMPILE_KEY, False)
232261

233-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
234-
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
262+
def __init__(
263+
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
264+
):
265+
if vllm_config is None:
266+
vllm_config = get_current_vllm_config()
267+
268+
# NOTE: to support multimodal models (such as encoder),
269+
# we may not have vllm_config so we may need to patch
270+
# it
271+
sig = inspect.signature(old_init)
272+
if "vllm_config" in sig.parameters:
273+
kwargs["vllm_config"] = vllm_config
274+
if "prefix" in sig.parameters:
275+
kwargs["prefix"] = prefix
276+
old_init(self, **kwargs)
277+
235278
self.vllm_config = vllm_config
236279
enable_compile = enable_if is None or enable_if(vllm_config)
237280
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
@@ -344,6 +387,15 @@ def __call__(self, *args, **kwargs):
344387
"Unsupported dynamic dimensions"
345388
f" {dims} for argument {k} with type {type(arg)}."
346389
)
390+
if mark_unbacked_dims:
391+
for k, dims in mark_unbacked_dims.items():
392+
arg = bound_args.arguments.get(k)
393+
if arg is not None:
394+
dims = [dims] if isinstance(dims, int) else dims
395+
if isinstance(arg, torch.Tensor):
396+
# In case dims is specified with negative indexing
397+
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
398+
torch._dynamo.decorators.mark_unbacked(arg, dims)
347399
# here, it is the starting point of the `torch.compile` process
348400
start_monitoring_torch_compile(self.vllm_config)
349401
logger.debug("Start compiling function %s", self.original_code_object)

vllm/config/compilation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
684684

685685
from vllm.compilation.backends import VllmBackend
686686

687+
# TODO[@lucaskabela]: See if we can forward prefix
688+
# https://github.com/vllm-project/vllm/issues/27045
687689
return VllmBackend(vllm_config)
688690

689691
def post_init_cudagraph_sizes(self) -> None:

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from vllm.config import VllmConfig
4747
from vllm.config.multimodal import BaseDummyOptions
48+
from vllm.forward_context import set_forward_context
4849
from vllm.logger import init_logger
4950
from vllm.model_executor.models.module_mapping import MultiModelKeys
5051
from vllm.model_executor.models.qwen2_5_vl import (
@@ -759,7 +760,8 @@ def _process_image_input(
759760
assert grid_thw.ndim == 2
760761

761762
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
762-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
763+
with set_forward_context(None, self.vllm_config):
764+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
763765
# Split concatenated embeddings for each image item.
764766
merge_size = self.visual.spatial_merge_size
765767
sizes = grid_thw.prod(-1) // merge_size // merge_size
@@ -779,7 +781,8 @@ def _process_video_input(
779781
assert grid_thw.ndim == 2
780782

781783
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
782-
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
784+
with set_forward_context(None, self.vllm_config):
785+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
783786
# Split concatenated embeddings for each video item.
784787
merge_size = self.visual.spatial_merge_size
785788
sizes = grid_thw.prod(-1) // merge_size // merge_size
@@ -839,6 +842,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
839842

840843
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
841844
super().__init__()
845+
self.vllm_config = vllm_config
842846
thinker_config: Qwen2_5OmniThinkerConfig = (
843847
vllm_config.model_config.hf_config.thinker_config
844848
)

0 commit comments

Comments
 (0)