Skip to content

Commit 6a0fa78

Browse files
zhyajieroot
authored andcommitted
[Bugfix][Rocm] Fix shared expert weight loading failure in DeepSeek-MTP
1 parent cbd5e07 commit 6a0fa78

File tree

1 file changed

+118
-42
lines changed

1 file changed

+118
-42
lines changed

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 118 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from collections.abc import Iterable
3+
import typing
4+
from collections.abc import Callable, Iterable
45

56
import torch
67
import torch.nn as nn
78
from transformers import PretrainedConfig
89

910
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import VllmConfig
11-
from vllm.model_executor.layers.fused_moe import FusedMoE
12+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
13+
14+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
15+
is_rocm_aiter_fusion_shared_expert_enabled,
16+
)
1217
from vllm.model_executor.layers.layernorm import RMSNorm
1318
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1419
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -212,11 +217,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
212217
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
213218
]
214219

215-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
220+
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
216221
ckpt_gate_proj_name="gate_proj",
217222
ckpt_down_proj_name="down_proj",
218223
ckpt_up_proj_name="up_proj",
219-
num_experts=self.config.n_routed_experts,
224+
num_experts=self.config.n_routed_experts
225+
+ (
226+
self.config.n_shared_experts
227+
if is_rocm_aiter_fusion_shared_expert_enabled()
228+
else 0
229+
),
220230
)
221231

222232
params_dict = dict(self.named_parameters())
@@ -227,6 +237,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
227237
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
228238
if spec_layer is None:
229239
continue
240+
is_fuse_shared_experts_layer = (
241+
is_rocm_aiter_fusion_shared_expert_enabled()
242+
and ("mlp.shared_experts" in name)
243+
)
230244
name = self._rewrite_spec_layer_name(spec_layer, name)
231245
for param_name, weight_name, shard_id in stacked_params_mapping:
232246
# Skip non-stacked layers and experts (experts handled below).
@@ -240,6 +254,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
240254
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
241255
if ("mlp.experts." in name) and name not in params_dict:
242256
continue
257+
if is_fuse_shared_experts_layer:
258+
continue
243259
name_mapped = name.replace(weight_name, param_name)
244260

245261
# QKV fusion is optional, fall back to normal
@@ -260,45 +276,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260276
weight_loader(param, loaded_weight, shard_id)
261277
break
262278
else:
263-
for mapping in expert_params_mapping:
264-
param_name, weight_name, expert_id, shard_id = mapping
265-
if weight_name not in name:
266-
continue
267-
name = name.replace(weight_name, param_name)
268-
269-
param = params_dict[name]
270-
weight_loader = param.weight_loader
271-
weight_loader(
272-
param,
273-
loaded_weight,
274-
name,
275-
shard_id=shard_id,
276-
expert_id=expert_id,
277-
)
278-
break
279-
else:
280-
# Skip loading extra bias for GPTQ models.
281-
if name.endswith(".bias") and name not in params_dict:
282-
continue
283-
284-
name = maybe_remap_kv_scale_name(name, params_dict)
285-
if name is None:
286-
continue
287-
288-
# According to DeepSeek-V3 Technical Report, MTP modules
289-
# shares embedding layer. We only load the first weights.
290-
if (
291-
spec_layer != self.model.mtp_start_layer_idx
292-
and ".layers" not in name
293-
):
294-
continue
295-
296-
param = params_dict[name]
297-
weight_loader = getattr(
298-
param, "weight_loader", default_weight_loader
279+
# Special handling: when AITER fusion_shared_experts is enabled,
280+
# checkpoints may provide a single widened shared_experts tensor
281+
# without explicit expert indices
282+
# (e.g. ...mlp.shared_experts.gate_proj.weight).
283+
# For models with multiple shared experts, split that tensor
284+
# evenly into per-shared-expert slices and load them into
285+
# appended expert slots mlp.experts.{n_routed_experts + j}.*
286+
# accordingly.
287+
num_chunks = 1
288+
if is_fuse_shared_experts_layer:
289+
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
290+
# Determine split axis based on op type
291+
# gate/up: ColumnParallel → split along dim 0
292+
# down: RowParallel → split along dim 1
293+
split_dim = 1 if "down_proj.weight" in name else 0
294+
total = loaded_weight.shape[split_dim]
295+
assert total % num_chunks == 0, (
296+
f"Shared expert weight dim {total} "
297+
f"not divisible by num_chunks {num_chunks}"
299298
)
300-
weight_loader(param, loaded_weight)
301-
loaded_params.add(name)
299+
chunk_size = total // num_chunks
300+
301+
for j in range(num_chunks):
302+
chunk_name = name
303+
weight_to_load = loaded_weight
304+
305+
if is_fuse_shared_experts_layer:
306+
if split_dim == 0:
307+
weight_to_load = loaded_weight[
308+
j * chunk_size : (j + 1) * chunk_size, :
309+
]
310+
else:
311+
weight_to_load = loaded_weight[
312+
:, j * chunk_size : (j + 1) * chunk_size
313+
]
314+
# Synthesize an expert-style name so expert mapping
315+
# can route it
316+
chunk_name = name.replace(
317+
"mlp.shared_experts",
318+
f"mlp.experts.{self.config.n_routed_experts + j}",
319+
)
320+
321+
# Use expert_params_mapping to locate the destination
322+
# param and delegate to its expert-aware weight_loader
323+
# with expert_id.
324+
for mapping in expert_params_mapping:
325+
param_name, weight_name, expert_id, shard_id = mapping
326+
if weight_name not in chunk_name:
327+
continue
328+
329+
# Do not modify `name` since the loop may continue here
330+
# Instead, create a new variable
331+
name_mapped = chunk_name.replace(weight_name, param_name)
332+
333+
param = params_dict[name_mapped]
334+
# We should ask the weight loader to return success or
335+
# not here since otherwise we may skip experts with
336+
# other available replicas.
337+
weight_loader = typing.cast(
338+
Callable[..., bool], param.weight_loader
339+
)
340+
success = weight_loader(
341+
param,
342+
weight_to_load,
343+
name_mapped,
344+
shard_id=shard_id,
345+
expert_id=expert_id,
346+
return_success=True,
347+
)
348+
if success:
349+
if not is_fuse_shared_experts_layer:
350+
name = name_mapped
351+
else:
352+
loaded_params.add(name_mapped)
353+
break
354+
else:
355+
# Skip loading extra bias for GPTQ models.
356+
if name.endswith(".bias") and name not in params_dict:
357+
continue
358+
359+
name = maybe_remap_kv_scale_name(name, params_dict)
360+
if name is None:
361+
continue
362+
363+
# According to DeepSeek-V3 Technical Report, MTP modules
364+
# shares embedding layer. We only load the first weights.
365+
if (
366+
spec_layer != self.model.mtp_start_layer_idx
367+
and ".layers" not in name
368+
):
369+
continue
370+
371+
param = params_dict[name]
372+
weight_loader = getattr(
373+
param, "weight_loader", default_weight_loader
374+
)
375+
weight_loader(param, loaded_weight)
376+
if not is_fuse_shared_experts_layer:
377+
loaded_params.add(name)
302378
return loaded_params
303379

304380
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:

0 commit comments

Comments
 (0)