Skip to content

Commit ef3b428

Browse files
Restore fsdpa calibration (#1086)
#942 introduced a fsdpa calibration regression. This PR restores the functionality. Signed-off-by: Michal Adamczyk <madamczyk@habana.ai> Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
1 parent f3d849c commit ef3b428

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import vllm_hpu_extension.kernels as kernels
1212
import vllm_hpu_extension.ops as ops
1313
from vllm_hpu_extension.flags import enabled_flags
14-
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
14+
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
15+
VLLMKVCache)
1516

1617
import vllm.envs as envs
1718
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -133,7 +134,9 @@ def __init__(
133134
self.block2batch_matmul = Matmul()
134135
self.k_cache = VLLMKVCache()
135136
self.v_cache = VLLMKVCache()
136-
self.fused_scaled_dot_product_attention = kernels.fsdpa()
137+
HPUFusedSDPA = kernels.fsdpa()
138+
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
139+
else ModuleFusedSDPA(HPUFusedSDPA)
137140

138141
self.prefill_impl = 'naive'
139142
if "flex_attention" in enabled_flags():
@@ -272,16 +275,13 @@ def common_attention_args(self,
272275
block_list=None,
273276
key_cache=None,
274277
value_cache=None):
275-
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
276-
if self.fused_scaled_dot_product_attention is not None else None
277-
278278
return {
279279
'scale': self.scale,
280280
'matmul_qk_op': self.matmul_qk,
281281
'matmul_av_op': self.matmul_av,
282282
'batch2block_matmul_op': self.batch2block_matmul,
283283
'block2batch_matmul_op': self.block2batch_matmul,
284-
'fsdpa_op': fsdpa_op,
284+
'fsdpa_op': self.fused_scaled_dot_product_attention,
285285
'keys_fetch_func': self.k_cache.fetch_from_cache,
286286
'values_fetch_func': self.v_cache.fetch_from_cache,
287287
'softmax_op': self.softmax,

0 commit comments

Comments
 (0)