Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)

import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -133,7 +134,9 @@ def __init__(
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.fused_scaled_dot_product_attention = kernels.fsdpa()
HPUFusedSDPA = kernels.fsdpa()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)

self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
Expand Down Expand Up @@ -272,16 +275,13 @@ def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None

return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'fsdpa_op': self.fused_scaled_dot_product_attention,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
Expand Down