Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/kernels/moe/test_expert_usage_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.model_executor.layers.fused_moe.utils import collect_expert_usage_histogram


@pytest.mark.parametrize(
"topk_experts,expert_count,topk_ids_dtype",
[(8, 264, torch.int32), (4, 32, torch.int32), (1, 1, torch.int64)],
)
@pytest.mark.parametrize("token_count", [1, 7, 256, 1024])
def test_collect_expert_usage_histogram(
topk_experts: int, expert_count: int, token_count: int, topk_ids_dtype: torch.dtype
):
device = torch.device("cuda")

# Make an uniform distribution of expert usage
topk_ids = torch.stack(
[torch.arange(topk_experts, dtype=torch.int32)] * token_count
)

topk_ids_gpu = topk_ids.to(device)

expert_usage_histogram_gpu = torch.zeros(
expert_count, dtype=topk_ids_dtype, device=device
)

collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)

# Every expert is used the same amount, so expecting token_count for
# each expert set in the topk_ids tensor.
assert torch.equal(
expert_usage_histogram_gpu[:topk_experts],
torch.full([topk_experts], token_count, dtype=topk_ids_dtype, device=device),
)

# The rest of the experts weren't used, so they should be zero.
assert expert_usage_histogram_gpu[topk_experts:].sum() == 0


@pytest.mark.parametrize("topk_experts,expert_count", [(16, 32)])
@pytest.mark.parametrize("token_count", [1])
@pytest.mark.parametrize("seed", [0xDEADBEEF, 0xCAFEBABE])
def test_collect_expert_usage_histogram_random(
topk_experts: int, expert_count: int, token_count: int, seed: int
):
device = torch.device("cuda")

generator = torch.Generator()
generator.manual_seed(seed)

# Make random distribution of expert usage
topk_ids_cpu = torch.stack(
[torch.randperm(topk_experts, generator=generator, dtype=torch.int32)]
* token_count
)

# Compute ground truth
torch_histogram = torch.histogram(
topk_ids_cpu.to(torch.float), bins=expert_count, range=(0, expert_count - 1)
)

# Use our function
expert_usage_histogram_gpu = torch.zeros(
expert_count, dtype=torch.int32, device=device
)

topk_ids_gpu = topk_ids_cpu.to(device)

collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)

assert torch.equal(
expert_usage_histogram_gpu, torch_histogram.hist.to(torch.int32).to(device)
)
21 changes: 17 additions & 4 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,11 +1388,15 @@ def get_num_experts(self) -> int:
return num_experts[0]
return num_experts

def get_layers_start_end_indices(
self, parallel_config: ParallelConfig
) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
def get_total_num_dense_moe_layers(self) -> int:
return getattr(self.hf_text_config, "first_k_dense_replace", 0)

def get_total_num_moe_layers(self) -> int:
return (
self.get_total_num_hidden_layers() - self.get_total_num_dense_moe_layers()
)

def get_total_num_hidden_layers(self) -> int:
if (
self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
Expand All @@ -1411,6 +1415,15 @@ def get_layers_start_end_indices(
total_num_hidden_layers = getattr(
self.hf_text_config, "num_hidden_layers", 0
)
return total_num_hidden_layers

def get_layers_start_end_indices(
self, parallel_config: ParallelConfig
) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices

total_num_hidden_layers = self.get_total_num_hidden_layers()

# the layout order is: DP x PP x TP
pp_rank = (
parallel_config.rank // parallel_config.tensor_parallel_size
Expand Down
12 changes: 12 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
"full",
"relax",
] = "relax"
VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL: int = 100
VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
Expand Down Expand Up @@ -1117,6 +1119,14 @@ def get_vllm_port() -> int | None:
"relax",
],
),
# Collects expert routing histogram per layer.
"VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM": lambda: bool(
int(os.getenv("VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM", "0"))
),
# How often should the expert usage histogram be saved.
"VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL": lambda: int(
os.getenv("VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL", "100")
),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
Expand Down Expand Up @@ -1504,6 +1514,8 @@ def compute_hash() -> str:
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM",
"VLLM_EXPERT_USAGE_HISTOGRAM_SAVE_INTERVAL",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8",
Expand Down
7 changes: 7 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ class ForwardContext:
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: BatchDescriptor | None = None

# Set when recording usage histogram
expert_usage_histogram: torch.Tensor | None = None

ubatch_slices: UBatchSlices | None = None

def __post_init__(self):
Expand Down Expand Up @@ -227,6 +230,7 @@ def create_forward_context(
dp_metadata: DPMetadata | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
expert_usage_histogram: torch.Tensor | None = None,
ubatch_slices: UBatchSlices | None = None,
):
return ForwardContext(
Expand All @@ -236,6 +240,7 @@ def create_forward_context(
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
expert_usage_histogram=expert_usage_histogram,
ubatch_slices=ubatch_slices,
)

Expand Down Expand Up @@ -264,6 +269,7 @@ def set_forward_context(
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
expert_usage_histogram: torch.Tensor | None = None,
ubatch_slices: UBatchSlices | None = None,
):
"""A context manager that stores the current forward context,
Expand Down Expand Up @@ -309,6 +315,7 @@ def set_forward_context(
dp_metadata,
cudagraph_runtime_mode,
batch_descriptor,
expert_usage_histogram,
ubatch_slices,
)

Expand Down
45 changes: 39 additions & 6 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.fused_moe.utils import collect_expert_usage_histogram
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
Expand Down Expand Up @@ -298,6 +299,7 @@ def apply(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
layer_index: int,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
Expand Down Expand Up @@ -534,6 +536,7 @@ def apply(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
layer_index: int,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
Expand Down Expand Up @@ -598,6 +601,7 @@ def forward_cuda(
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
layer_index: int,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
Expand All @@ -622,6 +626,7 @@ def forward_cuda(
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
layer_index=layer_index,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
Expand Down Expand Up @@ -709,6 +714,7 @@ def forward_cpu(
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
layer_index: int,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
Expand Down Expand Up @@ -758,6 +764,7 @@ def forward_xpu(
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
layer_index: int,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
Expand Down Expand Up @@ -799,6 +806,7 @@ def forward_tpu(
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
layer_index: int,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
Expand Down Expand Up @@ -1132,6 +1140,13 @@ def __init__(
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None

from vllm.model_executor.models.utils import extract_layer_index

self.layer_index = (
extract_layer_index(prefix)
- vllm_config.model_config.get_total_num_dense_moe_layers()
)

# ROCm aiter shared experts fusion
self.num_fused_shared_experts = (
n_shared_experts
Expand Down Expand Up @@ -1936,6 +1951,7 @@ def select_experts(
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
layer_index: int,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
Expand Down Expand Up @@ -2067,6 +2083,14 @@ def select_experts(
)
else:
zero_expert_result = None

expert_usage_histogram = get_forward_context().expert_usage_histogram

if expert_usage_histogram is not None:
collect_expert_usage_histogram(
topk_ids, expert_usage_histogram[layer_index]
)

return topk_weights, topk_ids, zero_expert_result

def must_reduce_shared_expert_outputs(self) -> bool:
Expand Down Expand Up @@ -2115,23 +2139,25 @@ def forward_native(
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
fused_output = self.forward_impl(hidden_states, router_logits)
fused_output = self.forward_impl(
hidden_states, router_logits, self.layer_index
)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name
hidden_states, router_logits, self.layer_name, self.layer_index
)
return fused_output[..., :og_hidden_states]
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits
hidden_states, router_logits, self.layer_index
)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name
hidden_states, router_logits, self.layer_name, self.layer_index
)
return (
shared_output[..., :og_hidden_states],
Expand Down Expand Up @@ -2212,6 +2238,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
layer_index=self.layer_index,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
Expand Down Expand Up @@ -2297,6 +2324,7 @@ def forward_impl(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_index: int,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None

Expand Down Expand Up @@ -2339,6 +2367,7 @@ def forward_impl(
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
layer_index=layer_index,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
Expand Down Expand Up @@ -2459,17 +2488,19 @@ def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
layer_index: int,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, layer_index)


def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
layer_index: int,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

Expand All @@ -2487,17 +2518,19 @@ def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
layer_index: int,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, layer_index)


def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
layer_index: int,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
Expand Down
Loading