3333import torchaudio .functional as F
3434from transformers import BatchFeature
3535
36- from vllm .attention .layer import MultiHeadAttention
3736from vllm .config import VllmConfig
3837from vllm .distributed import get_tensor_model_parallel_world_size
3938from vllm .model_executor .layers .activation import get_act_fn
@@ -204,12 +203,6 @@ def __init__(
204203 quant_config = quant_config ,
205204 prefix = f"{ prefix } .qkv" ,
206205 )
207- self .attn = MultiHeadAttention (
208- self .num_heads ,
209- self .head_dim ,
210- self .scale ,
211- num_kv_heads = self .num_kv_heads ,
212- )
213206 self .proj = RowParallelLinear (
214207 input_size = dim ,
215208 output_size = dim ,
@@ -221,15 +214,27 @@ def __init__(
221214 def forward (self , x : torch .Tensor , mask : Optional [torch .Tensor ] = None ):
222215 B , N , C = x .shape
223216
224- qkv_out , _ = self .qkv (x )
225- q , k , v = qkv_out .split ([self .q_size , self .kv_size , self .kv_size ],
226- dim = - 1 )
227-
228- attn_out = self .attn (q , k , v )
229- C_local = attn_out .numel () // (B * N ) # C_local for parallel
230- attn_out = attn_out .view (B , N , C_local )
231-
232- x , _ = self .proj (attn_out )
217+ qkv , _ = self .qkv (x )
218+ qkv = qkv .reshape (B , N , 3 , self .num_heads , C // self .num_heads )
219+ qkv = qkv .permute (2 , 0 , 3 , 1 , 4 )
220+ q , k , v = qkv .unbind (0 )
221+
222+ attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
223+ if self .causal :
224+ mask_value = - torch .finfo (attn .dtype ).max
225+ i , j = attn .shape [- 2 :]
226+ mask = torch .ones (i , j , device = q .device ,
227+ dtype = torch .bool ).triu (j - i + 1 )
228+ attn = attn .masked_fill (mask , mask_value )
229+ if mask is not None :
230+ mask_value = torch .finfo (attn .dtype ).min
231+ attn_mask = mask [:, None , None , :].expand (B , 1 , N , N )
232+ attn = attn .masked_fill (attn_mask , mask_value )
233+ attn = attn .softmax (dim = - 1 )
234+ attn = torch .nan_to_num (attn )
235+ x = (attn @ v ).transpose (1 , 2 ).reshape (B , N , C )
236+
237+ x , _ = self .proj (x )
233238
234239 return x
235240
0 commit comments