Skip to content

Commit de59b95

Browse files
baonudesifeizhaiIsotr0py
authored andcommitted
Feature/vit attention unification# 23880 (vllm-project#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 239f853 commit de59b95

File tree

9 files changed

+70
-58
lines changed

9 files changed

+70
-58
lines changed

tests/kernels/attention/test_mha_attn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def clear_cache():
2323
"""Clear lru cache to ensure each test case runs without caching.
2424
"""
2525
_cached_get_attn_backend.cache_clear()
26+
# Clear xformers availability cache
27+
import vllm.attention.layer as layer_module
28+
layer_module.USE_XFORMERS_OPS = None
2629

2730

2831
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
3336
torch.set_default_dtype(torch.float16)
3437

3538
if device == "cpu":
36-
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
39+
with patch("vllm.attention.selector.current_platform",
40+
CpuPlatform()), \
41+
patch("vllm.platforms.current_platform", CpuPlatform()):
3742
attn = MultiHeadAttention(16, 64, scale=1)
38-
assert attn.attn_backend == _Backend.TORCH_SDPA
43+
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
3944
elif device == "hip":
40-
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
45+
with patch("vllm.attention.selector.current_platform",
46+
RocmPlatform()), \
47+
patch("vllm.platforms.current_platform", RocmPlatform()), \
48+
patch("vllm.attention.layer.current_platform", RocmPlatform()):
4149
attn = MultiHeadAttention(16, 64, scale=1)
4250
assert attn.attn_backend == _Backend.TORCH_SDPA
4351
else:
44-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
52+
with patch("vllm.attention.selector.current_platform",
53+
CudaPlatform()), \
54+
patch("vllm.platforms.current_platform", CudaPlatform()):
4555
attn = MultiHeadAttention(16, 64, scale=1)
4656
assert attn.attn_backend == _Backend.XFORMERS
4757

48-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
58+
with patch("vllm.attention.selector.current_platform",
59+
CudaPlatform()), \
60+
patch("vllm.platforms.current_platform", CudaPlatform()):
4961
attn = MultiHeadAttention(16, 72, scale=1)
5062
assert attn.attn_backend == _Backend.XFORMERS
5163

vllm/attention/layer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,13 @@ def __init__(
360360
# currently, only torch_sdpa is supported on rocm
361361
self.attn_backend = _Backend.TORCH_SDPA
362362
else:
363-
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
364-
_Backend.FLEX_ATTENTION):
365-
backend = _Backend.XFORMERS
366-
367363
self.attn_backend = backend if backend in {
368-
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
369-
} else _Backend.TORCH_SDPA
364+
_Backend.TORCH_SDPA,
365+
_Backend.TORCH_SDPA_VLLM_V1,
366+
_Backend.XFORMERS,
367+
_Backend.PALLAS_VLLM_V1,
368+
_Backend.ROCM_AITER_FA,
369+
} else current_platform.get_vit_attn_backend()
370370

371371
if (self.attn_backend == _Backend.XFORMERS
372372
and not check_xformers_availability()):
@@ -413,6 +413,19 @@ def forward(
413413
from torch_xla.experimental.custom_kernel import flash_attention
414414
out = flash_attention(query, key, value, sm_scale=self.scale)
415415
out = out.transpose(1, 2)
416+
elif self.attn_backend == _Backend.ROCM_AITER_FA:
417+
from aiter import flash_attn_varlen_func
418+
419+
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
420+
out = flash_attn_varlen_func(query,
421+
key,
422+
value,
423+
softmax_scale=self.scale)
424+
else:
425+
# ViT attention hasn't supported this backend yet
426+
raise NotImplementedError(
427+
f"ViT attention hasn't supported {self.attn_backend} "
428+
f"backend yet.")
416429

417430
return out.reshape(bsz, q_len, -1)
418431

vllm/model_executor/models/idefics2_vision_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(
170170
quant_config=quant_config,
171171
prefix=f"{prefix}.out_proj",
172172
)
173+
# Use unified MultiHeadAttention with Flash Attention support
173174
self.attn = MultiHeadAttention(self.num_heads_per_partition,
174175
self.head_dim, self.scale)
175176

@@ -181,6 +182,8 @@ def forward(
181182
hidden_states
182183
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
183184
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
185+
186+
# Use unified MultiHeadAttention implementation
184187
out = self.attn(query_states, key_states, value_states)
185188
attn_output, _ = self.out_proj(out)
186189
return attn_output

vllm/model_executor/models/intern_vit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ def __init__(
255255

256256
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
257257

258+
# Use unified MultiHeadAttention with automatic backend selection
259+
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
260+
self.scale)
261+
258262
def forward(self, x: torch.Tensor) -> torch.Tensor:
259263
B, N, C = x.shape
260264
qkv = self.qkv(x)
@@ -268,12 +272,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
268272
B_, N_, H_, D_ = q.shape
269273
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
270274
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
271-
q = q.transpose(1, 2)
272-
k = k.transpose(1, 2)
273-
v = v.transpose(1, 2)
274275

275-
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
276-
x = x.transpose(1, 2).reshape(B, N, -1)
276+
# Use unified MultiHeadAttention with automatic backend selection
277+
x = self.attn(q, k, v)
277278

278279
x = self.proj(x)
279280
return x

vllm/model_executor/models/interns1_vit.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import torch
1414
import torch.nn as nn
15-
import torch.nn.functional as F
1615
from transformers import PretrainedConfig
1716
from transformers.utils import torch_int
1817

18+
from vllm.attention.layer import MultiHeadAttention
1919
from vllm.model_executor.layers.activation import get_act_fn
2020
from vllm.model_executor.layers.layernorm import RMSNorm
2121
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -206,27 +206,24 @@ def __init__(
206206

207207
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
208208

209+
# Use unified MultiHeadAttention with automatic backend selection
210+
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
211+
self.scale)
212+
209213
def forward(self, x: torch.Tensor) -> torch.Tensor:
210214
B, N, C = x.shape
211215

212216
q = self.q_proj(x)
213217
k = self.k_proj(x)
214218
v = self.v_proj(x)
215219

216-
q = q.view(B, N, self.num_heads, self.head_dim)
217-
k = k.view(B, N, self.num_heads, self.head_dim)
218-
v = v.view(B, N, self.num_heads, self.head_dim)
219-
220220
if self.qk_normalization:
221221
B_, N_, H_, D_ = q.shape
222222
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
223223
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
224-
q = q.transpose(1, 2)
225-
k = k.transpose(1, 2)
226-
v = v.transpose(1, 2)
227224

228-
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
229-
x = x.transpose(1, 2).reshape(B, N, -1)
225+
# Use unified MultiHeadAttention with automatic backend selection
226+
x = self.attn(q, k, v)
230227

231228
x = self.projection_layer(x)
232229
return x

vllm/model_executor/models/mllama.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import vllm.distributed.parallel_state as ps
3737
from vllm.attention import Attention, AttentionMetadata, AttentionType
38+
from vllm.attention.layer import MultiHeadAttention
3839
from vllm.attention.ops.paged_attn import PagedAttention
3940
from vllm.attention.selector import _Backend
4041
from vllm.config import VllmConfig
@@ -517,28 +518,21 @@ def __init__(self,
517518
prefix=f"{prefix}.o_proj",
518519
)
519520

521+
# Use unified MultiHeadAttention with automatic backend selection
522+
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
523+
1.0 / math.sqrt(self.head_dim))
524+
520525
def forward(
521526
self,
522527
hidden_state: torch.Tensor,
523528
attention_mask: Optional[torch.Tensor] = None,
524529
) -> torch.Tensor:
525530
qkv, _ = self.qkv_proj(hidden_state)
526531
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
527-
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
528-
self.head_dim).transpose(1, 2)
529-
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
530-
self.head_dim).transpose(1, 2)
531-
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
532-
self.head_dim).transpose(1, 2)
533-
534-
# TODO: remove padding in image encoder
535-
attn_output = F.scaled_dot_product_attention(q,
536-
k,
537-
v,
538-
attn_mask=attention_mask,
539-
dropout_p=0.0)
540-
541-
attn_output = attn_output.transpose(1, 2).contiguous()
532+
533+
# Use unified MultiHeadAttention with automatic backend selection
534+
attn_output = self.attn(q, k, v)
535+
542536
attn_output = attn_output.reshape(attn_output.shape[0],
543537
attn_output.shape[1], -1)
544538
output, _ = self.o_proj(attn_output)

vllm/model_executor/models/step3_vl.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchvision.transforms.functional import InterpolationMode
1717
from transformers import BatchFeature, PretrainedConfig, TensorType
1818

19+
from vllm.attention.layer import MultiHeadAttention
1920
from vllm.config import VllmConfig
2021
from vllm.distributed import get_tensor_model_parallel_world_size
2122
from vllm.model_executor.layers.activation import get_act_fn
@@ -682,9 +683,9 @@ def __init__(self,
682683
prefix=f"{prefix}.out_proj",
683684
disable_tp=use_data_parallel)
684685

685-
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
686-
return tensor.view(bsz, seq_len, self.num_heads,
687-
self.head_dim).transpose(1, 2).contiguous()
686+
# Use unified MultiHeadAttention with automatic backend selection
687+
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
688+
self.scale)
688689

689690
def forward(
690691
self,
@@ -696,19 +697,9 @@ def forward(
696697
# get query proj
697698
qkv, _ = self.qkv_proj(hidden_states)
698699
q, k, v = qkv.chunk(chunks=3, dim=-1)
699-
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
700-
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
701-
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
702-
q = q.transpose(1, 2)
703-
k = k.transpose(1, 2)
704-
v = v.transpose(1, 2)
705-
attn_output = F.scaled_dot_product_attention(q,
706-
k,
707-
v,
708-
scale=self.scale,
709-
is_causal=False)
710-
attn_output = attn_output.transpose(1, 2).reshape(
711-
bsz, tgt_len, self.num_heads * self.head_dim)
700+
701+
# Use unified MultiHeadAttention with automatic backend selection
702+
attn_output = self.attn(q, k, v)
712703

713704
attn_output, _ = self.out_proj(attn_output)
714705

vllm/model_executor/models/vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs(
122122
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
123123
if post_layer_norm is not None and uses_last_layer:
124124
hs_pool[-1] = post_layer_norm(encoder_outputs)
125-
return torch.cat(hs_pool, dim=-1)
125+
return torch.cat(hs_pool, dim=-1)

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class _Backend(enum.Enum):
4848
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
4949
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
5050
TORCH_SDPA = enum.auto()
51+
TORCH_SDPA_VLLM_V1 = enum.auto()
5152
FLASHINFER = enum.auto()
5253
FLASHINFER_VLLM_V1 = enum.auto()
5354
TRITON_MLA = enum.auto() # Supported by V1

0 commit comments

Comments
 (0)