Skip to content

Commit 58ecf1d

Browse files
kliuaevalarLiptjtanaa
authored andcommitted
[ROCm][FEAT] Fuse DeepSeek shared experts into AITER fused_moe ops (vllm-project#24097)
Signed-off-by: chenjun <junchen2@amd.com> Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Co-authored-by: valarLip <103567126+valarLip@users.noreply.github.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent c4f0811 commit 58ecf1d

File tree

8 files changed

+352
-88
lines changed

8 files changed

+352
-88
lines changed

tests/distributed/test_expert_placement.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
8585
else:
8686
expected_test_local = base_experts
8787

88-
test_local_experts, test_expert_map = determine_expert_map(
88+
test_local_experts, test_expert_map, _ = determine_expert_map(
8989
ep_size=test_ep_size,
9090
ep_rank=ep_rank,
9191
global_num_experts=test_global_experts,
@@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
116116
"""Test edge cases for round_robin expert placement."""
117117

118118
# Test case 1: ep_size = 1 (should return None for expert_map)
119-
local_num_experts, expert_map = determine_expert_map(
119+
local_num_experts, expert_map, _ = determine_expert_map(
120120
ep_size=1,
121121
ep_rank=0,
122122
global_num_experts=8,
@@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive():
217217
expected_local,
218218
expected_map_pattern,
219219
) in test_cases:
220-
local_num_experts, expert_map = determine_expert_map(
220+
local_num_experts, expert_map, _ = determine_expert_map(
221221
ep_size=ep_size,
222222
ep_rank=ep_rank,
223223
global_num_experts=global_num_experts,

tests/kernels/moe/test_moe_permute_unpermute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_moe_permute_unpermute(
217217
expert_map = None
218218
n_local_expert = n_expert
219219
if ep_size != 1:
220-
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
220+
n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
221221
expert_map = expert_map.cuda()
222222
start_expert = n_local_expert * ep_rank
223223
current_platform.seed_everything(0)

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
VLLM_ROCM_USE_TRITON_ROPE: bool = False
114114
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
115115
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
116+
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
116117
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
117118
VLLM_ROCM_FP8_PADDING: bool = True
118119
VLLM_ROCM_MOE_PADDING: bool = True
@@ -914,6 +915,12 @@ def get_vllm_port() -> int | None:
914915
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
915916
in ("true", "1")
916917
),
918+
# Whether to use aiter fusion shared experts ops.
919+
# By default is enabled.
920+
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
921+
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
922+
in ("true", "1")
923+
),
917924
# use rocm skinny gemms
918925
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
919926
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
from collections.abc import Callable, Iterable
66
from contextlib import nullcontext
77
from enum import Enum
8+
from functools import partial
89
from typing import Literal, get_args, overload
910

1011
import torch
1112
import torch.nn.functional as F
1213
from torch.nn.parameter import UninitializedParameter
1314

1415
import vllm.envs as envs
15-
from vllm.config import get_current_vllm_config
16+
from vllm.config import VllmConfig, get_current_vllm_config
1617
from vllm.config.parallel import ExpertPlacementStrategy
1718
from vllm.distributed import (
1819
get_dp_group,
@@ -39,6 +40,8 @@
3940
FusedMoEPrepareAndFinalize,
4041
)
4142
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
43+
init_aiter_topK_meta_data,
44+
is_rocm_aiter_fusion_shared_expert_enabled,
4245
is_rocm_aiter_moe_enabled,
4346
)
4447
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
@@ -87,7 +90,7 @@ def _eplb_map_to_physical_and_record(
8790

8891
if is_rocm_aiter_moe_enabled():
8992
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
90-
rocm_aiter_grouped_topk as grouped_topk,
93+
rocm_aiter_grouped_topk as grouped_topk_aiter,
9194
)
9295
else:
9396
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
@@ -634,6 +637,7 @@ def forward_cuda(
634637
global_num_experts=global_num_experts,
635638
zero_expert_num=zero_expert_num,
636639
zero_expert_type=zero_expert_type,
640+
num_fused_shared_experts=layer.num_fused_shared_experts,
637641
)
638642

639643
if self.rocm_aiter_moe_enabled:
@@ -860,7 +864,8 @@ def determine_expert_map(
860864
ep_rank: int,
861865
global_num_experts: int,
862866
expert_placement_strategy: ExpertPlacementStrategy = "linear",
863-
) -> tuple[int, torch.Tensor | None]:
867+
num_fused_shared_experts: int = 0,
868+
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
864869
"""
865870
Calculates how many experts should be assigned to each rank for EP and
866871
creates a mapping from global to local expert index. Experts are
@@ -882,10 +887,16 @@ def determine_expert_map(
882887
(global_num_experts,) mapping from global to local index.
883888
Contains -1 for experts not assigned to the current rank.
884889
Returns None if ep_size is 1.
890+
- expert_mask (Optional[torch.Tensor]): A tensor of shape
891+
(global_num_experts + num_fused_shared_experts + 1,)
892+
containing 1 for experts assigned to the current rank
893+
and 0 for sentinel.
894+
Returns None if ep_size is 1.
895+
Used only when AITER MOE is enabled.
885896
"""
886897
assert ep_size > 0
887898
if ep_size == 1:
888-
return (global_num_experts, None)
899+
return (global_num_experts, None, None)
889900

890901
# Distribute experts as evenly as possible to each rank.
891902
base_experts = global_num_experts // ep_size
@@ -914,7 +925,26 @@ def determine_expert_map(
914925
f"'{expert_placement_strategy}', expected one of "
915926
f"{get_args(ExpertPlacementStrategy)}"
916927
)
917-
return (local_num_experts, expert_map)
928+
929+
expert_mask = None
930+
if is_rocm_aiter_moe_enabled():
931+
expert_mask = torch.ones(
932+
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
933+
)
934+
expert_mask[-1] = 0
935+
expert_mask[:global_num_experts] = expert_map > -1
936+
expert_map = torch.cat(
937+
(
938+
expert_map,
939+
torch.tensor(
940+
[local_num_experts + i for i in range(num_fused_shared_experts)],
941+
dtype=torch.int32,
942+
),
943+
),
944+
dim=0,
945+
)
946+
947+
return (local_num_experts, expert_map, expert_mask)
918948

919949

920950
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
@@ -1040,6 +1070,7 @@ def __init__(
10401070
zero_expert_num: int | None = 0,
10411071
zero_expert_type: str | None = None,
10421072
expert_mapping: list[tuple[str, str, int, str]] | None = None,
1073+
n_shared_experts: int | None = None,
10431074
):
10441075
super().__init__()
10451076
if params_dtype is None:
@@ -1096,6 +1127,22 @@ def __init__(
10961127
self.logical_to_physical_map: torch.Tensor | None = None
10971128
self.logical_replica_count: torch.Tensor | None = None
10981129

1130+
# ROCm aiter shared experts fusion
1131+
self.num_fused_shared_experts = (
1132+
n_shared_experts
1133+
if n_shared_experts is not None
1134+
and is_rocm_aiter_fusion_shared_expert_enabled()
1135+
else 0
1136+
)
1137+
if (
1138+
not is_rocm_aiter_fusion_shared_expert_enabled()
1139+
and self.num_fused_shared_experts != 0
1140+
):
1141+
raise ValueError(
1142+
"n_shared_experts is only supported on ROCm aiter when "
1143+
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
1144+
)
1145+
10991146
# Determine expert maps
11001147
if self.use_ep:
11011148
if self.enable_eplb:
@@ -1129,14 +1176,16 @@ def __init__(
11291176
expert_placement_strategy = "linear"
11301177

11311178
self.expert_map: torch.Tensor | None
1132-
local_num_experts, expert_map = determine_expert_map(
1179+
local_num_experts, expert_map, expert_mask = determine_expert_map(
11331180
ep_size=self.ep_size,
11341181
ep_rank=self.ep_rank,
11351182
global_num_experts=self.global_num_experts,
11361183
expert_placement_strategy=expert_placement_strategy,
1184+
num_fused_shared_experts=self.num_fused_shared_experts,
11371185
)
11381186
self.local_num_experts = local_num_experts
11391187
self.register_buffer("expert_map", expert_map)
1188+
self.register_buffer("expert_mask", expert_mask)
11401189
logger.info_once(
11411190
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
11421191
"placement strategy: %s. Local/global"
@@ -1150,10 +1199,18 @@ def __init__(
11501199
get_compressed_expert_map(self.expert_map),
11511200
)
11521201
else:
1153-
self.local_num_experts, self.expert_map = (self.global_num_experts, None)
1202+
self.local_num_experts, self.expert_map, self.expert_mask = (
1203+
self.global_num_experts,
1204+
None,
1205+
None,
1206+
)
11541207

11551208
self.top_k = top_k
11561209

1210+
self._init_aiter_shared_experts_topK_buffer(
1211+
vllm_config=vllm_config, dp_size=dp_size_
1212+
)
1213+
11571214
assert intermediate_size % self.tp_size == 0
11581215
self.hidden_size = hidden_size
11591216
self.intermediate_size_per_partition = intermediate_size // self.tp_size
@@ -1327,13 +1384,18 @@ def update_expert_map(self):
13271384
# ep_size and ep_rank should already be updated
13281385
assert self.expert_map is not None
13291386
with self.expert_map.device:
1330-
local_num_experts, expert_map = determine_expert_map(
1387+
local_num_experts, expert_map, expert_mask = determine_expert_map(
13311388
ep_size=self.ep_size,
13321389
ep_rank=self.ep_rank,
13331390
global_num_experts=self.global_num_experts,
1391+
num_fused_shared_experts=self.num_fused_shared_experts,
13341392
)
13351393
self.local_num_experts = local_num_experts
13361394
self.register_buffer("expert_map", expert_map)
1395+
self.register_buffer("expert_mask", expert_mask)
1396+
self._init_aiter_shared_experts_topK_buffer(
1397+
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
1398+
)
13371399

13381400
def _load_per_tensor_weight_scale(
13391401
self,
@@ -1504,6 +1566,24 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
15041566
return expert_id
15051567
return self.expert_map[expert_id].item()
15061568

1569+
def _init_aiter_shared_experts_topK_buffer(
1570+
self, vllm_config: VllmConfig, dp_size: int
1571+
):
1572+
if is_rocm_aiter_fusion_shared_expert_enabled():
1573+
if self.num_fused_shared_experts > 0:
1574+
init_aiter_topK_meta_data(
1575+
n_routed_experts=self.global_num_experts,
1576+
n_shared_experts=self.num_fused_shared_experts,
1577+
top_k=self.top_k,
1578+
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
1579+
tp_size=self.ep_size if self.use_ep else self.tp_size,
1580+
shared_experts_score=1.0,
1581+
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
1582+
* dp_size,
1583+
is_EP=self.use_ep,
1584+
)
1585+
self.local_num_experts += self.num_fused_shared_experts
1586+
15071587
@overload
15081588
def weight_loader(
15091589
self,
@@ -1866,6 +1946,7 @@ def select_experts(
18661946
global_num_experts: int | None = None,
18671947
zero_expert_num: int | None = None,
18681948
zero_expert_type: str | None = None,
1949+
num_fused_shared_experts: int = 0,
18691950
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
18701951
"""
18711952
Route the input hidden states to the top-k experts based on the
@@ -1900,7 +1981,16 @@ def select_experts(
19001981
if use_grouped_topk:
19011982
assert topk_group is not None
19021983
assert num_expert_group is not None
1903-
topk_weights, topk_ids = grouped_topk(
1984+
if is_rocm_aiter_moe_enabled():
1985+
if not is_rocm_aiter_fusion_shared_expert_enabled():
1986+
assert num_fused_shared_experts == 0
1987+
grouped_topk_impl = partial(
1988+
grouped_topk_aiter,
1989+
num_fused_shared_experts=num_fused_shared_experts,
1990+
)
1991+
else:
1992+
grouped_topk_impl = grouped_topk
1993+
topk_weights, topk_ids = grouped_topk_impl(
19041994
hidden_states=hidden_states,
19051995
gating_output=router_logits,
19061996
topk=top_k,
@@ -2119,7 +2209,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21192209
renormalize=self.renormalize,
21202210
use_grouped_topk=self.use_grouped_topk,
21212211
global_num_experts=self.global_num_experts,
2122-
expert_map=self.expert_map,
2212+
expert_map=self.expert_map
2213+
if not is_rocm_aiter_moe_enabled()
2214+
else self.expert_mask,
21232215
topk_group=self.topk_group,
21242216
num_expert_group=self.num_expert_group,
21252217
custom_routing_function=self.custom_routing_function,
@@ -2244,7 +2336,9 @@ def forward_impl(
22442336
renormalize=self.renormalize,
22452337
use_grouped_topk=self.use_grouped_topk,
22462338
global_num_experts=self.global_num_experts,
2247-
expert_map=self.expert_map,
2339+
expert_map=self.expert_map
2340+
if not is_rocm_aiter_moe_enabled()
2341+
else self.expert_mask,
22482342
topk_group=self.topk_group,
22492343
num_expert_group=self.num_expert_group,
22502344
custom_routing_function=self.custom_routing_function,

0 commit comments

Comments
 (0)