Skip to content

Commit 55b1a85

Browse files
committed
fix vit attn for models like THUDM/GLM-4v-9B on xpu (vllm-project#339)
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent 436cc5b commit 55b1a85

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

vllm/attention/layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,7 @@ def __init__(
522522
use_upstream_fa = False
523523

524524
if current_platform.is_xpu():
525-
# currently, only torch_sdpa is supported on xpu
526-
self.attn_backend = _Backend.TORCH_SDPA
525+
self.attn_backend = _Backend.IPEX
527526
else:
528527
self.attn_backend = (
529528
backend
@@ -611,7 +610,10 @@ def forward(
611610
out = xops.memory_efficient_attention_forward(
612611
query, key, value, scale=self.scale
613612
)
614-
elif self.attn_backend == _Backend.TORCH_SDPA:
613+
elif (
614+
self.attn_backend == _Backend.TORCH_SDPA
615+
or self.attn_backend == _Backend.IPEX
616+
):
615617
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
616618
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
617619
out = out.transpose(1, 2)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ def __init__(
725725
_Backend.TORCH_SDPA,
726726
_Backend.XFORMERS,
727727
_Backend.ROCM_AITER_FA,
728+
_Backend.IPEX,
728729
}:
729730
raise RuntimeError(
730731
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@@ -861,12 +862,11 @@ def compute_attn_mask_seqlen(
861862
if (
862863
self.attn_backend == _Backend.FLASH_ATTN
863864
or self.attn_backend == _Backend.ROCM_AITER_FA
865+
or self.attn_backend == _Backend.IPEX
864866
):
865867
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
866868
elif self.attn_backend == _Backend.XFORMERS:
867869
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
868-
elif self.attn_backend == _Backend.IPEX:
869-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
870870
return max_seqlen, seqlens
871871

872872
@staticmethod

vllm/model_executor/models/qwen2_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -822,12 +822,11 @@ def compute_attn_mask_seqlen(
822822
if (
823823
self.attn_backend == _Backend.FLASH_ATTN
824824
or self.attn_backend == _Backend.ROCM_AITER_FA
825+
or self.attn_backend == _Backend.IPEX
825826
):
826827
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
827828
elif self.attn_backend == _Backend.XFORMERS:
828829
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
829-
elif self.attn_backend == _Backend.IPEX:
830-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
831830
return max_seqlen, seqlens
832831

833832
def forward(

vllm/platforms/xpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
116116
return device_props.total_memory
117117

118118
@classmethod
119-
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
119+
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
120+
from vllm.attention.backends.registry import _Backend
121+
120122
return _Backend.IPEX
121123

122124
@classmethod

0 commit comments

Comments
 (0)