Skip to content

Commit 44f77ec

Browse files
committed
expert histogram
Signed-off-by: Patryk Saffer <patryk.saffer99@gmail.com>
1 parent 6c9fdbf commit 44f77ec

File tree

20 files changed

+409
-11
lines changed

20 files changed

+409
-11
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.fused_moe.utils import (
7+
collect_expert_usage_histogram)
8+
9+
10+
@pytest.mark.parametrize("topk_experts,expert_count,topk_ids_dtype",
11+
[(8, 264, torch.int32), (4, 32, torch.int32),
12+
(1, 1, torch.int64)])
13+
@pytest.mark.parametrize("token_count", [1, 7, 256, 1024])
14+
def test_collect_expert_usage_histogram(topk_experts: int, expert_count: int,
15+
token_count: int,
16+
topk_ids_dtype: torch.dtype):
17+
device = torch.device('cuda')
18+
19+
# Make an uniform distribution of expert usage
20+
topk_ids = torch.stack([torch.arange(topk_experts, dtype=torch.int32)] *
21+
token_count)
22+
23+
topk_ids_gpu = topk_ids.to(device)
24+
25+
expert_usage_histogram_gpu = torch.zeros(expert_count,
26+
dtype=topk_ids_dtype,
27+
device=device)
28+
29+
collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)
30+
31+
# Every expert is used the same amount, so expecting token_count for
32+
# each expert set in the topk_ids tensor.
33+
assert torch.equal(
34+
expert_usage_histogram_gpu[:topk_experts],
35+
torch.full([topk_experts],
36+
token_count,
37+
dtype=topk_ids_dtype,
38+
device=device))
39+
40+
# The rest of the experts weren't used, so they should be zero.
41+
assert expert_usage_histogram_gpu[topk_experts:].sum() == 0
42+
43+
44+
@pytest.mark.parametrize("topk_experts,expert_count", [(16, 32)])
45+
@pytest.mark.parametrize("token_count", [1])
46+
@pytest.mark.parametrize("seed", [0xDEADBEEF, 0xCAFEBABE])
47+
def test_collect_expert_usage_histogram_random(topk_experts: int,
48+
expert_count: int,
49+
token_count: int, seed: int):
50+
device = torch.device('cuda')
51+
52+
generator = torch.Generator()
53+
generator.manual_seed(seed)
54+
55+
# Make random distribution of expert usage
56+
topk_ids_cpu = torch.stack(
57+
[torch.randperm(topk_experts, generator=generator, dtype=torch.int32)
58+
] * token_count)
59+
60+
# Compute ground truth
61+
torch_histogram = torch.histogram(topk_ids_cpu.to(torch.float),
62+
bins=expert_count,
63+
range=(0, expert_count - 1))
64+
65+
# Use our function
66+
expert_usage_histogram_gpu = torch.zeros(expert_count,
67+
dtype=torch.int32,
68+
device=device)
69+
70+
topk_ids_gpu = topk_ids_cpu.to(device)
71+
72+
collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)
73+
74+
assert torch.equal(expert_usage_histogram_gpu,
75+
torch_histogram.hist.to(torch.int32).to(device))

vllm/config/model.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,11 +1388,14 @@ def get_num_experts(self) -> int:
13881388
return num_experts[0]
13891389
return num_experts
13901390

1391-
def get_layers_start_end_indices(
1392-
self, parallel_config: ParallelConfig
1393-
) -> tuple[int, int]:
1394-
from vllm.distributed.utils import get_pp_indices
1391+
def get_total_num_dense_moe_layers(self) -> int:
1392+
return getattr(self.hf_text_config, "first_k_dense_replace", 0)
1393+
1394+
def get_total_num_moe_layers(self) -> int:
1395+
return self.get_total_num_hidden_layers(
1396+
) - self.get_total_num_dense_moe_layers()
13951397

1398+
def get_total_num_hidden_layers(self) -> int:
13961399
if (
13971400
self.hf_text_config.model_type == "deepseek_mtp"
13981401
or self.hf_config.model_type == "mimo_mtp"
@@ -1411,6 +1414,14 @@ def get_layers_start_end_indices(
14111414
total_num_hidden_layers = getattr(
14121415
self.hf_text_config, "num_hidden_layers", 0
14131416
)
1417+
return total_num_hidden_layers
1418+
1419+
def get_layers_start_end_indices(
1420+
self, parallel_config: ParallelConfig
1421+
) -> tuple[int, int]:
1422+
from vllm.distributed.utils import get_pp_indices
1423+
total_num_hidden_layers = self.get_total_num_hidden_layers()
1424+
14141425
# the layout order is: DP x PP x TP
14151426
pp_rank = (
14161427
parallel_config.rank // parallel_config.tensor_parallel_size

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
"full",
155155
"relax",
156156
] = "relax"
157+
VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL: int = 100
158+
VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: bool = False
157159
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
158160
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
159161
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
@@ -1117,6 +1119,14 @@ def get_vllm_port() -> int | None:
11171119
"relax",
11181120
],
11191121
),
1122+
# Collects expert routing histogram per layer.
1123+
"VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM":
1124+
lambda: bool(
1125+
int(os.getenv("VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM", "0"))),
1126+
1127+
# How often should the expert usage histogram be saved.
1128+
"VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL":
1129+
lambda: int(os.getenv("VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL", "100")),
11201130
# Whether to use fused grouped_topk used for MoE expert selection.
11211131
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
11221132
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
@@ -1504,6 +1514,8 @@ def compute_hash() -> str:
15041514
"VLLM_DISABLED_KERNELS",
15051515
"VLLM_USE_DEEP_GEMM",
15061516
"VLLM_USE_DEEP_GEMM_E8M0",
1517+
"VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM",
1518+
"VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL",
15071519
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
15081520
"VLLM_USE_FLASHINFER_MOE_FP16",
15091521
"VLLM_USE_FLASHINFER_MOE_FP8",

vllm/forward_context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ class ForwardContext:
200200
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
201201
batch_descriptor: BatchDescriptor | None = None
202202

203+
# Set when recording usage histogram
204+
expert_usage_histogram: torch.Tensor | None = None
205+
203206
ubatch_slices: UBatchSlices | None = None
204207

205208
def __post_init__(self):
@@ -227,6 +230,7 @@ def create_forward_context(
227230
dp_metadata: DPMetadata | None = None,
228231
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
229232
batch_descriptor: BatchDescriptor | None = None,
233+
expert_usage_histogram: torch.Tensor | None = None,
230234
ubatch_slices: UBatchSlices | None = None,
231235
):
232236
return ForwardContext(
@@ -236,6 +240,7 @@ def create_forward_context(
236240
dp_metadata=dp_metadata,
237241
cudagraph_runtime_mode=cudagraph_runtime_mode,
238242
batch_descriptor=batch_descriptor,
243+
expert_usage_histogram=expert_usage_histogram,
239244
ubatch_slices=ubatch_slices,
240245
)
241246

@@ -264,6 +269,7 @@ def set_forward_context(
264269
num_tokens_across_dp: torch.Tensor | None = None,
265270
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
266271
batch_descriptor: BatchDescriptor | None = None,
272+
expert_usage_histogram: torch.Tensor | None = None,
267273
ubatch_slices: UBatchSlices | None = None,
268274
):
269275
"""A context manager that stores the current forward context,
@@ -309,6 +315,7 @@ def set_forward_context(
309315
dp_metadata,
310316
cudagraph_runtime_mode,
311317
batch_descriptor,
318+
expert_usage_histogram,
312319
ubatch_slices,
313320
)
314321

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
is_rocm_aiter_fusion_shared_expert_enabled,
4545
is_rocm_aiter_moe_enabled,
4646
)
47+
from vllm.model_executor.layers.fused_moe.utils import (
48+
collect_expert_usage_histogram)
4749
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
4850
from vllm.model_executor.layers.quantization.base_config import (
4951
QuantizationConfig,
@@ -298,6 +300,7 @@ def apply(
298300
router_logits: torch.Tensor,
299301
top_k: int,
300302
renormalize: bool,
303+
layer_index: int,
301304
use_grouped_topk: bool = False,
302305
topk_group: int | None = None,
303306
num_expert_group: int | None = None,
@@ -534,6 +537,7 @@ def apply(
534537
router_logits: torch.Tensor,
535538
top_k: int,
536539
renormalize: bool,
540+
layer_index: int,
537541
use_grouped_topk: bool = False,
538542
topk_group: int | None = None,
539543
num_expert_group: int | None = None,
@@ -598,6 +602,7 @@ def forward_cuda(
598602
top_k: int,
599603
router_logits: torch.Tensor,
600604
renormalize: bool,
605+
layer_index: int,
601606
topk_group: int | None = None,
602607
num_expert_group: int | None = None,
603608
global_num_experts: int = -1,
@@ -709,6 +714,7 @@ def forward_cpu(
709714
top_k: int,
710715
router_logits: torch.Tensor,
711716
renormalize: bool,
717+
layer_index: int,
712718
topk_group: int | None = None,
713719
num_expert_group: int | None = None,
714720
global_num_experts: int = -1,
@@ -758,6 +764,7 @@ def forward_xpu(
758764
top_k: int,
759765
router_logits: torch.Tensor,
760766
renormalize: bool,
767+
layer_index: int,
761768
topk_group: int | None = None,
762769
num_expert_group: int | None = None,
763770
global_num_experts: int = -1,
@@ -799,6 +806,7 @@ def forward_tpu(
799806
top_k: int,
800807
router_logits: torch.Tensor,
801808
renormalize: bool,
809+
layer_index: int,
802810
topk_group: int | None = None,
803811
num_expert_group: int | None = None,
804812
global_num_experts: int = -1,
@@ -1132,6 +1140,11 @@ def __init__(
11321140
self.logical_to_physical_map: torch.Tensor | None = None
11331141
self.logical_replica_count: torch.Tensor | None = None
11341142

1143+
from vllm.model_executor.models.utils import extract_layer_index
1144+
self.layer_index = extract_layer_index(
1145+
prefix) - vllm_config.model_config.get_total_num_dense_moe_layers(
1146+
)
1147+
11351148
# ROCm aiter shared experts fusion
11361149
self.num_fused_shared_experts = (
11371150
n_shared_experts
@@ -1936,6 +1949,7 @@ def select_experts(
19361949
top_k: int,
19371950
use_grouped_topk: bool,
19381951
renormalize: bool,
1952+
layer_index: int,
19391953
topk_group: int | None = None,
19401954
num_expert_group: int | None = None,
19411955
custom_routing_function: Callable | None = None,
@@ -2067,6 +2081,13 @@ def select_experts(
20672081
)
20682082
else:
20692083
zero_expert_result = None
2084+
2085+
expert_usage_histogram = get_forward_context().expert_usage_histogram
2086+
2087+
if expert_usage_histogram is not None:
2088+
collect_expert_usage_histogram(topk_ids,
2089+
expert_usage_histogram[layer_index])
2090+
20702091
return topk_weights, topk_ids, zero_expert_result
20712092

20722093
def must_reduce_shared_expert_outputs(self) -> bool:
@@ -2115,23 +2136,25 @@ def forward_native(
21152136
if current_platform.is_tpu():
21162137
# TODO: Once the OOM issue for the TPU backend is resolved, we
21172138
# will switch to using the moe_forward custom op.
2118-
fused_output = self.forward_impl(hidden_states, router_logits)
2139+
fused_output = self.forward_impl(hidden_states, router_logits,
2140+
self.layer_index)
21192141
assert not isinstance(fused_output, tuple)
21202142
else:
21212143
fused_output = torch.ops.vllm.moe_forward(
2122-
hidden_states, router_logits, self.layer_name
2144+
hidden_states, router_logits, self.layer_name,
2145+
self.layer_index
21232146
)
21242147
return fused_output[..., :og_hidden_states]
21252148
else:
21262149
if current_platform.is_tpu():
21272150
# TODO: Once the OOM issue for the TPU backend is resolved, we
21282151
# will switch to using the moe_forward custom op.
21292152
shared_output, fused_output = self.forward_impl(
2130-
hidden_states, router_logits
2153+
hidden_states, router_logits, self.layer_index
21312154
)
21322155
else:
21332156
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
2134-
hidden_states, router_logits, self.layer_name
2157+
hidden_states, router_logits, self.layer_name, self.layer_index
21352158
)
21362159
return (
21372160
shared_output[..., :og_hidden_states],
@@ -2212,6 +2235,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22122235
router_logits=staged_router_logits,
22132236
top_k=self.top_k,
22142237
renormalize=self.renormalize,
2238+
layer_index=self.layer_index,
22152239
use_grouped_topk=self.use_grouped_topk,
22162240
global_num_experts=self.global_num_experts,
22172241
expert_map=self.expert_map
@@ -2297,6 +2321,7 @@ def forward_impl(
22972321
self,
22982322
hidden_states: torch.Tensor,
22992323
router_logits: torch.Tensor,
2324+
layer_index: int,
23002325
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
23012326
assert self.quant_method is not None
23022327

@@ -2339,6 +2364,7 @@ def forward_impl(
23392364
router_logits=router_logits,
23402365
top_k=self.top_k,
23412366
renormalize=self.renormalize,
2367+
layer_index=layer_index,
23422368
use_grouped_topk=self.use_grouped_topk,
23432369
global_num_experts=self.global_num_experts,
23442370
expert_map=self.expert_map
@@ -2459,17 +2485,19 @@ def moe_forward(
24592485
hidden_states: torch.Tensor,
24602486
router_logits: torch.Tensor,
24612487
layer_name: str,
2488+
layer_index: int,
24622489
) -> torch.Tensor:
24632490
forward_context: ForwardContext = get_forward_context()
24642491
self = forward_context.no_compile_layers[layer_name]
24652492
assert self.shared_experts is None
2466-
return self.forward_impl(hidden_states, router_logits)
2493+
return self.forward_impl(hidden_states, router_logits, layer_index)
24672494

24682495

24692496
def moe_forward_fake(
24702497
hidden_states: torch.Tensor,
24712498
router_logits: torch.Tensor,
24722499
layer_name: str,
2500+
layer_index: int,
24732501
) -> torch.Tensor:
24742502
return torch.empty_like(hidden_states)
24752503

@@ -2487,17 +2515,19 @@ def moe_forward_shared(
24872515
hidden_states: torch.Tensor,
24882516
router_logits: torch.Tensor,
24892517
layer_name: str,
2518+
layer_index: int,
24902519
) -> tuple[torch.Tensor, torch.Tensor]:
24912520
forward_context: ForwardContext = get_forward_context()
24922521
self = forward_context.no_compile_layers[layer_name]
24932522
assert self.shared_experts is not None
2494-
return self.forward_impl(hidden_states, router_logits)
2523+
return self.forward_impl(hidden_states, router_logits, layer_index)
24952524

24962525

24972526
def moe_forward_shared_fake(
24982527
hidden_states: torch.Tensor,
24992528
router_logits: torch.Tensor,
25002529
layer_name: str,
2530+
layer_index: int,
25012531
) -> tuple[torch.Tensor, torch.Tensor]:
25022532
shared_out = torch.empty_like(hidden_states)
25032533
fused_out = torch.empty_like(hidden_states)

0 commit comments

Comments
 (0)