55from collections .abc import Callable , Iterable
66from contextlib import nullcontext
77from enum import Enum
8+ from functools import partial
89from typing import Literal , get_args , overload
910
1011import torch
1112import torch .nn .functional as F
1213from torch .nn .parameter import UninitializedParameter
1314
1415import vllm .envs as envs
15- from vllm .config import get_current_vllm_config
16+ from vllm .config import VllmConfig , get_current_vllm_config
1617from vllm .config .parallel import ExpertPlacementStrategy
1718from vllm .distributed import (
1819 get_dp_group ,
3940 FusedMoEPrepareAndFinalize ,
4041)
4142from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
43+ init_aiter_topK_meta_data ,
44+ is_rocm_aiter_fusion_shared_expert_enabled ,
4245 is_rocm_aiter_moe_enabled ,
4346)
4447from vllm .model_executor .layers .fused_moe .routing_simulator import RoutingSimulator
@@ -87,7 +90,7 @@ def _eplb_map_to_physical_and_record(
8790
8891if is_rocm_aiter_moe_enabled ():
8992 from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import ( # noqa: E501
90- rocm_aiter_grouped_topk as grouped_topk ,
93+ rocm_aiter_grouped_topk as grouped_topk_aiter ,
9194 )
9295else :
9396 from vllm .model_executor .layers .fused_moe .fused_moe import grouped_topk
@@ -634,6 +637,7 @@ def forward_cuda(
634637 global_num_experts = global_num_experts ,
635638 zero_expert_num = zero_expert_num ,
636639 zero_expert_type = zero_expert_type ,
640+ num_fused_shared_experts = layer .num_fused_shared_experts ,
637641 )
638642
639643 if self .rocm_aiter_moe_enabled :
@@ -860,7 +864,8 @@ def determine_expert_map(
860864 ep_rank : int ,
861865 global_num_experts : int ,
862866 expert_placement_strategy : ExpertPlacementStrategy = "linear" ,
863- ) -> tuple [int , torch .Tensor | None ]:
867+ num_fused_shared_experts : int = 0 ,
868+ ) -> tuple [int , torch .Tensor | None , torch .Tensor | None ]:
864869 """
865870 Calculates how many experts should be assigned to each rank for EP and
866871 creates a mapping from global to local expert index. Experts are
@@ -882,10 +887,16 @@ def determine_expert_map(
882887 (global_num_experts,) mapping from global to local index.
883888 Contains -1 for experts not assigned to the current rank.
884889 Returns None if ep_size is 1.
890+ - expert_mask (Optional[torch.Tensor]): A tensor of shape
891+ (global_num_experts + num_fused_shared_experts + 1,)
892+ containing 1 for experts assigned to the current rank
893+ and 0 for sentinel.
894+ Returns None if ep_size is 1.
895+ Used only when AITER MOE is enabled.
885896 """
886897 assert ep_size > 0
887898 if ep_size == 1 :
888- return (global_num_experts , None )
899+ return (global_num_experts , None , None )
889900
890901 # Distribute experts as evenly as possible to each rank.
891902 base_experts = global_num_experts // ep_size
@@ -914,7 +925,26 @@ def determine_expert_map(
914925 f"'{ expert_placement_strategy } ', expected one of "
915926 f"{ get_args (ExpertPlacementStrategy )} "
916927 )
917- return (local_num_experts , expert_map )
928+
929+ expert_mask = None
930+ if is_rocm_aiter_moe_enabled ():
931+ expert_mask = torch .ones (
932+ (global_num_experts + num_fused_shared_experts + 1 ,), dtype = torch .int32
933+ )
934+ expert_mask [- 1 ] = 0
935+ expert_mask [:global_num_experts ] = expert_map > - 1
936+ expert_map = torch .cat (
937+ (
938+ expert_map ,
939+ torch .tensor (
940+ [local_num_experts + i for i in range (num_fused_shared_experts )],
941+ dtype = torch .int32 ,
942+ ),
943+ ),
944+ dim = 0 ,
945+ )
946+
947+ return (local_num_experts , expert_map , expert_mask )
918948
919949
920950def get_compressed_expert_map (expert_map : torch .Tensor ) -> str :
@@ -1040,6 +1070,7 @@ def __init__(
10401070 zero_expert_num : int | None = 0 ,
10411071 zero_expert_type : str | None = None ,
10421072 expert_mapping : list [tuple [str , str , int , str ]] | None = None ,
1073+ n_shared_experts : int | None = None ,
10431074 ):
10441075 super ().__init__ ()
10451076 if params_dtype is None :
@@ -1096,6 +1127,22 @@ def __init__(
10961127 self .logical_to_physical_map : torch .Tensor | None = None
10971128 self .logical_replica_count : torch .Tensor | None = None
10981129
1130+ # ROCm aiter shared experts fusion
1131+ self .num_fused_shared_experts = (
1132+ n_shared_experts
1133+ if n_shared_experts is not None
1134+ and is_rocm_aiter_fusion_shared_expert_enabled ()
1135+ else 0
1136+ )
1137+ if (
1138+ not is_rocm_aiter_fusion_shared_expert_enabled ()
1139+ and self .num_fused_shared_experts != 0
1140+ ):
1141+ raise ValueError (
1142+ "n_shared_experts is only supported on ROCm aiter when "
1143+ "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
1144+ )
1145+
10991146 # Determine expert maps
11001147 if self .use_ep :
11011148 if self .enable_eplb :
@@ -1129,14 +1176,16 @@ def __init__(
11291176 expert_placement_strategy = "linear"
11301177
11311178 self .expert_map : torch .Tensor | None
1132- local_num_experts , expert_map = determine_expert_map (
1179+ local_num_experts , expert_map , expert_mask = determine_expert_map (
11331180 ep_size = self .ep_size ,
11341181 ep_rank = self .ep_rank ,
11351182 global_num_experts = self .global_num_experts ,
11361183 expert_placement_strategy = expert_placement_strategy ,
1184+ num_fused_shared_experts = self .num_fused_shared_experts ,
11371185 )
11381186 self .local_num_experts = local_num_experts
11391187 self .register_buffer ("expert_map" , expert_map )
1188+ self .register_buffer ("expert_mask" , expert_mask )
11401189 logger .info_once (
11411190 "[EP Rank %s/%s] Expert parallelism is enabled. Expert "
11421191 "placement strategy: %s. Local/global"
@@ -1150,10 +1199,18 @@ def __init__(
11501199 get_compressed_expert_map (self .expert_map ),
11511200 )
11521201 else :
1153- self .local_num_experts , self .expert_map = (self .global_num_experts , None )
1202+ self .local_num_experts , self .expert_map , self .expert_mask = (
1203+ self .global_num_experts ,
1204+ None ,
1205+ None ,
1206+ )
11541207
11551208 self .top_k = top_k
11561209
1210+ self ._init_aiter_shared_experts_topK_buffer (
1211+ vllm_config = vllm_config , dp_size = dp_size_
1212+ )
1213+
11571214 assert intermediate_size % self .tp_size == 0
11581215 self .hidden_size = hidden_size
11591216 self .intermediate_size_per_partition = intermediate_size // self .tp_size
@@ -1327,13 +1384,18 @@ def update_expert_map(self):
13271384 # ep_size and ep_rank should already be updated
13281385 assert self .expert_map is not None
13291386 with self .expert_map .device :
1330- local_num_experts , expert_map = determine_expert_map (
1387+ local_num_experts , expert_map , expert_mask = determine_expert_map (
13311388 ep_size = self .ep_size ,
13321389 ep_rank = self .ep_rank ,
13331390 global_num_experts = self .global_num_experts ,
1391+ num_fused_shared_experts = self .num_fused_shared_experts ,
13341392 )
13351393 self .local_num_experts = local_num_experts
13361394 self .register_buffer ("expert_map" , expert_map )
1395+ self .register_buffer ("expert_mask" , expert_mask )
1396+ self ._init_aiter_shared_experts_topK_buffer (
1397+ vllm_config = get_current_vllm_config (), dp_size = get_dp_group ().world_size
1398+ )
13371399
13381400 def _load_per_tensor_weight_scale (
13391401 self ,
@@ -1504,6 +1566,24 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
15041566 return expert_id
15051567 return self .expert_map [expert_id ].item ()
15061568
1569+ def _init_aiter_shared_experts_topK_buffer (
1570+ self , vllm_config : VllmConfig , dp_size : int
1571+ ):
1572+ if is_rocm_aiter_fusion_shared_expert_enabled ():
1573+ if self .num_fused_shared_experts > 0 :
1574+ init_aiter_topK_meta_data (
1575+ n_routed_experts = self .global_num_experts ,
1576+ n_shared_experts = self .num_fused_shared_experts ,
1577+ top_k = self .top_k ,
1578+ tp_rank = self .ep_rank if self .use_ep else self .tp_rank ,
1579+ tp_size = self .ep_size if self .use_ep else self .tp_size ,
1580+ shared_experts_score = 1.0 ,
1581+ max_num_tokens = vllm_config .scheduler_config .max_num_batched_tokens
1582+ * dp_size ,
1583+ is_EP = self .use_ep ,
1584+ )
1585+ self .local_num_experts += self .num_fused_shared_experts
1586+
15071587 @overload
15081588 def weight_loader (
15091589 self ,
@@ -1866,6 +1946,7 @@ def select_experts(
18661946 global_num_experts : int | None = None ,
18671947 zero_expert_num : int | None = None ,
18681948 zero_expert_type : str | None = None ,
1949+ num_fused_shared_experts : int = 0 ,
18691950 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
18701951 """
18711952 Route the input hidden states to the top-k experts based on the
@@ -1900,7 +1981,16 @@ def select_experts(
19001981 if use_grouped_topk :
19011982 assert topk_group is not None
19021983 assert num_expert_group is not None
1903- topk_weights , topk_ids = grouped_topk (
1984+ if is_rocm_aiter_moe_enabled ():
1985+ if not is_rocm_aiter_fusion_shared_expert_enabled ():
1986+ assert num_fused_shared_experts == 0
1987+ grouped_topk_impl = partial (
1988+ grouped_topk_aiter ,
1989+ num_fused_shared_experts = num_fused_shared_experts ,
1990+ )
1991+ else :
1992+ grouped_topk_impl = grouped_topk
1993+ topk_weights , topk_ids = grouped_topk_impl (
19041994 hidden_states = hidden_states ,
19051995 gating_output = router_logits ,
19061996 topk = top_k ,
@@ -2119,7 +2209,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21192209 renormalize = self .renormalize ,
21202210 use_grouped_topk = self .use_grouped_topk ,
21212211 global_num_experts = self .global_num_experts ,
2122- expert_map = self .expert_map ,
2212+ expert_map = self .expert_map
2213+ if not is_rocm_aiter_moe_enabled ()
2214+ else self .expert_mask ,
21232215 topk_group = self .topk_group ,
21242216 num_expert_group = self .num_expert_group ,
21252217 custom_routing_function = self .custom_routing_function ,
@@ -2244,7 +2336,9 @@ def forward_impl(
22442336 renormalize = self .renormalize ,
22452337 use_grouped_topk = self .use_grouped_topk ,
22462338 global_num_experts = self .global_num_experts ,
2247- expert_map = self .expert_map ,
2339+ expert_map = self .expert_map
2340+ if not is_rocm_aiter_moe_enabled ()
2341+ else self .expert_mask ,
22482342 topk_group = self .topk_group ,
22492343 num_expert_group = self .num_expert_group ,
22502344 custom_routing_function = self .custom_routing_function ,
0 commit comments