|
35 | 35 |
|
36 | 36 | import vllm.distributed.parallel_state as ps |
37 | 37 | from vllm.attention import Attention, AttentionMetadata, AttentionType |
| 38 | +from vllm.attention.layer import MultiHeadAttention |
38 | 39 | from vllm.attention.ops.paged_attn import PagedAttention |
39 | 40 | from vllm.attention.selector import _Backend |
40 | 41 | from vllm.config import VllmConfig |
@@ -517,28 +518,21 @@ def __init__(self, |
517 | 518 | prefix=f"{prefix}.o_proj", |
518 | 519 | ) |
519 | 520 |
|
| 521 | + # Use unified MultiHeadAttention with automatic backend selection |
| 522 | + self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, |
| 523 | + 1.0 / math.sqrt(self.head_dim)) |
| 524 | + |
520 | 525 | def forward( |
521 | 526 | self, |
522 | 527 | hidden_state: torch.Tensor, |
523 | 528 | attention_mask: Optional[torch.Tensor] = None, |
524 | 529 | ) -> torch.Tensor: |
525 | 530 | qkv, _ = self.qkv_proj(hidden_state) |
526 | 531 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
527 | | - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, |
528 | | - self.head_dim).transpose(1, 2) |
529 | | - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, |
530 | | - self.head_dim).transpose(1, 2) |
531 | | - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, |
532 | | - self.head_dim).transpose(1, 2) |
533 | | - |
534 | | - # TODO: remove padding in image encoder |
535 | | - attn_output = F.scaled_dot_product_attention(q, |
536 | | - k, |
537 | | - v, |
538 | | - attn_mask=attention_mask, |
539 | | - dropout_p=0.0) |
540 | | - |
541 | | - attn_output = attn_output.transpose(1, 2).contiguous() |
| 532 | + |
| 533 | + # Use unified MultiHeadAttention with automatic backend selection |
| 534 | + attn_output = self.attn(q, k, v) |
| 535 | + |
542 | 536 | attn_output = attn_output.reshape(attn_output.shape[0], |
543 | 537 | attn_output.shape[1], -1) |
544 | 538 | output, _ = self.o_proj(attn_output) |
|
0 commit comments