Skip to content

Commit d844ae2

Browse files
committed
Fix MiDashengLM audio encoder by replacing incorrect attention with manual implementation
Signed-off-by: zhoukz <me@zhoukz.com>
1 parent 2317dae commit d844ae2

File tree

1 file changed

+51
-43
lines changed

1 file changed

+51
-43
lines changed

vllm/model_executor/models/midashenglm.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25+
2526
import collections
2627
import collections.abc
2728
from collections.abc import Iterable, Mapping, Sequence
@@ -31,9 +32,9 @@
3132
import torch
3233
import torch.nn as nn
3334
import torchaudio.functional as F
35+
from torch.nn.functional import scaled_dot_product_attention
3436
from transformers import BatchFeature
3537

36-
from vllm.attention.layer import MultiHeadAttention
3738
from vllm.config import VllmConfig
3839
from vllm.distributed import get_tensor_model_parallel_world_size
3940
from vllm.model_executor.layers.activation import get_act_fn
@@ -146,15 +147,19 @@ def __init__(
146147
super().__init__()
147148
out_features = out_features or in_features
148149
hidden_features = hidden_features or in_features
149-
self.fc1 = ColumnParallelLinear(input_size=in_features,
150-
output_size=hidden_features,
151-
quant_config=quant_config,
152-
prefix=f"{prefix}.fc1")
150+
self.fc1 = ColumnParallelLinear(
151+
input_size=in_features,
152+
output_size=hidden_features,
153+
quant_config=quant_config,
154+
prefix=f"{prefix}.fc1",
155+
)
153156
self.act = get_act_fn("gelu")
154-
self.fc2 = RowParallelLinear(input_size=hidden_features,
155-
output_size=out_features,
156-
quant_config=quant_config,
157-
prefix=f"{prefix}.fc2")
157+
self.fc2 = RowParallelLinear(
158+
input_size=hidden_features,
159+
output_size=out_features,
160+
quant_config=quant_config,
161+
prefix=f"{prefix}.fc2",
162+
)
158163

159164
def forward(self, x: torch.Tensor) -> torch.Tensor:
160165
x, _ = self.fc1(x)
@@ -170,7 +175,6 @@ def __init__(
170175
dim: int,
171176
num_heads: int = 8,
172177
qkv_bias: bool = False,
173-
causal: bool = False,
174178
quant_config: Optional[QuantizationConfig] = None,
175179
prefix: str = "",
176180
):
@@ -204,33 +208,30 @@ def __init__(
204208
quant_config=quant_config,
205209
prefix=f"{prefix}.qkv",
206210
)
207-
self.attn = MultiHeadAttention(
208-
self.num_heads,
209-
self.head_dim,
210-
self.scale,
211-
num_kv_heads=self.num_kv_heads,
212-
)
213211
self.proj = RowParallelLinear(
214212
input_size=dim,
215213
output_size=dim,
216214
quant_config=quant_config,
217215
prefix=f"{prefix}.proj",
218216
)
219-
self.causal = causal
220217

221218
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
222219
B, N, C = x.shape
223220

224-
qkv_out, _ = self.qkv(x)
225-
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
226-
dim=-1)
227-
228-
attn_out = self.attn(q, k, v)
229-
C_local = attn_out.numel() // (B * N) # C_local for parallel
230-
attn_out = attn_out.view(B, N, C_local)
221+
qkv, _ = self.qkv(x)
222+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
223+
qkv = qkv.permute(2, 0, 3, 1, 4)
224+
q, k, v = qkv.unbind(0)
231225

232-
x, _ = self.proj(attn_out)
226+
x = scaled_dot_product_attention(
227+
q,
228+
k,
229+
v,
230+
attn_mask=mask[:, None, None, :] if mask is not None else None,
231+
)
233232

233+
x = x.transpose(1, 2).reshape(B, N, C)
234+
x, _ = self.proj(x)
234235
return x
235236

236237

@@ -462,14 +463,16 @@ def __init__(
462463
quant_config=quant_config,
463464
prefix=f"{prefix}.net.0",
464465
return_bias=False,
465-
), get_act_fn("gelu"),
466+
),
467+
get_act_fn("gelu"),
466468
RowParallelLinear(
467469
input_size=out_dim,
468470
output_size=out_dim,
469471
quant_config=quant_config,
470472
prefix=f"{prefix}.net.2",
471473
return_bias=False,
472-
))
474+
),
475+
)
473476

474477
def forward(self, x, mask=None):
475478
batch_size, seq_len, dim = x.shape
@@ -566,9 +569,12 @@ def _call_hf_processor(
566569
# + Padding
567570
min_audio_len = self.info.get_min_audio_len()
568571
processed_audios = [
569-
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
570-
mode='constant',
571-
constant_values=0) if isinstance(audio, np.ndarray)
572+
np.pad(
573+
audio,
574+
(0, min_audio_len - audio.shape[-1]),
575+
mode="constant",
576+
constant_values=0,
577+
) if isinstance(audio, np.ndarray)
572578
and audio.shape[-1] < min_audio_len else audio for audio in audios
573579
]
574580

@@ -617,8 +623,8 @@ def _get_prompt_updates(
617623
if audio_length is None:
618624
audio_output_lengths = []
619625
else:
620-
audio_length_np = audio_length.cpu().numpy() if isinstance(
621-
audio_length, torch.Tensor) else audio_length
626+
audio_length_np = (audio_length.cpu().numpy() if isinstance(
627+
audio_length, torch.Tensor) else audio_length)
622628
audio_output_lengths = [
623629
max(1, calculate_mel_frames_dasheng(
624630
int(length))) # at least one frame
@@ -703,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
703709
def _validate_and_reshape_mm_tensor(self, mm_input: object,
704710
name: str) -> torch.Tensor:
705711
if not isinstance(mm_input, (torch.Tensor, list)):
706-
raise ValueError(f"Incorrect type of {name}. "
707-
f"Got type: {type(mm_input)}")
712+
raise ValueError(
713+
f"Incorrect type of {name}. Got type: {type(mm_input)}")
708714
if isinstance(mm_input, torch.Tensor):
709715
return mm_input.reshape(-1, *mm_input.shape[2:])
710716

@@ -753,8 +759,8 @@ def _process_audio_input(
753759
audio_input["input_values"].dtype)
754760
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
755761

756-
audio_length_np = audio_length.cpu().numpy() if isinstance(
757-
audio_length, torch.Tensor) else audio_length
762+
audio_length_np = (audio_length.cpu().numpy() if isinstance(
763+
audio_length, torch.Tensor) else audio_length)
758764
audio_output_lengths = [
759765
max(1, calculate_mel_frames_dasheng(
760766
int(length))) # at least one frame
@@ -763,11 +769,11 @@ def _process_audio_input(
763769
audio_output_lengths = torch.tensor(audio_output_lengths).to(
764770
audio_embeddings.device)
765771

766-
audio_feature_mask = (torch.arange(
772+
audio_feature_mask = torch.arange(
767773
max_audio_tokens,
768774
device=audio_embeddings.device).unsqueeze(0).expand(
769-
batch_size, max_audio_tokens)
770-
< audio_output_lengths.unsqueeze(1))
775+
batch_size,
776+
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
771777

772778
masked_audio_features = audio_embeddings[audio_feature_mask].view(
773779
-1, embed_dim)
@@ -805,10 +811,12 @@ def forward(
805811
)
806812
input_ids = None
807813

808-
return self.decoder.model(input_ids,
809-
positions,
810-
intermediate_tensors,
811-
inputs_embeds=inputs_embeds)
814+
return self.decoder.model(
815+
input_ids,
816+
positions,
817+
intermediate_tensors,
818+
inputs_embeds=inputs_embeds,
819+
)
812820

813821
def compute_logits(
814822
self,

0 commit comments

Comments
 (0)