11# SPDX-License-Identifier: Apache-2.0
22
33from abc import abstractmethod
4+ from dataclasses import dataclass
45from enum import Enum
56from typing import Callable , List , Optional , Tuple
6- from dataclasses import dataclass
77
8+ import pplx_kernels as pplx
89import torch
910import torch .nn .functional as F
1011from torch .nn .parameter import UninitializedParameter
1112
12- import pplx_kernels as pplx
13-
1413import vllm .envs as envs
1514from vllm .config import get_current_vllm_config
1615from vllm .distributed import (get_dp_group , get_ep_group ,
4746
4847MOE_DP_CHUNK_SIZE = 256
4948
49+
5050# Adapted from pplx-kernels tests/all_to_all_utils.py
5151@dataclass
5252class MoEConfig :
@@ -64,6 +64,7 @@ class MoEConfig:
6464 out_dtype : torch .dtype = torch .bfloat16
6565 block_size : int = 128
6666
67+
6768class FusedMoeWeightScaleSupported (Enum ):
6869 TENSOR = "tensor"
6970 CHANNEL = "channel"
@@ -100,26 +101,14 @@ def apply(
100101 ) -> torch .Tensor :
101102 raise NotImplementedError
102103
104+
105+ #TODO: Every change in this class is a broken hack!!
103106@CustomOp .register ("unquantized_fused_moe" )
104107class UnquantizedFusedMoEMethod (FusedMoEMethodBase , CustomOp ):
105108 """MoE method without quantization."""
106- def __init__ (self , moe : MoEConfig ):
107- self .all_to_all = pplx .AllToAll (
108- max_num_tokens = MOE_DP_CHUNK_SIZE // moe .dp_size ,
109- num_experts = moe .num_experts ,
110- experts_per_token = moe .experts_per_token ,
111- rank = moe .ep_rank ,
112- world_size = moe .ep_size ,
113- dp_size = moe .dp_size ,
114- hidden_dim = moe .hidden_dim ,
115- hidden_dim_bytes = moe .hidden_dim * moe .in_dtype .itemsize ,
116- hidden_dim_scale_bytes = 0 ,
117- )
118-
119109
120- def __init__ (self ):
110+ def __init__ (self , moe : MoEConfig ):
121111 super ().__init__ ()
122-
123112 self .rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled ()
124113 if self .rocm_aiter_moe_enabled :
125114 from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -903,7 +892,7 @@ def forward(self, hidden_states: torch.Tensor,
903892 self .layer_name )
904893
905894 def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
906- full_router_logits : torch .Tensor ):
895+ full_router_logits : torch .Tensor ):
907896 max_tokens_across_dp = get_forward_context (
908897 ).dp_metadata .max_tokens_across_dp
909898 cu_tokens_across_dp_cpu = get_forward_context (
@@ -919,21 +908,23 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
919908
920909 num_tokens_remaining_across_dp = num_tokens_across_dp
921910 chunk_start = 0
922- chunk_end = min (moe_dp_chunk_size_per_rank , full_hidden_states .shape [0 ])
911+ chunk_end = min (moe_dp_chunk_size_per_rank ,
912+ full_hidden_states .shape [0 ])
923913 full_final_hidden_states = torch .empty_like (full_hidden_states )
924914
925915 for _ in range (0 , max_tokens_across_dp , moe_dp_chunk_size_per_rank ):
926- hidden_states = full_hidden_states [chunk_start :chunk_end ,:]
927- router_logits = full_router_logits [chunk_start :chunk_end ,:]
916+ hidden_states = full_hidden_states [chunk_start :chunk_end , :]
917+ router_logits = full_router_logits [chunk_start :chunk_end , :]
928918
929919 cu_tokens_across_dp_this_iter = torch .cumsum (
930- num_tokens_remaining_across_dp .clamp (max = moe_dp_chunk_size_per_rank ),
920+ num_tokens_remaining_across_dp .clamp (
921+ max = moe_dp_chunk_size_per_rank ),
931922 dim = 0 )
932923
933- hidden_states = self .naive_multicast (hidden_states ,
934- cu_tokens_across_dp_this_iter )
935- router_logits = self .naive_multicast (router_logits ,
936- cu_tokens_across_dp_this_iter )
924+ hidden_states = self .naive_multicast (
925+ hidden_states , cu_tokens_across_dp_this_iter )
926+ router_logits = self .naive_multicast (
927+ router_logits , cu_tokens_across_dp_this_iter )
937928
938929 # Matrix multiply.
939930 final_hidden_states = self .quant_method .apply (
@@ -954,7 +945,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
954945 )
955946
956947 if self .dp_size > 1 :
957- start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_this_iter [self .dp_rank - 1 ]
948+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_this_iter [
949+ self .dp_rank - 1 ]
958950 end = cu_tokens_across_dp_this_iter [self .dp_rank ]
959951
960952 all_hidden_states = get_dp_group ().all_reduce (
@@ -963,20 +955,26 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
963955
964956 if self .reduce_results and (self .tp_size > 1 or self .ep_size > 1 ):
965957 # Default set to False. (May have to add shared expert outputs.)
966- final_hidden_states = tensor_model_parallel_all_reduce (final_hidden_states )
958+ final_hidden_states = tensor_model_parallel_all_reduce (
959+ final_hidden_states )
967960
968- full_final_hidden_states [chunk_start :chunk_end , :].copy_ (final_hidden_states )
961+ full_final_hidden_states [chunk_start :chunk_end , :].copy_ (
962+ final_hidden_states )
969963
970964 # Update bounds
971- num_tokens_remaining_across_dp = torch .clamp (num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank , min = 0 )
965+ num_tokens_remaining_across_dp = torch .clamp (
966+ num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank ,
967+ min = 0 )
968+
972969 def update_chunk_bound (x : int ):
973- return min (x + moe_dp_chunk_size_per_rank , full_hidden_states .shape [0 ])
970+ return min (x + moe_dp_chunk_size_per_rank ,
971+ full_hidden_states .shape [0 ])
972+
974973 chunk_start = update_chunk_bound (chunk_start )
975974 chunk_end = update_chunk_bound (chunk_end )
976975
977976 return full_final_hidden_states
978977
979-
980978 def forward_impl (self , hidden_states : torch .Tensor ,
981979 router_logits : torch .Tensor ):
982980 assert self .quant_method is not None
0 commit comments