|
11 | 11 | import vllm_hpu_extension.kernels as kernels |
12 | 12 | import vllm_hpu_extension.ops as ops |
13 | 13 | from vllm_hpu_extension.flags import enabled_flags |
14 | | -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache |
| 14 | +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, |
| 15 | + VLLMKVCache) |
15 | 16 |
|
16 | 17 | import vllm.envs as envs |
17 | 18 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
@@ -133,7 +134,9 @@ def __init__( |
133 | 134 | self.block2batch_matmul = Matmul() |
134 | 135 | self.k_cache = VLLMKVCache() |
135 | 136 | self.v_cache = VLLMKVCache() |
136 | | - self.fused_scaled_dot_product_attention = kernels.fsdpa() |
| 137 | + HPUFusedSDPA = kernels.fsdpa() |
| 138 | + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ |
| 139 | + else ModuleFusedSDPA(HPUFusedSDPA) |
137 | 140 |
|
138 | 141 | self.prefill_impl = 'naive' |
139 | 142 | if "flex_attention" in enabled_flags(): |
@@ -272,16 +275,13 @@ def common_attention_args(self, |
272 | 275 | block_list=None, |
273 | 276 | key_cache=None, |
274 | 277 | value_cache=None): |
275 | | - fsdpa_op = self.fused_scaled_dot_product_attention.apply \ |
276 | | - if self.fused_scaled_dot_product_attention is not None else None |
277 | | - |
278 | 278 | return { |
279 | 279 | 'scale': self.scale, |
280 | 280 | 'matmul_qk_op': self.matmul_qk, |
281 | 281 | 'matmul_av_op': self.matmul_av, |
282 | 282 | 'batch2block_matmul_op': self.batch2block_matmul, |
283 | 283 | 'block2batch_matmul_op': self.block2batch_matmul, |
284 | | - 'fsdpa_op': fsdpa_op, |
| 284 | + 'fsdpa_op': self.fused_scaled_dot_product_attention, |
285 | 285 | 'keys_fetch_func': self.k_cache.fetch_from_cache, |
286 | 286 | 'values_fetch_func': self.v_cache.fetch_from_cache, |
287 | 287 | 'softmax_op': self.softmax, |
|
0 commit comments