44# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
55###############################################################################
66
7- import os
87from dataclasses import dataclass
98from typing import Any , Dict , List , Optional , Tuple , Type
109
1110import torch
11+ import vllm_hpu_extension .kernels as kernels
1212import vllm_hpu_extension .ops as ops
13- from vllm_hpu_extension .utils import ( Matmul , ModuleFusedSDPA , Softmax ,
14- VLLMKVCache )
13+ from vllm_hpu_extension .flags import enabled_flags
14+ from vllm_hpu_extension . utils import Matmul , Softmax , VLLMKVCache
1515
1616from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
1717 AttentionLayer ,
@@ -126,7 +126,15 @@ def __init__(
126126 self .block2batch_matmul = Matmul ()
127127 self .k_cache = VLLMKVCache ()
128128 self .v_cache = VLLMKVCache ()
129- ops .pa_impl = ops .pa
129+ self .fused_scaled_dot_product_attention = kernels .fsdpa ()
130+
131+ self .prefill_impl = 'naive'
132+ if "flex_attention" in enabled_flags ():
133+ self .prefill_impl = 'flex'
134+ if "fsdpa" in enabled_flags ():
135+ assert alibi_slopes is None , \
136+ 'Prefill with FusedSDPA not supported with alibi slopes!'
137+ self .prefill_impl = 'fsdpa'
130138
131139 self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
132140 self .sliding_window = sliding_window
@@ -138,27 +146,18 @@ def __init__(
138146 assert self .num_heads % self .num_kv_heads == 0
139147 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
140148
141- self .prefill_usefusedsdpa = os .getenv ('VLLM_PROMPT_USE_FUSEDSDPA' ,
142- '0' ).lower () in ['1' , 'true' ]
143- self .fused_scaled_dot_product_attention = None
144- if self .prefill_usefusedsdpa :
149+ if self .prefill_impl == 'fsdpa' :
145150 assert alibi_slopes is None , \
146151 'Prefill with FusedSDPA not supported with alibi slopes!'
147- try :
148- from habana_frameworks .torch .hpex .kernels import FusedSDPA
149- self .fused_scaled_dot_product_attention = ModuleFusedSDPA (
150- FusedSDPA )
151- except ImportError :
152- logger .warning ("Could not import HPU FusedSDPA kernel. "
153- "vLLM will use native implementation." )
154152
155153 supported_head_sizes = HPUPagedAttention .get_supported_head_sizes ()
156154 if head_size not in supported_head_sizes :
157155 raise ValueError (
158156 f"Head size { head_size } is not supported by PagedAttention. "
159157 f"Supported head sizes are: { supported_head_sizes } ." )
160158
161- if attn_type != AttentionType .DECODER :
159+ self .attn_type = attn_type
160+ if self .attn_type != AttentionType .DECODER :
162161 raise NotImplementedError ("Encoder self-attention and "
163162 "encoder/decoder cross-attention "
164163 "are not implemented for "
@@ -192,15 +191,18 @@ def forward(
192191 batch_size , seq_len , hidden_size = query .shape
193192 _ , seq_len_kv , _ = key .shape
194193
195- query = query .view (- 1 , self .num_heads , self .head_size )
196194 key = key .view (- 1 , self .num_kv_heads , self .head_size )
197195 value = value .view (- 1 , self .num_kv_heads , self .head_size )
198196 block_indices = attn_metadata .block_indices
199197 block_offsets = attn_metadata .block_offsets
200- if attn_metadata .is_prompt :
198+ key_cache = None
199+ value_cache = None
200+ if attn_metadata .is_prompt and self .attn_type \
201+ is not AttentionType .ENCODER_ONLY \
202+ and attn_metadata .block_list is None :
201203 key = key .unflatten (0 , (block_indices .size (0 ), - 1 ))
202204 value = value .unflatten (0 , (block_indices .size (0 ), - 1 ))
203- if kv_cache is not None :
205+ if kv_cache is not None and isinstance ( kv_cache , tuple ) :
204206 key_cache , value_cache = HPUPagedAttention .split_kv_cache (
205207 kv_cache , self .num_kv_heads , self .head_size )
206208
@@ -214,36 +216,28 @@ def forward(
214216
215217 if attn_metadata .is_prompt :
216218 # Prompt run.
217- if not self .prefill_usefusedsdpa :
218- # TODO: move this outside of model
219- assert attn_metadata .attn_bias is not None , \
220- 'attn_bias must be set before calling model.forward!'
221- attn_bias = attn_metadata .attn_bias
222- if self .alibi_slopes is not None :
223- position_bias = _make_alibi_bias (self .alibi_slopes ,
224- self .num_kv_heads ,
225- attn_bias .dtype ,
226- attn_bias .shape [- 1 ])
227- attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
228- attn_bias .add_ (position_bias )
229- else :
230- attn_bias = None
231-
232219 query_shape = (batch_size , seq_len , self .num_heads , self .head_size )
233220 kv_shape = (batch_size , seq_len_kv , self .num_kv_heads ,
234221 self .head_size )
222+
223+ attn_bias = attn_metadata .attn_bias
224+ if attn_bias is not None and self .alibi_slopes is not None :
225+ position_bias = _make_alibi_bias (self .alibi_slopes ,
226+ self .num_kv_heads ,
227+ attn_bias .dtype ,
228+ attn_bias .shape [- 1 ])
229+ attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
230+ attn_bias .add_ (position_bias )
231+
235232 out = ops .prompt_attention (
236- query .view (query_shape ),
237- key .view (kv_shape ),
238- value .view (kv_shape ),
233+ impl = self .prefill_impl ,
234+ query = query .view (query_shape ),
235+ key = key .view (kv_shape ),
236+ value = value .view (kv_shape ),
237+ is_causal = True ,
239238 attn_bias = attn_bias ,
240- p = 0.0 ,
241- scale = self .scale ,
242- matmul_qk_op = self .matmul_qk ,
243- softmax_op = self .softmax ,
244- matmul_av_op = self .matmul_av ,
245- fsdpa_op = self .fused_scaled_dot_product_attention ,
246- )
239+ valid_seq_lengths = attn_metadata .seq_lens_tensor ,
240+ ** self .common_attention_args ())
247241 output = out .reshape (batch_size , seq_len , hidden_size )
248242 else :
249243 # Decoding run.
@@ -254,18 +248,26 @@ def forward(
254248 block_list = attn_metadata .block_list ,
255249 block_mapping = attn_metadata .block_mapping ,
256250 block_bias = attn_metadata .attn_bias ,
257- block_scales = attn_metadata .block_scales ,
258251 block_groups = attn_metadata .block_groups ,
259- scale = self .scale ,
260- matmul_qk_op = self .matmul_qk ,
261- matmul_av_op = self .matmul_av ,
262- batch2block_matmul_op = self .batch2block_matmul ,
263- block2batch_matmul_op = self .block2batch_matmul ,
264- keys_fetch_func = self .k_cache .fetch_from_cache ,
265- values_fetch_func = self .v_cache .fetch_from_cache )
252+ ** self .common_attention_args ())
266253 # Reshape the output tensor.
267254 return output .view (batch_size , seq_len , hidden_size )
268255
256+ def common_attention_args (self ):
257+ fsdpa_op = self .fused_scaled_dot_product_attention .apply \
258+ if self .fused_scaled_dot_product_attention is not None else None
259+ return {
260+ 'scale' : self .scale ,
261+ 'matmul_qk_op' : self .matmul_qk ,
262+ 'matmul_av_op' : self .matmul_av ,
263+ 'batch2block_matmul_op' : self .batch2block_matmul ,
264+ 'block2batch_matmul_op' : self .block2batch_matmul ,
265+ 'fsdpa_op' : fsdpa_op ,
266+ 'keys_fetch_func' : self .k_cache .fetch_from_cache ,
267+ 'values_fetch_func' : self .v_cache .fetch_from_cache ,
268+ 'softmax_op' : self .softmax ,
269+ }
270+
269271
270272def _make_alibi_bias (
271273 alibi_slopes : torch .Tensor ,
0 commit comments