Skip to content

Commit 10eab15

Browse files
committed
fix vit attn for models like THUDM/GLM-4v-9B on xpu (vllm-project#339)
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent 72177d9 commit 10eab15

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

vllm/attention/layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,7 @@ def __init__(
505505
use_upstream_fa = False
506506

507507
if current_platform.is_xpu():
508-
# currently, only torch_sdpa is supported on xpu
509-
self.attn_backend = _Backend.TORCH_SDPA
508+
self.attn_backend = _Backend.IPEX
510509
else:
511510
self.attn_backend = (
512511
backend
@@ -593,7 +592,10 @@ def forward(
593592
out = xops.memory_efficient_attention_forward(
594593
query, key, value, scale=self.scale
595594
)
596-
elif self.attn_backend == _Backend.TORCH_SDPA:
595+
elif (
596+
self.attn_backend == _Backend.TORCH_SDPA
597+
or self.attn_backend == _Backend.IPEX
598+
):
597599
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
598600
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
599601
out = out.transpose(1, 2)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ def __init__(
719719
_Backend.TORCH_SDPA,
720720
_Backend.XFORMERS,
721721
_Backend.ROCM_AITER_FA,
722+
_Backend.IPEX,
722723
}:
723724
raise RuntimeError(
724725
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@@ -855,12 +856,11 @@ def compute_attn_mask_seqlen(
855856
if (
856857
self.attn_backend == _Backend.FLASH_ATTN
857858
or self.attn_backend == _Backend.ROCM_AITER_FA
859+
or self.attn_backend == _Backend.IPEX
858860
):
859861
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
860862
elif self.attn_backend == _Backend.XFORMERS:
861863
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
862-
elif self.attn_backend == _Backend.IPEX:
863-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
864864
return max_seqlen, seqlens
865865

866866
@staticmethod

vllm/model_executor/models/qwen2_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,12 +816,11 @@ def compute_attn_mask_seqlen(
816816
if (
817817
self.attn_backend == _Backend.FLASH_ATTN
818818
or self.attn_backend == _Backend.ROCM_AITER_FA
819+
or self.attn_backend == _Backend.IPEX
819820
):
820821
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
821822
elif self.attn_backend == _Backend.XFORMERS:
822823
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
823-
elif self.attn_backend == _Backend.IPEX:
824-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
825824
return max_seqlen, seqlens
826825

827826
def forward(

vllm/platforms/xpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
116116
return device_props.total_memory
117117

118118
@classmethod
119-
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
119+
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
120+
from vllm.attention.backends.registry import _Backend
121+
120122
return _Backend.IPEX
121123

122124
@classmethod

0 commit comments

Comments
 (0)