Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8fe6f82
[Feature] Core EPLB algorithm
abmfy May 14, 2025
bdda8dc
[Feature] Register expert weights for DeepSeek MoE
abmfy May 16, 2025
43d52ac
[Chore] Rename EPLB rebalance algo module name
abmfy May 16, 2025
58bf9fd
[Feature] Store EPLB states in model runner
abmfy May 16, 2025
52b141f
[Feature] EPLB rearrangement execution
abmfy May 16, 2025
98312d3
[Feature] Add expert load metrics collection during forward
abmfy May 19, 2025
22a963d
[Feature] Rearrange experts after a preset step interval
abmfy May 19, 2025
f88d836
Merge branch 'main' into eplb
abmfy May 19, 2025
43ac672
[Feature] Use unified `FusedMoE` in DeepSeek-V3/R1
abmfy May 20, 2025
f7ba162
[Bugfix] Copy expert mappings after rearrangement
abmfy May 20, 2025
ba3d60f
[Chore] Move implementations to `deepseek_v2.py`
abmfy May 23, 2025
ebcfcc7
[Chore] Remove expert load stats from forward context
abmfy May 23, 2025
620f59a
[Feature] Weight loading for redundant experts
abmfy May 23, 2025
90f3ed5
[Feature] Expert replica selection and load metrics recording
abmfy May 27, 2025
b3697de
[Feature] Map logical experts in weight loading
abmfy May 27, 2025
5d85f61
[Bugfix] Use `scatter_add_` instead of `bincount` for compile
abmfy May 27, 2025
e416e3c
[Bugfix] Add EPLB args in `EngineArgs`
abmfy May 27, 2025
233741c
[Bugfix] Sum up steps on EPLb rearrange
abmfy May 27, 2025
cfcd42c
[Bugfix] Collect expert weights into a list
abmfy May 27, 2025
36b0b11
[Bugfix] Fix typo in assertion
abmfy May 27, 2025
d5add3a
[Bugfix] Pad `log2phy` magging in rebalance algo
abmfy May 27, 2025
b00bdb9
[Bugfix] Fix EP group in `DeepseekV2MoE`
abmfy May 27, 2025
c9cf2d4
[Refactor] Use local physical ids in expert load collection
abmfy May 27, 2025
4f79fef
[Bugfix] Map physical id before recording expert load metrics
abmfy May 27, 2025
a97ee39
[Perf] Reduce overhead of expert load recording
abmfy May 28, 2025
0c9340d
Merge branch 'main' into eplb
abmfy May 28, 2025
2b14d51
[Bugfix] Step EPLB state in dummy run to avoid blocking DP
abmfy May 29, 2025
306b21a
[Feature] Do not record expert loads for dummy batches
abmfy May 30, 2025
021578e
[Bugfix] Collect expert weights after weight post-processing
abmfy Jun 2, 2025
c2e0516
[Bugfix] Fix weight loading of replica experts
abmfy Jun 3, 2025
0071b24
Merge branch 'main' into eplb
abmfy Jun 6, 2025
38f9218
Merge branch 'main' into eplb
abmfy Jun 9, 2025
79c0d41
[Bugfix] Remove `e_score_correction_bias` in expert weights
abmfy Jun 9, 2025
b011065
[Bugfix] Fix shapes and dtypes in `FusedMoE`
abmfy Jun 10, 2025
82a6299
Merge branch 'main' into eplb
abmfy Jun 12, 2025
90706aa
[Feature] Disable EPLb step during profile run
abmfy Jun 16, 2025
f1f62b2
[Bugfix] Synchronize CUDA before shuffling layer to avoid hang
abmfy Jun 17, 2025
332a4d6
Merge branch 'main' into eplb
abmfy Jun 18, 2025
90d23ec
Merge branch 'eplb-graph' into eplb
abmfy Jun 19, 2025
993d7d7
[Style] Rename module `eplb.states` to `eplb.eplb_state`
abmfy Jun 19, 2025
90afdaf
[Feature] Run a dummy rearrangement during profile run for CUDA graphs
abmfy Jun 20, 2025
7774e0a
Merge branch 'eplb-graph' into eplb
abmfy Jun 20, 2025
f5d171f
[Feature] Constrain EPLB to main models
abmfy Jun 20, 2025
aaa66a2
[Refactor] Move out `EplbState` in model runner from classvars
abmfy Jun 20, 2025
934bbf0
Merge branch 'main' into eplb
abmfy Jun 23, 2025
4e346be
[Style] Rename `--num-extra-experts` to `--num-redundant-experts`
abmfy Jun 23, 2025
2496a54
[Doc] Add glossary for different types of experts
abmfy Jun 23, 2025
9916913
[Doc] Add staatements in `EplbState` that some var is just config
abmfy Jun 23, 2025
420cb99
[Doc] Add notes on synchronization of rearrangement step
abmfy Jun 23, 2025
ff368a1
[Doc] Add examples for expert mappings
abmfy Jun 23, 2025
425d56c
[Doc] Add explanation on why picking the last layer for MoE config
abmfy Jun 23, 2025
76fbdf8
[Refactor] Revert `fused_moe.py` since not used
abmfy Jun 23, 2025
6777877
[Doc] Add explanations for calling points of `_dummy_run`
abmfy Jun 23, 2025
12401b1
[Doc] Add comments on when do real communication happen
abmfy Jun 23, 2025
80b3a1b
[Doc] Add comments on only last `eplb_window_size` steps will be used
abmfy Jun 23, 2025
3ea6f2c
[Feature] Disable balancedness logging by default
abmfy Jun 23, 2025
aff7991
[Style] Rename shadowed variables to make linter happy
abmfy Jun 24, 2025
8ac089e
[Style] Add parameters of `apply` for subclasses of `FusedMoEMethodBase`
abmfy Jun 24, 2025
a6a4a3a
[Test] Add test for EPLB algo
abmfy Jun 24, 2025
1ed45b2
[Test] Add test for EPLB execute
abmfy Jun 25, 2025
4eeb0ff
[Style] Split some long lines
abmfy Jun 25, 2025
0c177d0
Merge branch 'main' into eplb
abmfy Jun 25, 2025
5b1e354
[Feature] Use `get_node_count` and remove magic number
abmfy Jun 25, 2025
495f782
[Test] Disable `first_k_dense_replace` in `test_initialization`
abmfy Jun 26, 2025
66fe93f
[Test] Use only 2 experts in `test_initialization`
abmfy Jun 26, 2025
3ec9032
[Test] Get at least `n_group` experts in `test_initialization`
abmfy Jun 26, 2025
c479d2c
[Test] Allow 2 experts per group in `test_initialization`
abmfy Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd

- label: EPLB Algorithm Test
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- vllm/distributed/eplb
- tests/distributed/test_eplb_algo.py
commands:
- pytest -v -s distributed/test_eplb_algo.py

- label: EPLB Execution Test # 5min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/distributed/eplb
- tests/distributed/test_eplb_execute.py
commands:
- pytest -v -s distributed/test_eplb_execute.py

- label: Metrics, Tracing Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 2
Expand Down
292 changes: 292 additions & 0 deletions tests/distributed/test_eplb_algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm.distributed.eplb.rebalance_algo import rebalance_experts


def test_basic_rebalance():
"""Test basic rebalancing functionality"""
# Example from https://github.com/deepseek-ai/eplb
weight = torch.tensor([
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
])

num_layers = weight.shape[0]
num_replicas = 16
num_groups = 4
num_nodes = 2
num_gpus = 8

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify output shapes
assert phy2log.shape == (
2,
16,
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
assert (log2phy.shape[0] == 2
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
assert (
log2phy.shape[1] == 12
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
assert logcnt.shape == (
2,
12,
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"

# Verify physical to logical expert mapping range is correct
assert torch.all(phy2log >= 0) and torch.all(
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"

# Verify expert count reasonableness
assert torch.all(
logcnt >= 1), "Each logical expert should have at least 1 replica"
assert (
torch.sum(logcnt, dim=1).sum() == num_replicas *
num_layers), f"Total replicas should be {num_replicas * num_layers}"

# Verify expected output
expected_phy2log = torch.tensor([
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
])
assert torch.all(phy2log == expected_phy2log)

expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
assert torch.all(logcnt == expected_logcnt)


def test_single_gpu_case():
"""Test single GPU case"""
weight = torch.tensor([[10, 20, 30, 40]])
num_replicas = 4
num_groups = 1
num_nodes = 1
num_gpus = 1

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify shapes
assert phy2log.shape == (1, 4)
assert log2phy.shape[0] == 1
assert log2phy.shape[1] == 4
assert logcnt.shape == (1, 4)

# Verify all logical experts are mapped
assert set(phy2log[0].tolist()) == {0, 1, 2, 3}


def test_equal_weights():
"""Test case with equal weights"""
weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]])
num_replicas = 8
num_groups = 2
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify shapes
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 8)

# With equal weights, each expert should have exactly one replica
assert torch.all(
logcnt == 1
), "With equal weights and no replication, " \
"each expert should have exactly 1 replica"


def test_extreme_weight_imbalance():
"""Test extreme weight imbalance case"""
weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]])
num_replicas = 12
num_groups = 2
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify shapes
assert phy2log.shape == (1, 12)
assert logcnt.shape == (1, 8)

# Expert with highest weight (index 0) should have more replicas
assert (
logcnt[0, 0]
> logcnt[0, 1]), "Expert with highest weight should have more replicas"


def test_multiple_layers():
"""Test multiple layers case"""
weight = torch.tensor([
[10, 20, 30, 40, 50, 60], # First layer
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
])
num_replicas = 8
num_groups = 2
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify shapes
assert phy2log.shape == (3, 8)
assert logcnt.shape == (3, 6)

# Verify expert allocation is reasonable for each layer
for layer in range(3):
assert torch.all(phy2log[layer] >= 0) and torch.all(
phy2log[layer] < 6
), f"Layer {layer} physical to logical mapping" \
"should be in range [0, 6)"
assert (torch.sum(logcnt[layer]) == num_replicas
), f"Layer {layer} total replicas should be {num_replicas}"


def test_parameter_validation():
"""Test parameter validation"""
weight = torch.tensor([[10, 20, 30, 40]])

# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4)

# Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus
with pytest.raises(AssertionError):
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4


def test_small_scale_hierarchical():
"""Test small-scale hierarchical load balancing"""
weight = torch.tensor([
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
])
num_replicas = 12
num_groups = 4 # 4 groups, 2 experts each
num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Verify basic constraints
assert phy2log.shape == (1, 12)
assert logcnt.shape == (1, 8)
assert torch.sum(logcnt) == num_replicas
assert torch.all(logcnt >= 1)

# Expert with highest weight should have more replicas
max_weight_expert = torch.argmax(weight[0])
assert (logcnt[0, max_weight_expert]
>= 2), "Highest weight expert should have multiple replicas"


def test_global_load_balance_fallback():
"""Test global load balancing fallback case"""
# When num_groups % num_nodes != 0, should fall back to global load
# balancing
weight = torch.tensor([[10, 20, 30, 40, 50, 60]])
num_replicas = 8
num_groups = 3 # Cannot be divided evenly by num_nodes=2
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Should work normally, just using global load balancing strategy
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 6)
assert torch.sum(logcnt) == num_replicas


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_device_compatibility(device):
"""Test device compatibility"""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

weight = torch.tensor([[10, 20, 30, 40]], device=device)
num_replicas = 6
num_groups = 2
num_nodes = 1
num_gpus = 2

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)

# Function will convert to CPU internally, but should handle different
# device inputs normally
assert phy2log.shape == (1, 6)
assert logcnt.shape == (1, 4)


def test_additional_cases():
"""Test more edge cases and different parameter combinations"""

# Test case 1: Large-scale distributed setup
weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)

assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16)
assert torch.sum(logcnt1) == 24

# Test case 2: Different weight distributions
weight2 = torch.tensor([
[200, 150, 100, 50, 25, 12], # Decreasing weights
[12, 25, 50, 100, 150, 200], # Increasing weights
])
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)

assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6)

# Verify high-weight experts have more replicas
for layer in range(2):
max_weight_idx = torch.argmax(weight2[layer])
assert logcnt2[layer, max_weight_idx] >= 2


if __name__ == "__main__":
weight = torch.tensor([
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
])

num_replicas = 16
num_groups = 4
num_nodes = 2
num_gpus = 8

phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
num_groups, num_nodes,
num_gpus)
print(phy2log)

test_basic_rebalance()
Loading