11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from collections .abc import Iterable
3+ import typing
4+ from collections .abc import Callable , Iterable
45
56import torch
67import torch .nn as nn
78from transformers import PretrainedConfig
89
910from vllm .compilation .decorators import support_torch_compile
1011from vllm .config import VllmConfig
11- from vllm .model_executor .layers .fused_moe import FusedMoE
12+ from vllm .model_executor .layers .fused_moe import SharedFusedMoE
13+
14+ from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
15+ is_rocm_aiter_fusion_shared_expert_enabled ,
16+ )
1217from vllm .model_executor .layers .layernorm import RMSNorm
1318from vllm .model_executor .layers .logits_processor import LogitsProcessor
1419from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -212,11 +217,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
212217 ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
213218 ]
214219
215- expert_params_mapping = FusedMoE .make_expert_params_mapping (
220+ expert_params_mapping = SharedFusedMoE .make_expert_params_mapping (
216221 ckpt_gate_proj_name = "gate_proj" ,
217222 ckpt_down_proj_name = "down_proj" ,
218223 ckpt_up_proj_name = "up_proj" ,
219- num_experts = self .config .n_routed_experts ,
224+ num_experts = self .config .n_routed_experts
225+ + (
226+ self .config .n_shared_experts
227+ if is_rocm_aiter_fusion_shared_expert_enabled ()
228+ else 0
229+ ),
220230 )
221231
222232 params_dict = dict (self .named_parameters ())
@@ -227,6 +237,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
227237 spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
228238 if spec_layer is None :
229239 continue
240+ is_fuse_shared_experts_layer = (
241+ is_rocm_aiter_fusion_shared_expert_enabled ()
242+ and ("mlp.shared_experts" in name )
243+ )
230244 name = self ._rewrite_spec_layer_name (spec_layer , name )
231245 for param_name , weight_name , shard_id in stacked_params_mapping :
232246 # Skip non-stacked layers and experts (experts handled below).
@@ -240,6 +254,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
240254 # for mlp.experts[0].gate_gate_up_proj, which breaks load.
241255 if ("mlp.experts." in name ) and name not in params_dict :
242256 continue
257+ if is_fuse_shared_experts_layer :
258+ continue
243259 name_mapped = name .replace (weight_name , param_name )
244260
245261 # QKV fusion is optional, fall back to normal
@@ -260,45 +276,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260276 weight_loader (param , loaded_weight , shard_id )
261277 break
262278 else :
263- for mapping in expert_params_mapping :
264- param_name , weight_name , expert_id , shard_id = mapping
265- if weight_name not in name :
266- continue
267- name = name .replace (weight_name , param_name )
268-
269- param = params_dict [name ]
270- weight_loader = param .weight_loader
271- weight_loader (
272- param ,
273- loaded_weight ,
274- name ,
275- shard_id = shard_id ,
276- expert_id = expert_id ,
277- )
278- break
279- else :
280- # Skip loading extra bias for GPTQ models.
281- if name .endswith (".bias" ) and name not in params_dict :
282- continue
283-
284- name = maybe_remap_kv_scale_name (name , params_dict )
285- if name is None :
286- continue
287-
288- # According to DeepSeek-V3 Technical Report, MTP modules
289- # shares embedding layer. We only load the first weights.
290- if (
291- spec_layer != self .model .mtp_start_layer_idx
292- and ".layers" not in name
293- ):
294- continue
295-
296- param = params_dict [name ]
297- weight_loader = getattr (
298- param , "weight_loader" , default_weight_loader
279+ # Special handling: when AITER fusion_shared_experts is enabled,
280+ # checkpoints may provide a single widened shared_experts tensor
281+ # without explicit expert indices
282+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
283+ # For models with multiple shared experts, split that tensor
284+ # evenly into per-shared-expert slices and load them into
285+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
286+ # accordingly.
287+ num_chunks = 1
288+ if is_fuse_shared_experts_layer :
289+ num_chunks = getattr (self .config , "n_shared_experts" , 1 ) or 1
290+ # Determine split axis based on op type
291+ # gate/up: ColumnParallel → split along dim 0
292+ # down: RowParallel → split along dim 1
293+ split_dim = 1 if "down_proj.weight" in name else 0
294+ total = loaded_weight .shape [split_dim ]
295+ assert total % num_chunks == 0 , (
296+ f"Shared expert weight dim { total } "
297+ f"not divisible by num_chunks { num_chunks } "
299298 )
300- weight_loader (param , loaded_weight )
301- loaded_params .add (name )
299+ chunk_size = total // num_chunks
300+
301+ for j in range (num_chunks ):
302+ chunk_name = name
303+ weight_to_load = loaded_weight
304+
305+ if is_fuse_shared_experts_layer :
306+ if split_dim == 0 :
307+ weight_to_load = loaded_weight [
308+ j * chunk_size : (j + 1 ) * chunk_size , :
309+ ]
310+ else :
311+ weight_to_load = loaded_weight [
312+ :, j * chunk_size : (j + 1 ) * chunk_size
313+ ]
314+ # Synthesize an expert-style name so expert mapping
315+ # can route it
316+ chunk_name = name .replace (
317+ "mlp.shared_experts" ,
318+ f"mlp.experts.{ self .config .n_routed_experts + j } " ,
319+ )
320+
321+ # Use expert_params_mapping to locate the destination
322+ # param and delegate to its expert-aware weight_loader
323+ # with expert_id.
324+ for mapping in expert_params_mapping :
325+ param_name , weight_name , expert_id , shard_id = mapping
326+ if weight_name not in chunk_name :
327+ continue
328+
329+ # Do not modify `name` since the loop may continue here
330+ # Instead, create a new variable
331+ name_mapped = chunk_name .replace (weight_name , param_name )
332+
333+ param = params_dict [name_mapped ]
334+ # We should ask the weight loader to return success or
335+ # not here since otherwise we may skip experts with
336+ # other available replicas.
337+ weight_loader = typing .cast (
338+ Callable [..., bool ], param .weight_loader
339+ )
340+ success = weight_loader (
341+ param ,
342+ weight_to_load ,
343+ name_mapped ,
344+ shard_id = shard_id ,
345+ expert_id = expert_id ,
346+ return_success = True ,
347+ )
348+ if success :
349+ if not is_fuse_shared_experts_layer :
350+ name = name_mapped
351+ else :
352+ loaded_params .add (name_mapped )
353+ break
354+ else :
355+ # Skip loading extra bias for GPTQ models.
356+ if name .endswith (".bias" ) and name not in params_dict :
357+ continue
358+
359+ name = maybe_remap_kv_scale_name (name , params_dict )
360+ if name is None :
361+ continue
362+
363+ # According to DeepSeek-V3 Technical Report, MTP modules
364+ # shares embedding layer. We only load the first weights.
365+ if (
366+ spec_layer != self .model .mtp_start_layer_idx
367+ and ".layers" not in name
368+ ):
369+ continue
370+
371+ param = params_dict [name ]
372+ weight_loader = getattr (
373+ param , "weight_loader" , default_weight_loader
374+ )
375+ weight_loader (param , loaded_weight )
376+ if not is_fuse_shared_experts_layer :
377+ loaded_params .add (name )
302378 return loaded_params
303379
304380 def _rewrite_spec_layer_name (self , spec_layer : int , name : str ) -> str :
0 commit comments