Skip to content

Commit 9dce93e

Browse files
CSWYF3634076yewentao256
authored andcommitted
[Bugfix][Model]fix ernie45 moe gate&bias dtype to float32 (#25936)
Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent c0734fc commit 9dce93e

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

vllm/model_executor/models/ernie45_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,12 @@ def __init__(
120120
self.gate = ReplicatedLinear(config.hidden_size,
121121
config.moe_num_experts,
122122
bias=False,
123+
params_dtype=torch.float32,
123124
quant_config=None,
124125
prefix=f"{prefix}.gate")
125126

126127
self.gate.e_score_correction_bias = nn.Parameter(
127-
torch.empty(config.moe_num_experts))
128+
torch.empty(config.moe_num_experts, dtype=torch.float32))
128129

129130
self.experts = FusedMoE(
130131
num_experts=config.moe_num_experts,
@@ -157,7 +158,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
157158
if self.has_shared_experts:
158159
shared_output = self.shared_experts(hidden_states)
159160

160-
router_logits, _ = self.gate(hidden_states)
161+
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
161162

162163
final_hidden_states = self.experts(hidden_states=hidden_states,
163164
router_logits=router_logits)

vllm/model_executor/models/ernie45_vl_moe.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(
199199

200200
assert config.moe_num_experts[0] == config.moe_num_experts[1]
201201
self.e_score_correction_bias = nn.Parameter(
202-
torch.empty(2, config.moe_num_experts[0]))
202+
torch.empty(2, config.moe_num_experts[0], dtype=torch.float32))
203203

204204
assert text_moe_layer_start_index <= text_moe_layer_end_index
205205

@@ -209,6 +209,7 @@ def __init__(
209209
config.hidden_size,
210210
config.moe_num_experts[0],
211211
bias=False,
212+
params_dtype=torch.float32,
212213
quant_config=quant_config,
213214
prefix=f"{prefix}.text_experts_gate")
214215

@@ -238,6 +239,7 @@ def __init__(
238239
config.hidden_size,
239240
config.moe_num_experts[1],
240241
bias=False,
242+
params_dtype=torch.float32,
241243
quant_config=quant_config,
242244
prefix=f"{prefix}.vision_experts_gate")
243245

@@ -288,7 +290,8 @@ def forward(
288290

289291
if visual_token_mask is not None and visual_token_mask.all():
290292
# only vision modal input
291-
router_logits, _ = self.vision_experts_gate(hidden_states)
293+
router_logits, _ = self.vision_experts_gate(
294+
hidden_states.to(dtype=torch.float32))
292295
final_hidden_states = self.vision_experts(
293296
hidden_states=hidden_states, router_logits=router_logits)
294297
elif visual_token_mask is not None and visual_token_mask.any():
@@ -303,19 +306,21 @@ def forward(
303306
vision_hidden_states = hidden_states[visual_token_mask].reshape(
304307
-1, self.hidden_size)
305308

306-
text_router_logits, _ = self.text_experts_gate(text_hidden_states)
309+
text_router_logits, _ = self.text_experts_gate(
310+
text_hidden_states.to(dtype=torch.float32))
307311
final_hidden_states[text_token_mask] = self.text_experts(
308312
hidden_states=text_hidden_states,
309313
router_logits=text_router_logits).flatten()
310314

311315
vision_router_logits, _ = self.vision_experts_gate(
312-
vision_hidden_states)
316+
vision_hidden_states.to(dtype=torch.float32))
313317
final_hidden_states[visual_token_mask] = self.vision_experts(
314318
hidden_states=vision_hidden_states,
315319
router_logits=vision_router_logits).flatten()
316320
else:
317321
# only text modal input
318-
text_router_logits, _ = self.text_experts_gate(hidden_states)
322+
text_router_logits, _ = self.text_experts_gate(
323+
hidden_states.to(dtype=torch.float32))
319324

320325
final_hidden_states = self.text_experts(
321326
hidden_states=hidden_states, router_logits=text_router_logits)

0 commit comments

Comments
 (0)