Skip to content

Commit b761fd8

Browse files
committed
Fix MiDashengLM audio encoder by replacing incorrect attention with manual implementation
Signed-off-by: zhoukz <me@zhoukz.com>
1 parent 2317dae commit b761fd8

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

vllm/model_executor/models/midashenglm.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import torchaudio.functional as F
3434
from transformers import BatchFeature
3535

36-
from vllm.attention.layer import MultiHeadAttention
3736
from vllm.config import VllmConfig
3837
from vllm.distributed import get_tensor_model_parallel_world_size
3938
from vllm.model_executor.layers.activation import get_act_fn
@@ -204,12 +203,6 @@ def __init__(
204203
quant_config=quant_config,
205204
prefix=f"{prefix}.qkv",
206205
)
207-
self.attn = MultiHeadAttention(
208-
self.num_heads,
209-
self.head_dim,
210-
self.scale,
211-
num_kv_heads=self.num_kv_heads,
212-
)
213206
self.proj = RowParallelLinear(
214207
input_size=dim,
215208
output_size=dim,
@@ -221,15 +214,27 @@ def __init__(
221214
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
222215
B, N, C = x.shape
223216

224-
qkv_out, _ = self.qkv(x)
225-
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
226-
dim=-1)
227-
228-
attn_out = self.attn(q, k, v)
229-
C_local = attn_out.numel() // (B * N) # C_local for parallel
230-
attn_out = attn_out.view(B, N, C_local)
231-
232-
x, _ = self.proj(attn_out)
217+
qkv, _ = self.qkv(x)
218+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
219+
qkv = qkv.permute(2, 0, 3, 1, 4)
220+
q, k, v = qkv.unbind(0)
221+
222+
attn = (q @ k.transpose(-2, -1)) * self.scale
223+
if self.causal:
224+
mask_value = -torch.finfo(attn.dtype).max
225+
i, j = attn.shape[-2:]
226+
mask = torch.ones(i, j, device=q.device,
227+
dtype=torch.bool).triu(j - i + 1)
228+
attn = attn.masked_fill(mask, mask_value)
229+
if mask is not None:
230+
mask_value = torch.finfo(attn.dtype).min
231+
attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
232+
attn = attn.masked_fill(attn_mask, mask_value)
233+
attn = attn.softmax(dim=-1)
234+
attn = torch.nan_to_num(attn)
235+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
236+
237+
x, _ = self.proj(x)
233238

234239
return x
235240

0 commit comments

Comments
 (0)