Skip to content

Commit 4a770f2

Browse files
committed
fix vit attn for models like THUDM/GLM-4v-9B on xpu (vllm-project#339)
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent 3bbd945 commit 4a770f2

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

vllm/attention/layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,7 @@ def __init__(
522522
use_upstream_fa = False
523523

524524
if current_platform.is_xpu():
525-
# currently, only torch_sdpa is supported on xpu
526-
self.attn_backend = _Backend.TORCH_SDPA
525+
self.attn_backend = _Backend.IPEX
527526
else:
528527
self.attn_backend = (
529528
backend
@@ -611,7 +610,10 @@ def forward(
611610
out = xops.memory_efficient_attention_forward(
612611
query, key, value, scale=self.scale
613612
)
614-
elif self.attn_backend == _Backend.TORCH_SDPA:
613+
elif (
614+
self.attn_backend == _Backend.TORCH_SDPA
615+
or self.attn_backend == _Backend.IPEX
616+
):
615617
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
616618
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
617619
out = out.transpose(1, 2)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def __init__(
711711
_Backend.TORCH_SDPA,
712712
_Backend.XFORMERS,
713713
_Backend.ROCM_AITER_FA,
714+
_Backend.IPEX,
714715
}:
715716
raise RuntimeError(
716717
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@@ -851,6 +852,7 @@ def compute_attn_mask_seqlen(
851852
if (
852853
self.attn_backend == _Backend.FLASH_ATTN
853854
or self.attn_backend == _Backend.ROCM_AITER_FA
855+
or self.attn_backend == _Backend.IPEX
854856
):
855857
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
856858
elif self.attn_backend == _Backend.XFORMERS:

vllm/model_executor/models/qwen2_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -822,12 +822,11 @@ def compute_attn_mask_seqlen(
822822
if (
823823
self.attn_backend == _Backend.FLASH_ATTN
824824
or self.attn_backend == _Backend.ROCM_AITER_FA
825+
or self.attn_backend == _Backend.IPEX
825826
):
826827
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
827828
elif self.attn_backend == _Backend.XFORMERS:
828829
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
829-
elif self.attn_backend == _Backend.IPEX:
830-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
831830
return max_seqlen, seqlens
832831

833832
def forward(

vllm/platforms/xpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
116116
return device_props.total_memory
117117

118118
@classmethod
119-
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
119+
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
120+
from vllm.attention.backends.registry import _Backend
121+
120122
return _Backend.IPEX
121123

122124
@classmethod

0 commit comments

Comments
 (0)