diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1536759c06bd..26f70ad457b6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -168,6 +168,23 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd +- label: EPLB Algorithm Test + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution Test # 5min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - label: Metrics, Tracing Test # 10min mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py new file mode 100644 index 000000000000..e47ccba99c81 --- /dev/null +++ b/tests/distributed/test_eplb_algo.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.distributed.eplb.rebalance_algo import rebalance_experts + + +def test_basic_rebalance(): + """Test basic rebalancing functionality""" + # Example from https://github.com/deepseek-ai/eplb + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_layers = weight.shape[0] + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify output shapes + assert phy2log.shape == ( + 2, + 16, + ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" + assert (log2phy.shape[0] == 2 + ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + assert ( + log2phy.shape[1] == 12 + ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert logcnt.shape == ( + 2, + 12, + ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" + + # Verify physical to logical expert mapping range is correct + assert torch.all(phy2log >= 0) and torch.all( + phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + + # Verify expert count reasonableness + assert torch.all( + logcnt >= 1), "Each logical expert should have at least 1 replica" + assert ( + torch.sum(logcnt, dim=1).sum() == num_replicas * + num_layers), f"Total replicas should be {num_replicas * num_layers}" + + # Verify expected output + expected_phy2log = torch.tensor([ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ]) + assert torch.all(phy2log == expected_phy2log) + + expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], + [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + assert torch.all(logcnt == expected_logcnt) + + +def test_single_gpu_case(): + """Test single GPU case""" + weight = torch.tensor([[10, 20, 30, 40]]) + num_replicas = 4 + num_groups = 1 + num_nodes = 1 + num_gpus = 1 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 4) + assert log2phy.shape[0] == 1 + assert log2phy.shape[1] == 4 + assert logcnt.shape == (1, 4) + + # Verify all logical experts are mapped + assert set(phy2log[0].tolist()) == {0, 1, 2, 3} + + +def test_equal_weights(): + """Test case with equal weights""" + weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 8) + + # With equal weights, each expert should have exactly one replica + assert torch.all( + logcnt == 1 + ), "With equal weights and no replication, " \ + "each expert should have exactly 1 replica" + + +def test_extreme_weight_imbalance(): + """Test extreme weight imbalance case""" + weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]]) + num_replicas = 12 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + + # Expert with highest weight (index 0) should have more replicas + assert ( + logcnt[0, 0] + > logcnt[0, 1]), "Expert with highest weight should have more replicas" + + +def test_multiple_layers(): + """Test multiple layers case""" + weight = torch.tensor([ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (3, 8) + assert logcnt.shape == (3, 6) + + # Verify expert allocation is reasonable for each layer + for layer in range(3): + assert torch.all(phy2log[layer] >= 0) and torch.all( + phy2log[layer] < 6 + ), f"Layer {layer} physical to logical mapping" \ + "should be in range [0, 6)" + assert (torch.sum(logcnt[layer]) == num_replicas + ), f"Layer {layer} total replicas should be {num_replicas}" + + +def test_parameter_validation(): + """Test parameter validation""" + weight = torch.tensor([[10, 20, 30, 40]]) + + # Test non-divisible case - this should handle normally without throwing + # errors because the function will fall back to global load balancing + # strategy + phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 4) + + # Test cases that will actually cause errors: + # num_physical_experts not divisible by num_gpus + with pytest.raises(AssertionError): + rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 + + +def test_small_scale_hierarchical(): + """Test small-scale hierarchical load balancing""" + weight = torch.tensor([ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ]) + num_replicas = 12 + num_groups = 4 # 4 groups, 2 experts each + num_nodes = 2 # 2 nodes + num_gpus = 4 # 4 GPUs + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify basic constraints + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + assert torch.sum(logcnt) == num_replicas + assert torch.all(logcnt >= 1) + + # Expert with highest weight should have more replicas + max_weight_expert = torch.argmax(weight[0]) + assert (logcnt[0, max_weight_expert] + >= 2), "Highest weight expert should have multiple replicas" + + +def test_global_load_balance_fallback(): + """Test global load balancing fallback case""" + # When num_groups % num_nodes != 0, should fall back to global load + # balancing + weight = torch.tensor([[10, 20, 30, 40, 50, 60]]) + num_replicas = 8 + num_groups = 3 # Cannot be divided evenly by num_nodes=2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Should work normally, just using global load balancing strategy + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 6) + assert torch.sum(logcnt) == num_replicas + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_compatibility(device): + """Test device compatibility""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + weight = torch.tensor([[10, 20, 30, 40]], device=device) + num_replicas = 6 + num_groups = 2 + num_nodes = 1 + num_gpus = 2 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Function will convert to CPU internally, but should handle different + # device inputs normally + assert phy2log.shape == (1, 6) + assert logcnt.shape == (1, 4) + + +def test_additional_cases(): + """Test more edge cases and different parameter combinations""" + + # Test case 1: Large-scale distributed setup + weight1 = torch.tensor( + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) + + assert phy2log1.shape == (1, 24) + assert logcnt1.shape == (1, 16) + assert torch.sum(logcnt1) == 24 + + # Test case 2: Different weight distributions + weight2 = torch.tensor([ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ]) + phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) + + assert phy2log2.shape == (2, 10) + assert logcnt2.shape == (2, 6) + + # Verify high-weight experts have more replicas + for layer in range(2): + max_weight_idx = torch.argmax(weight2[layer]) + assert logcnt2[layer, max_weight_idx] >= 2 + + +if __name__ == "__main__": + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + print(phy2log) + + test_basic_rebalance() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py new file mode 100644 index 000000000000..de9ed1eabbac --- /dev/null +++ b/tests/distributed/test_eplb_execute.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import os +import random + +import pytest +import torch +import torch.distributed + +from vllm.distributed.eplb.rebalance_execute import ( + rearrange_expert_weights_inplace) +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment) +from vllm.utils import update_environment_variables + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[multiprocessing.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + init_distributed_environment() + + # Ensure each worker process has the same random seed + random.seed(42) + torch.manual_seed(42) + + fn() + + return wrapped_fn + + +def create_expert_indices_with_redundancy( + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert +) -> torch.Tensor: + """ + Create expert indices with redundancy. + + Args: + num_layers: number of layers + num_logical_experts: number of logical experts + total_physical_experts: total number of physical experts + redundancy_config: redundancy for each logical expert + + Returns: + indices: Shape (num_layers, total_physical_experts) + """ + assert sum(redundancy_config) == total_physical_experts + assert len(redundancy_config) == num_logical_experts + + indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long) + + for layer in range(num_layers): + physical_pos = 0 + for logical_expert_id, redundancy in enumerate(redundancy_config): + for _ in range(redundancy): + indices[layer, physical_pos] = logical_expert_id + physical_pos += 1 + + # Shuffle the indices at dim 1 + for layer in range(num_layers): + indices[layer] = indices[layer][torch.randperm(indices.shape[1])] + + return indices + + +def create_expert_weights( + num_layers: int, + num_local_experts: int, + hidden_sizes: list[int], + rank: int, + device: torch.device, + physical_to_logical_mapping: torch.Tensor, +) -> list[list[torch.Tensor]]: + """ + Create fake expert weights tensor for testing. + + Use `arange` to generate predictable weights values, based on logical + expert ID. + All replicas of the same logical expert should have the same weights. + + Args: + physical_to_logical_mapping: Shape (num_layers, num_local_experts) + mapping[layer, physical_pos] = logical_expert_id + """ + expert_weights = [] + + for layer in range(num_layers): + layer_weights = [] + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = torch.zeros(num_local_experts, + hidden_size, + device=device, + dtype=torch.float32) + + for local_expert in range(num_local_experts): + # Get the logical expert ID for this physical expert + global_pos = rank * num_local_experts + local_expert + logical_expert_id = physical_to_logical_mapping[ + layer, global_pos].item() + + # Generate weights based on logical expert ID + # (so that all replicas of the same logical expert have the + # same weights) + base_value = (logical_expert_id * 1000 + layer * 100 + + weight_idx * 10) + weight_tensor[local_expert] = torch.arange(base_value, + base_value + + hidden_size, + device=device, + dtype=torch.float32) + + layer_weights.append(weight_tensor) + expert_weights.append(layer_weights) + + return expert_weights + + +def create_redundancy_config( + num_logical_experts: int, + num_physical_experts: int, +) -> list[int]: + """Create a redundancy configuration.""" + redundancy_config = [1] * num_logical_experts + remaining = num_physical_experts - num_logical_experts + # Randomly assign the remaining physical experts to the logical experts + for _ in range(remaining): + redundancy_config[random.choice(range(num_logical_experts))] += 1 + return redundancy_config + + +def verify_expert_weights_after_shuffle( + expert_weights: list[list[torch.Tensor]], + new_indices: torch.Tensor, + hidden_sizes: list[int], + ep_rank: int, + num_local_experts: int, +): + """Verify the weights after shuffling are correct.""" + num_layers = len(expert_weights) + + for layer in range(num_layers): + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = expert_weights[layer][weight_idx] + + for local_expert in range(num_local_experts): + # Calculate the global expert ID for this local expert + global_pos = ep_rank * num_local_experts + local_expert + expected_logical_expert = new_indices[layer, global_pos].item() + + # Check if the weights are correct + actual_weights = weight_tensor[local_expert] + expected_base = (expected_logical_expert * 1000 + layer * 100 + + weight_idx * 10) + expected_weights = torch.arange(expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype) + + torch.testing.assert_close( + actual_weights, + expected_weights, + msg=f"Layer {layer}, weight {weight_idx}," + f"local expert {local_expert}: " + f"weights do not match. " + f"Expected logical expert {expected_logical_expert}") + + +def verify_redundant_experts_have_same_weights( + expert_weights: list[list[torch.Tensor]], + indices: torch.Tensor, + hidden_sizes: list[int], + world_size: int, + num_local_experts: int, +): + """ + Verify that all replicas of the same logical expert have the same weights. + """ + num_layers = len(expert_weights) + total_physical_experts = world_size * num_local_experts + + for layer in range(num_layers): + # Collect weights for all physical experts for each weight matrix + all_weights: list[torch.Tensor] = [] + + for weight_idx, hidden_size in enumerate(hidden_sizes): + # Create tensor to store all expert weights + # Shape: [total_physical_experts, hidden_size] + gathered_weights = torch.zeros( + total_physical_experts, + hidden_size, + device=expert_weights[layer][weight_idx].device, + dtype=expert_weights[layer][weight_idx].dtype) + + # Use all_gather to collect expert weights from current node + # expert_weights[layer][weight_idx] shape: + # [num_local_experts, hidden_size] + local_weights = expert_weights[layer][ + weight_idx] # [num_local_experts, hidden_size] + + # Split tensor along dim 0 into a list for all_gather + gathered_weights_list = torch.chunk(gathered_weights, + world_size, + dim=0) + + torch.distributed.all_gather( + # Output list: each element corresponds to one rank's weights + list(gathered_weights_list), + local_weights # Input: current rank's local weights + ) + + all_weights.append(gathered_weights) + + # Verify that all replicas of the same logical expert have the same + # weights + logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {} + + for physical_pos in range(total_physical_experts): + logical_expert_id = int(indices[layer, physical_pos].item()) + + if logical_expert_id not in logical_expert_weights: + # First time encountering this logical expert, save its weights + logical_expert_weights[logical_expert_id] = { + weight_idx: all_weights[weight_idx][physical_pos] + for weight_idx in range(len(hidden_sizes)) + } + else: + # Verify that current physical expert's weights match the + # previously saved logical expert weights + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + all_weights[weight_idx][physical_pos], + logical_expert_weights[logical_expert_id][weight_idx], + msg=f"Layer {layer}, weight {weight_idx}," + f"logical expert {logical_expert_id}: " + f"Physical expert {physical_pos} has different weights" + f"than expected") + + +@pytest.mark.parametrize( + "world_size,num_layers,num_local_experts,num_logical_experts", + [ + # 2 GPU, 2 experts per GPU + # 3 logical experts, 4 physical experts, 1 redundant experts + (2, 1, 2, 3), + # 2 GPU, 3 experts per GPU + # 4 logical experts, 6 physical experts, 2 redundant experts + (2, 2, 3, 4), + # 2 GPU, 8 experts per GPU + # 16 logical experts, 16 physical experts, 0 redundant experts + (2, 4, 8, 16), + # 4 GPU, 2 experts per GPU + # 6 logical experts, 8 physical experts, 2 redundant experts + (4, 1, 2, 6), + # 4 GPU, 2 experts per GPU + # 5 logical experts, 8 physical experts, 3 redundant experts + (4, 2, 2, 5), + # 4 GPU, 8 experts per GPU + # 16 logical experts, 32 physical experts, 16 redundant experts + (4, 8, 8, 16), + ]) +def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, + num_local_experts, + num_logical_experts): + """Test the functionality of rearranging expert weights with redundancy.""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + # Initialize model parallel (using tensor parallel as an entrypoint + # to expert parallel) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + # Test parameters + total_physical_experts = world_size * num_local_experts + hidden_sizes = [32, 64] # Two different weight matrices + + # Create old expert indices (with redundancy) + redundancy_config = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) + + # Create new expert indices (with redundancy) + new_redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) + + # Create expert weights + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Execute weight rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=False, + ) + + # Verify the rearrangement result + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_no_change(world_size): + """ + Test that when the indices do not change, the weights should remain + unchanged. + """ + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 2 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 # Some redundancy + hidden_sizes = [32, 64] + + # Create redundancy configuration + redundancy_config = [2] * num_logical_experts + + # Same indices - no change + indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + redundancy_config) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute rearrangement (should be no change) + rearrange_expert_weights_inplace( + indices, + indices, # Same indices + expert_weights, + ep_group, + is_profile=False) + + # Verify that the weights have not changed + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg=f"Layer {layer}, weight {weight_idx} should remain " + f"unchanged") + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_profile_mode(world_size): + """Test profile mode (should not copy actual weights)""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 1 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 + hidden_sizes = [32] + + # Create different index distributions + old_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + new_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + old_redundancy) + new_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + new_redundancy) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute profile mode rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=True # Profile mode + ) + + # In profile mode, the weights should remain unchanged + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg="In profile mode, the weights should remain unchanged") + + distributed_run(worker_fn, world_size) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 54e8cd597bfc..e56bc925c9c4 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -31,12 +31,20 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: text_config = hf_config.get_text_config() + # Ensure at least 2 expert per group + # Since `grouped_topk` assums top-2 + num_experts = getattr(text_config, 'n_group', 1) * 2 + text_config.update({ "num_layers": 1, "num_hidden_layers": 1, - "num_experts": 2, + "num_experts": num_experts, "num_experts_per_tok": 2, - "num_local_experts": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, }) if hasattr(hf_config, "vision_config"): diff --git a/vllm/config.py b/vllm/config.py index e90ad5e9c8b6..14d59b9d39dc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1775,6 +1775,25 @@ class ParallelConfig: """Backend to use for data parallel, either "mp" or "ray".""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + enable_eplb: bool = False + """Enable expert parallelism load balancing for MoE layers.""" + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + eplb_window_size: int = 1000 + """Window size for expert load recording.""" + eplb_step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `eplb_window_size` steps will be used for rearranging experts. + """ + eplb_log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor @@ -1913,6 +1932,20 @@ def __post_init__(self) -> None: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now.") + if self.num_redundant_experts < 0: + raise ValueError( + "num_redundant_experts must be non-negative, but got " + f"{self.num_redundant_experts}.") + else: + if self.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts should be used with EPLB." + f"{self.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py new file mode 100644 index 000000000000..c87b039afd73 --- /dev/null +++ b/vllm/distributed/eplb/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +''' +Expert parallelism load balancer (EPLB). +''' + +from .eplb_state import * +from .rebalance_algo import * diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py new file mode 100644 index 000000000000..2185df865c1f --- /dev/null +++ b/vllm/distributed/eplb/eplb_state.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) metrics and states. + +# Glossary + +- **Logical Expert**: An expert that is part of the model's logical structure. + It holds a set of weights and is replicated across multiple physical + experts. +- **Redundant Expert**: To achieve load balancing, for some popular logical + experts, we create additional copies of the expert weights. During inference, + each of these copies can be routed to by the same set of tokens. +- **Physical Expert**: An expert that is instantiated on a specific device. + It is a replica of a logical expert and can be rearranged across devices. + I.e., one logical expert may have multiple sets of weights initialized on + different devices, and each of these sets is a physical expert. +- **Local Physical Expert**: A physical expert that is instantiated on the + current device. + +For example: DeepSeek-R1 has 256 logical experts, so each MoE layer +has 256 sets of linear layer weights in the model parameters. If we add 32 +redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in +total. And when deploying, we'll have 288 sets of linear layer weights for each +MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local +physical experts. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +from torch.distributed import all_gather, all_reduce + +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import MixtureOfExperts + +from .rebalance_algo import rebalance_experts +from .rebalance_execute import rearrange_expert_weights_inplace + +logger = init_logger(__name__) + + +@dataclass +class EplbState: + """EPLB metrics.""" + + physical_to_logical_map: torch.Tensor + """ + Mapping from physical experts to logical experts. + + Shape: (num_moe_layers, num_physical_experts) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[0, 1, 2, 3, 0, 1], + [0, 2, 0, 1, 0, 3]] + ``` + """ + logical_to_physical_map: torch.Tensor + """ + Mapping from logical experts to physical experts. + + This is a sparse matrix, where -1 indicates no mapping. + + Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[[0, 4, -1], + [1, 5, -1], + [2, -1, -1], + [3, -1, -1]], + [[0, 2, 4], + [3, -1, -1], + [1, -1, -1], + [5, -1, -1]]] + ``` + """ + logical_replica_count: torch.Tensor + """ + Number of replicas for each logical expert. + This is exactly the non-`-1` count in the `logical_to_physical_map`. + + Shape: (num_moe_layers, num_logical_experts) + + # Example + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the count could look like this: + + ``` + [[2, 2, 1, 1], + [3, 1, 1, 1]] + """ + + expert_load_pass: torch.Tensor + """ + Expert load during this forward pass. + We use the token count each expert processes as the load. + + Shape: (num_moe_layers, num_local_physical_experts) + """ + expert_load_window: torch.Tensor + """ + A sliding window of expert load. + + Shape: (window_size, num_moe_layers, num_local_physical_experts) + """ + expert_load_window_step: int = 0 + """ + Current step in the sliding window. + + Different from `expert_rearrangement_step`, each EP rank may have its own + `expert_load_window_step`. + """ + expert_load_window_size: int = 0 + """ + Size of the expert load sliding window. + This is a constant and is taken from the config. + """ + + expert_rearrangement_step: int = 0 + """ + Steps after last rearrangement. + Will trigger a rearrangement if it exceeds the threshold. + + NOTE: Keep in mind that all EP ranks need to have the same + `expert_rearrangement_step` value to ensure synchronization. + Otherwise, the rearrangement will hang at collective + communication calls. + """ + expert_rearrangement_step_interval: int = 0 + """ + Interval for expert rearrangement steps. + This is a constant and is taken from the config. + """ + + @staticmethod + def build_initial_global_physical_to_logical_map( + num_routed_experts: int, + num_redundant_experts: int, + ) -> Sequence[int]: + """ + Build an initial expert arrangement using the following structure: + [original routed experts, redundant experts] + + Returns: + physical_to_logical_map (Sequence[int]): A list of integers, + where each integer is the index of the logical expert + that the corresponding physical expert maps to. + """ + global_physical_to_logical_map = list(range(num_routed_experts)) + global_physical_to_logical_map += [ + i % num_routed_experts for i in range(num_redundant_experts) + ] + return global_physical_to_logical_map + + @classmethod + def build( + cls, + model: MixtureOfExperts, + device: torch.device, + parallel_config: ParallelConfig, + ) -> "EplbState": + """ + Build the initial EPLB state. + """ + physical_to_logical_map_list = ( + cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + )) + physical_to_logical_map = torch.tensor( + physical_to_logical_map_list, + device=device, + ) + logical_to_physical_map = torch.full( + (model.num_logical_experts, model.num_redundant_experts + 1), + -1, + device=device, + ) + logical_replica_count = torch.zeros( + (model.num_logical_experts, ), + device=device, + dtype=torch.long, + ) + + for i in range(model.num_physical_experts): + logical_idx = physical_to_logical_map[i] + logical_to_physical_map[logical_idx, + logical_replica_count[logical_idx]] = i + logical_replica_count[logical_idx] += 1 + + # Duplicate initial mapping for all layers + physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + -1, + ).contiguous() + logical_replica_count = logical_replica_count.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + + expert_load_pass = torch.zeros( + (model.num_moe_layers, model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + expert_load_window_size = parallel_config.eplb_window_size + expert_load_window = torch.zeros( + (expert_load_window_size, model.num_moe_layers, + model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + + # Set the initial progress of rearrangement to 3/4 + eplb_step_interval = parallel_config.eplb_step_interval + expert_rearrangement_step = max( + 0, eplb_step_interval - eplb_step_interval // 4) + + model.set_eplb_state( + expert_load_pass, + logical_to_physical_map, + logical_replica_count, + ) + + return cls( + physical_to_logical_map, + logical_to_physical_map, + logical_replica_count, + expert_load_pass, + expert_load_window, + expert_load_window_size=expert_load_window_size, + expert_rearrangement_step=expert_rearrangement_step, + expert_rearrangement_step_interval=eplb_step_interval, + ) + + def step(self, + model: MixtureOfExperts, + is_dummy: bool = False, + is_profile: bool = False, + log_stats: bool = False) -> None: + """ + Step the EPLB state. + + Args: + model (MixtureOfExperts): The MoE model. + is_dummy (bool): If `True`, this is a dummy step and the load + metrics recorded in this forward pass will not count. Defaults + to `False`. + is_profile (bool): If `True`, perform a dummy rearrangement + with maximum communication cost. This is used in `profile_run` + to reserve enough memory for the communication buffer. + log_stats (bool): If `True`, log the expert load metrics. + + # Stats + The metrics are all summed up across layers. + - `avg_tokens`: The average load across ranks. + - `max_tokens`: The maximum load across ranks. + - `balancedness`: The ratio of average load to maximum load. + """ + + if is_profile: + self.rearrange(model, is_profile=True) + return + + if is_dummy: + # Do not record load metrics for dummy steps + self.expert_load_pass.zero_() + + if log_stats: + # `num_tokens`: (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + + # Collect load metrics from all ranks + ep_group = get_ep_group().device_group + num_tokens_list = [ + torch.empty_like(num_tokens) for _ in range(ep_group.size()) + ] + all_gather(num_tokens_list, num_tokens, group=ep_group) + # Stack to get (num_ranks, num_moe_layers) + num_tokens_per_rank = torch.stack(num_tokens_list).float() + + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( + dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor]).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + + if ep_group.rank() == 0: + logger.info( + "EPLB step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + + # Update the expert load sliding window + if not is_dummy: + self.expert_load_window[self.expert_load_window_step] = ( + self.expert_load_pass.clone()) + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 + self.expert_load_pass.zero_() + + # Step the expert rearrangement step + # Note that even if this is a dummy step, we still increment the + # rearrangement step and perform rearrangement to ensure all ranks are + # performing collective communication. + self.expert_rearrangement_step += 1 + if (self.expert_rearrangement_step + >= self.expert_rearrangement_step_interval): + self.expert_rearrangement_step = 0 + self.rearrange(model) + + def rearrange(self, + model: MixtureOfExperts, + is_profile: bool = False) -> None: + """ + Rearrange the experts according to the current load. + """ + + ep_group = get_ep_group().device_group + ep_rank = ep_group.rank() + + time_start = None + is_main_rank = ep_rank == 0 + if is_main_rank: + torch.cuda.synchronize() + time_start = time.perf_counter() + logger.info("Rearranging experts %s...", + "(profile)" if is_profile else "") + + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] + + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + # TODO(bowen): Treat differently for prefill and decode nodes + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + # Update expert weights + rearrange_expert_weights_inplace( + self.physical_to_logical_map, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + is_profile, + ) + + if not is_profile: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + self.logical_to_physical_map.copy_(new_logical_to_physical_map) + self.logical_replica_count.copy_(new_logical_replica_count) + + if is_main_rank: + assert time_start is not None + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", + time_end - time_start, + ) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py new file mode 100644 index 000000000000..7ad6d566b55b --- /dev/null +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. +""" + +import torch + + +def balanced_packing(weight: torch.Tensor, + num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), + dtype=torch.int64, + device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, + fill_value=-1, + dtype=torch.int64, + device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i + for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, + num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, + device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, + device=perm.device).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing( + tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * + group_size).unsqueeze(-1) + + torch.arange(group_size, + dtype=torch.int64, + device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, + num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of + each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica + indices for each expert + expert_count: [layers, num_logical_experts], number of physical + replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus) + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, + device=log2phy.device).expand(num_layers, -1), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py new file mode 100644 index 000000000000..cf173c734afd --- /dev/null +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +The actual execution of the rearrangement. + +This involves the exchange of expert weights between GPUs. +""" + +from collections.abc import Iterable, MutableSequence, Sequence +from functools import partial + +import torch +from torch.distributed import (P2POp, ProcessGroup, all_gather, + batch_isend_irecv, get_global_rank) + + +def idx_local_to_global( + local_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a local expert index to a global expert index. + """ + return ep_rank * local_cnt + local_idx + + +def idx_global_to_local( + global_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a global expert index to a local expert index. + """ + return global_idx - ep_rank * local_cnt + + +def global_idx_to_rank( + global_idx: int, + local_cnt: int, +) -> int: + """ + Convert a global expert index to a rank index. + """ + return global_idx // local_cnt + + +def get_ep_ranks_with_expert( + idx: int, + num_local_experts: int, + old_indices: Sequence[int], + new_indices: Sequence[int], +) -> tuple[MutableSequence[int], MutableSequence[int]]: + """ + Get the ranks of the experts that need to be exchanged. + + Args: + idx: The index of the expert. + num_local_experts: The number of local experts. + old_indices: The old indices of the experts. + new_indices: The new indices of the experts. + + Returns: + A tuple of two lists: + - The ranks of the experts that need to be sent. + - The ranks of the experts that need to be received. + """ + global2rank = partial( + global_idx_to_rank, + local_cnt=num_local_experts, + ) + + ranks_to_send: list[int] = [] + ranks_to_recv: list[int] = [] + + for i, e in enumerate(old_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_send or ranks_to_send[-1] != rank: + ranks_to_send.append(rank) + + for i, e in enumerate(new_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_recv or ranks_to_recv[-1] != rank: + ranks_to_recv.append(rank) + + # Remove those ranks that can get this expert locally. + ranks_to_send_set = set(ranks_to_send) + ranks_to_recv_actual = [ + rank for rank in ranks_to_recv if rank not in ranks_to_send_set + ] + + return ranks_to_send, ranks_to_recv_actual + + +def shuffle_layer( + num_local_experts: int, + ep_rank: int, + old_indices: Sequence[int], + new_indices: Sequence[int], + expert_weights: Iterable[torch.Tensor], + expert_weights_buffer: Sequence[torch.Tensor], + ep_group: ProcessGroup, +) -> None: + """ + Perform expert weights rearrangement of one layer. + """ + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + + # 0. Do nothing for experts that did not change. + is_unchanged = [ + old_indices[local2global(i)] == new_indices[local2global(i)] + for i in range(num_local_experts) + ] + + # 1. Perform weight copy inside the local rank. + is_received_locally = is_unchanged[:] + for src in range(num_local_experts): + src_global = local2global(src) + for dst in range(num_local_experts): + dst_global = local2global(dst) + if is_received_locally[dst]: + continue + if old_indices[src_global] == new_indices[dst_global]: + is_received_locally[dst] = True + for weight, buffer in zip(expert_weights, + expert_weights_buffer): + buffer[dst].copy_(weight[src]) + + p2p_ops: list[P2POp] = [] + + # 2. Initiate sending of weights. + experts_send_loc: dict[int, int] = {} + for src in range(num_local_experts): + expert = old_indices[local2global(src)] + if expert in experts_send_loc: + continue + experts_send_loc[expert] = src + + # We need to sort here to match send/recv + for expert, src in sorted(experts_send_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the ranks to send by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + + # Tackle remainders + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + + for dst in recv_ranks: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) for weight in expert_weights + ] + + # 3. Initiate receiving of weights. + experts_recv_loc: dict[int, int] = {} + for dst in range(num_local_experts): + if is_received_locally[dst]: + continue + expert = new_indices[local2global(dst)] + if expert in experts_recv_loc: + continue + experts_recv_loc[expert] = dst + + # We need to sort here to match send/recv + for expert, dst in sorted(experts_recv_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the rank to recv by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) for weight in expert_weights_buffer + ] + + # 4. Execute the P2P operations. The real communication happens here. + if p2p_ops: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + # 5. Copy the weights from the buffer back to the original weights. + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + if is_received_locally[dst]: + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[dst]) + else: + expert = new_indices[local2global(dst)] + src = experts_recv_loc[expert] + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[src]) + + +def rearrange_expert_weights_inplace( + old_global_expert_indices: torch.Tensor, + new_global_expert_indices: torch.Tensor, + expert_weights: Sequence[Iterable[torch.Tensor]], + ep_group: ProcessGroup, + is_profile: bool = False, +) -> None: + """ + Rearranges the expert weights in place according to the new expert indices. + + The value of the indices arguments are logical indices of the experts, + while keys are physical. + + Args: + old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + expert_weights: A sequence of shape (num_moe_layers)(weight_count) + of tensors of shape (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection, + so weight_count = 2. Each weight's hidden size can be different. + ep_group: The device process group for expert parallelism. + is_profile (bool): If `True`, do not perform any actual weight copy. + This is used during profile run, where we only perform dummy + communications to reserve enough memory for the buffers. + """ + num_moe_layers, num_physical_experts = old_global_expert_indices.shape + assert len(expert_weights) == num_moe_layers + + num_local_physical_experts = next(iter(expert_weights[0])).shape[0] + assert new_global_expert_indices.shape == (num_moe_layers, + num_physical_experts) + + ep_rank = ep_group.rank() + ep_size = ep_group.size() + assert num_physical_experts == ep_size * num_local_physical_experts + + # A buffer to hold the expert weights in one layer during the exchange. + # NOTE: Currently we assume the same weights across different layers + # have the same shape. + expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] + + if is_profile: + # Maximum send size is to send all local experts to all ranks, + # So we use a dummy `all_gather` to reserve enough communication buffer + for weight, buffer in zip(expert_weights[0], expert_weights_buffer): + # A `/dev/null`-like buffer to avoid real memory allocation + dummy_recv_buffer = [buffer for _ in range(ep_size)] + # NOTE(bowen): Needed this barrier to avoid OOM during actual + # execution. I'm not very sure why this is needed + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) + return + + for layer in range(num_moe_layers): + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + shuffle_layer( + num_local_physical_experts, + ep_rank, + old_global_expert_indices[layer].tolist(), + new_global_expert_indices[layer].tolist(), + expert_weights[layer], + expert_weights_buffer, + ep_group, + ) + + +__all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9d1008b6b350..6c908f88b9a9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -320,6 +320,11 @@ class EngineArgs: data_parallel_rpc_port: Optional[int] = None data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_eplb: bool = ParallelConfig.enable_eplb + num_redundant_experts: int = ParallelConfig.num_redundant_experts + eplb_window_size: int = ParallelConfig.eplb_window_size + eplb_step_interval: int = ParallelConfig.eplb_step_interval + eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -666,6 +671,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-eplb", + **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--num-redundant-experts", + **parallel_kwargs["num_redundant_experts"]) + parallel_group.add_argument("--eplb-window-size", + **parallel_kwargs["eplb_window_size"]) + parallel_group.add_argument("--eplb-step-interval", + **parallel_kwargs["eplb_step_interval"]) + parallel_group.add_argument("--eplb-log-balancedness", + **parallel_kwargs["eplb_log_balancedness"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1135,6 +1150,11 @@ def create_engine_config( data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.num_redundant_experts, + eplb_window_size=self.eplb_window_size, + eplb_step_interval=self.eplb_step_interval, + eplb_log_balancedness=self.eplb_log_balancedness, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c1bae033c2b4..de905a85cc18 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,9 +3,10 @@ import importlib from abc import abstractmethod +from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Callable, Optional, Union +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -20,6 +21,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -433,6 +435,10 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -572,7 +578,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + return self.forward( x=x, layer=layer, @@ -819,6 +833,7 @@ class FusedMoE(torch.nn.Module): reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. + enable_eplb: Whether to enable expert parallelism load balancer. """ def __init__( @@ -843,6 +858,8 @@ def __init__( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, ): super().__init__() if params_dtype is None: @@ -858,7 +875,7 @@ def __init__( get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config)) - self.global_num_experts = num_experts + self.global_num_experts = num_experts + num_redundant_experts # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -867,8 +884,20 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.enable_eplb = enable_eplb + self.expert_load_view: Optional[torch.Tensor] = None + self.logical_to_physical_map: Optional[torch.Tensor] = None + self.logical_replica_count: Optional[torch.Tensor] = None + # Determine expert maps if self.use_ep: + if self.enable_eplb: + assert self.global_num_experts % self.ep_size == 0, \ + "EPLB currently only supports even distribution of " \ + "experts across ranks." + else: + assert num_redundant_experts == 0, \ + "Redundant experts are only supported with EPLB." self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -935,6 +964,20 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if self.enable_eplb: + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8MoEMethod) + if not isinstance(quant_method, Fp8MoEMethod): + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -963,8 +1006,9 @@ def __init__( dtype=act_dtype, device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts), + (envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts), dtype=act_dtype, device=torch.cuda.current_device()) @@ -1128,13 +1172,33 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: return expert_id return self.expert_map[expert_id].item() + @overload def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: + shard_id: str, expert_id: int, + return_success: Literal[False]) -> None: + ... + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[True]) -> bool: + ... + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False) -> Optional[bool]: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: - return + # Failed to load this param since it's not local to this rank + return False if return_success else None + # Hereafter, `expert_id` is local physical id + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format @@ -1161,7 +1225,7 @@ def weight_loader(self, param: torch.nn.Parameter, if is_gguf_weight_type: param.weight_type = loaded_weight.item() param.data.copy_(loaded_weight) - return + return True if return_success else None # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -1200,7 +1264,7 @@ def weight_loader(self, param: torch.nn.Parameter, self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case g_idx if "g_idx" in weight_name: @@ -1209,7 +1273,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None if "ModelOpt" in quant_method_name: if ('weight_scale_2' in weight_name @@ -1225,7 +1289,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name @@ -1262,7 +1326,7 @@ def weight_loader(self, param: torch.nn.Parameter, else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") - return + return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: @@ -1270,7 +1334,7 @@ def weight_loader(self, param: torch.nn.Parameter, self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case model weights if "weight" in weight_name: @@ -1280,23 +1344,77 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None + + return False if return_success else None + + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = list(self.named_parameters()) + assert all(weight.is_contiguous() for _, weight in weights) + + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + } + + return [ + weight.view(self.local_num_experts, -1) for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + ] + + def set_eplb_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + """ + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None): + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + enable_eplb: bool = False, + expert_map: Optional[torch.Tensor] = None, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): + The weights and *global physical* expert ids of the top-k experts. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - # DeekSeekv2 uses grouped_top_k + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None @@ -1328,6 +1446,74 @@ def select_experts(hidden_states: torch.Tensor, if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_logical_experts,) + + # Mask out non-local experts + if expert_map is not None: + topk_ids_local = expert_map[topk_ids] + topk_ids_flatten = topk_ids_local.flatten() + else: + topk_ids_flatten = topk_ids.flatten() + + # Should be equivalent to: + # ``` + # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] + # expert_load_view += topk_ids_masked.bincount( + # minlength=expert_load_view.shape[0]) + # ``` + # We use `scatter_add_` since `bincount` cannot be compiled + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) + + topk_ids = topk_ids.to(dtype=indices_type) + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: @@ -1408,6 +1594,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if not skip_result_store: @@ -1465,6 +1655,10 @@ def forward_impl(self, hidden_states: torch.Tensor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if do_naive_dispatch_combine: @@ -1479,16 +1673,30 @@ def forward_impl(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> list[tuple[str, str, int, str]]: + num_experts: int, + num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: + + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = \ + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts) return [ # (param_name, weight_name, expert_id, shard_id) ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, shard_id) for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 56d803c6baf1..aff54bc495b2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -482,7 +482,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `AWQMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05b..7703b9e687c4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -331,7 +331,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoEMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -593,7 +601,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -722,7 +738,16 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Int8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -1012,7 +1037,16 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") assert not apply_router_weight_on_input, ( @@ -1228,7 +1262,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsWNA16MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 01b0064f0805..47eca80609e0 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -117,7 +117,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b3042bfaed3d..d2eda541f7a4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -825,7 +825,16 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -839,6 +848,11 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) if self.rocm_aiter_moe_enabled: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9c8f74545d37..86da04c39989 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -520,7 +520,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GGUFMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e9b8dc3266b4..48ab04c9ab37 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -635,7 +635,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3f79b203aa17..e35db5b31dba 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -664,7 +664,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + if self.use_marlin: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3aa23f068257..c5055a02fa3d 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -297,7 +297,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `MoeWNA16Method` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4c2da4c8b04e..a040c430cbca 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -205,7 +205,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0f996d04e6e8..f712b626c74c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import torch @@ -32,8 +33,10 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -51,7 +54,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import MixtureOfExperts, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -99,11 +102,17 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -120,6 +129,22 @@ def __init__( else: self.gate.e_score_correction_bias = None + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -133,7 +158,9 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -503,6 +530,7 @@ def __init__( model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -543,6 +571,7 @@ def __init__( config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -615,6 +644,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -636,6 +666,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config=model_config, cache_config=cache_config, quant_config=quant_config, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers") @@ -681,7 +712,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -700,6 +731,44 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -752,7 +821,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -789,24 +859,44 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -824,6 +914,7 @@ def load_weights(self, weights: Iterable[tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index f759f8f1f273..3ea424e44b62 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) @@ -426,6 +427,73 @@ def is_hybrid( return isinstance(model, IsHybrid) +@runtime_checkable +class MixtureOfExperts(Protocol): + """ + Check if the model is a mixture of experts (MoE) model. + """ + + expert_weights: MutableSequence[Iterable[Tensor]] + """ + Expert weights saved in this rank. + + The first dimension is the layer, and the second dimension is different + parameters in the layer, e.g. up/down projection weights. + """ + + num_moe_layers: int + """Number of MoE layers in this model.""" + + num_expert_groups: int + """Number of expert groups in this model.""" + + num_logical_experts: int + """Number of logical experts in this model.""" + + num_physical_experts: int + """Number of physical experts in this model.""" + + num_local_physical_experts: int + """Number of local physical experts in this model.""" + + num_routed_experts: int + """Number of routed experts in this model.""" + + num_shared_experts: int + """Number of shared experts in this model.""" + + num_redundant_experts: int + """Number of redundant experts in this model.""" + + def set_eplb_state( + self, + expert_load_view: Tensor, + logical_to_physical_map: Tensor, + logical_replica_count: Tensor, + ) -> None: + """ + Register the EPLB state in the MoE model. + + Since these are views of the actual EPLB state, any changes made by + the EPLB algorithm are automatically reflected in the model's behavior + without requiring additional method calls to set new states. + + You should also collect model's `expert_weights` here instead of in + the weight loader, since after initial weight loading, further + processing like quantization may be applied to the weights. + + Args: + expert_load_view: A view of the expert load metrics tensor. + logical_to_physical_map: Mapping from logical to physical experts. + logical_replica_count: Count of replicas for each logical expert. + """ + ... + + +def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: + return isinstance(model, MixtureOfExperts) + + @runtime_checkable class HasNoOps(Protocol): has_noops: ClassVar[Literal[True]] = True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf2433..3c9de5720405 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -21,6 +21,7 @@ from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) +from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -33,7 +34,8 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import has_step_pooler +from vllm.model_executor.models.interfaces import (has_step_pooler, + is_mixture_of_experts) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -150,6 +152,13 @@ def __init__( # Sampler self.sampler = Sampler() + self.eplb_state: Optional[EplbState] = None + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + # Lazy initializations # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache @@ -1178,6 +1187,24 @@ def sync_and_slice_intermediate_tensors( for k, v in self.intermediate_tensors.items() }) + def eplb_step(self, + is_dummy: bool = False, + is_profile: bool = False) -> None: + """ + Step for the EPLB (Expert Parallelism Load Balancing) state. + """ + if not self.parallel_config.enable_eplb: + return + + assert self.eplb_state is not None + assert is_mixture_of_experts(self.model) + self.eplb_state.step( + self.model, + is_dummy, + is_profile, + log_stats=self.parallel_config.eplb_log_balancedness, + ) + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size @@ -1595,6 +1622,8 @@ def execute_model( if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + self.eplb_step() + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, @@ -1729,6 +1758,16 @@ def load_model(self) -> None: time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", + self.model_config.model) + self.eplb_state = EplbState.build( + self.model, + self.device, + self.parallel_config, + ) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -1887,6 +1926,8 @@ def _dummy_run( self, num_tokens: int, capture_attn_cudagraph: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Padding for DP @@ -1983,6 +2024,16 @@ def _dummy_run( assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states, hidden_states[logit_indices] @@ -2175,8 +2226,9 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens) + = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2210,10 +2262,15 @@ def capture_model(self) -> None: for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), desc="Capturing CUDA graphs", total=len(self.cudagraph_batch_sizes)): + # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + skip_eplb=True) + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + skip_eplb=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b0f80c701325..9e7e44d06861 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -259,9 +259,10 @@ def compile_or_warm_up_model(self) -> None: x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) + self.model_runner._dummy_run(size, skip_eplb=True) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -274,8 +275,12 @@ def compile_or_warm_up_model(self) -> None: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) + # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ - self.model_runner._dummy_run(num_tokens=max_num_reqs) + self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) else: