99from vllm .model_executor .layers .fused_moe .config import (
1010 FUSED_MOE_UNQUANTIZED_CONFIG , FusedMoEQuantConfig )
1111from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
12- TopKWeightAndReduceDelegate )
12+ TopKWeightAndReduceNoOP )
13+ from vllm .triton_utils import tl , triton
1314from vllm .utils import has_triton_kernels
1415
1516logger = init_logger (__name__ )
1920 import triton_kernels .swiglu
2021 from triton_kernels .matmul_ogs import (FnSpecs , FusedActivation ,
2122 matmul_ogs )
22- from triton_kernels .routing import routing
23+ from triton_kernels .routing import (RoutingData , routing ,
24+ routing_from_bitmatrix )
25+ from triton_kernels .tensor import Bitmatrix
2326 except (ModuleNotFoundError , AttributeError ) as e :
2427 logger .error (
2528 "Failed to import Triton kernels. Please make sure your triton "
2629 "version is compatible. Error: %s" , e )
2730
2831
32+ @triton .jit
33+ def pack_bitmatrix (
34+ bitmatrix ,
35+ topk_ids ,
36+ n_rows , # n_rows in bitmatrix / topk_ids
37+ bm_cols : tl .constexpr , # n int32_t bitpacks in bitmatrix
38+ n_expts_act , # num_topk
39+ BLOCK_SIZE_M : tl .constexpr ,
40+ BLOCK_SIZE_K : tl .constexpr ,
41+ ):
42+ """
43+ Packs topk_ids into a bitmatrix.
44+ code reference:
45+ https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
46+ """
47+ pid_m = tl .program_id (0 )
48+ offsets_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
49+ offsets_k = tl .arange (0 , BLOCK_SIZE_K )
50+ offsets = offsets_m [:, None ] * n_expts_act + offsets_k [None , :]
51+ mask = (offsets_m < n_rows )[:, None ] & (offsets_k < n_expts_act )[None , :]
52+ indices = tl .load (topk_ids + offsets , mask = mask , other = - 1 )
53+ div = indices // 32
54+ rem = indices % 32
55+ one = tl .cast (1 , tl .uint32 )
56+
57+ # Iterate through all the relevant bitmatrix columns.
58+ for i in range (bm_cols ):
59+ # When BLOCK_SIZE_K=32, offs is just the column index.
60+ offs = tl .arange (0 , BLOCK_SIZE_K // 32 ) + i * (BLOCK_SIZE_K // 32 )
61+ # All topks that need to go into this column has the correct bit set.
62+ # Other bits are 0. x is a 2D tensor.
63+ x = tl .where (div [:, :, None ] == offs [None , None , :],
64+ (one << rem )[:, :, None ], 0 )
65+ # Reduce x to get a single int32_t bitpack.
66+ y = tl .reduce_or (x , axis = 1 )
67+ bitmatrix_ptrs = bitmatrix + offsets_m [:,
68+ None ] * bm_cols + offs [None , :]
69+ tl .store (bitmatrix_ptrs , y , mask = offsets_m [:, None ] < n_rows )
70+
71+
2972def triton_kernel_moe_forward (
3073 hidden_states : torch .Tensor ,
3174 w1 , # Tensor or triton_kernels.Tensor
@@ -124,48 +167,99 @@ def triton_kernel_fused_experts(
124167 return intermediate_cache3
125168
126169
127- class BatchedOAITritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
170+ def make_routing_data (
171+ topk_ids : torch .Tensor ,
172+ topk_weights : torch .Tensor ,
173+ num_local_experts : int ,
174+ ) -> tuple ["RoutingData" , torch .Tensor , torch .Tensor ]:
175+
176+ topk_ids = topk_ids .to (torch .int16 )
177+ topk_weights = topk_weights .to (torch .bfloat16 )
178+
179+ n_rows , num_topk = topk_ids .size ()
180+
181+ BLOCK_SIZE_M = 512
182+ BLOCK_SIZE_K = 32
183+
184+ bm_cols = triton .cdiv (num_local_experts , BLOCK_SIZE_K ) # n_bitpacks
185+ bitmatrix = torch .zeros ((n_rows , bm_cols ),
186+ dtype = torch .uint32 ,
187+ device = topk_ids .device )
188+
189+ grid = (triton .cdiv (n_rows , BLOCK_SIZE_M ), )
190+ pack_bitmatrix [grid ](
191+ bitmatrix ,
192+ topk_ids ,
193+ n_rows ,
194+ bm_cols ,
195+ num_topk ,
196+ BLOCK_SIZE_M = BLOCK_SIZE_M ,
197+ BLOCK_SIZE_K = BLOCK_SIZE_K ,
198+ )
199+
200+ bitmatrix_shape = [n_rows , bm_cols * 32 ]
201+ bitmatrix_shape_max = [n_rows , None ]
202+ bitmatrix = Bitmatrix (bitmatrix ,
203+ shape = bitmatrix_shape ,
204+ shape_max = bitmatrix_shape_max ,
205+ scratchpad = None )
206+
207+ # matmul_ogs expects invalid topk_weights to be -1s
208+ topk_weights = torch .where (topk_ids == - 1 , - 1.0 , topk_weights )
209+ routing_data , gather_indx , scatter_indx = routing_from_bitmatrix (
210+ bitmatrix , topk_weights , topk_ids , num_local_experts , num_topk )
211+
212+ return routing_data , gather_indx , scatter_indx
213+
214+
215+ class BaseOAITritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
216+
217+ def __init__ (self , quant_config : FusedMoEQuantConfig ):
218+ super ().__init__ (quant_config )
219+
220+ def supports_expert_map (self ) -> bool :
221+ return True
222+
223+ def finalize_weight_and_reduce_impl (self ) -> mk .TopKWeightAndReduce :
224+ # Weight application and reduction happens in the fused_experts kernel.
225+ return TopKWeightAndReduceNoOP ()
128226
129- def __init__ (
227+ def _make_routing_data (
130228 self ,
131- max_num_tokens : int ,
132- num_dispatchers : int ,
133- quant_config : FusedMoEQuantConfig ,
134- ):
229+ topk_ids : torch .Tensor ,
230+ topk_weights : torch .Tensor ,
231+ num_local_experts : int ,
232+ ) -> tuple ["RoutingData" , torch .Tensor , torch .Tensor ]:
233+ return make_routing_data (topk_ids , topk_weights , num_local_experts )
234+
235+
236+ class OAITritonExperts (BaseOAITritonExperts ):
237+
238+ def __init__ (self , quant_config : FusedMoEQuantConfig ):
239+ # TODO (varun) : Enable activation quantization
240+ assert quant_config .use_mxfp4_w4a16 , "Supports only mxfp4_w4a16"
135241 super ().__init__ (quant_config )
136- self .max_num_tokens = max_num_tokens
137- self .num_dispatchers = num_dispatchers
138242
139243 @property
140244 def activation_formats (
141245 self
142246 ) -> tuple [mk .FusedMoEActivationFormat , mk .FusedMoEActivationFormat ]:
143- return (mk .FusedMoEActivationFormat .BatchedExperts ,
144- mk .FusedMoEActivationFormat .BatchedExperts )
247+ return (mk .FusedMoEActivationFormat .Standard ,
248+ mk .FusedMoEActivationFormat .Standard )
145249
146250 def supports_chunking (self ) -> bool :
147- return False
148-
149- def supports_expert_map (self ) -> bool :
150- return False
151-
152- def finalize_weight_and_reduce_impl (self ) -> mk .TopKWeightAndReduce :
153- # Let PrepareAndFinalize::finalize() decide the impl.
154- return TopKWeightAndReduceDelegate ()
251+ return True
155252
156253 def workspace_shapes (
157254 self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
158255 topk : int , global_num_experts : int , local_num_experts : int ,
159256 expert_tokens_meta : Optional [mk .ExpertTokensMetadata ]
160257 ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
161258 # workspace are allocated inside the kernel
162- assert a .dim () == 2
163- num_dp = self .num_dispatchers
164- num_experts = local_num_experts
165- max_num_tokens = self .max_num_tokens
166- workspace2 = (0 , 0 , 0 )
167- output = (num_experts , max_num_tokens * num_dp , N )
168- return (output , workspace2 , output , a .dtype )
259+ workspace1 = (M , K )
260+ workspace2 = (0 , 0 )
261+ output = (M , K )
262+ return (workspace1 , workspace2 , output , a .dtype )
169263
170264 def apply (
171265 self ,
@@ -185,17 +279,29 @@ def apply(
185279 expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
186280 apply_router_weight_on_input : bool ,
187281 ):
188- return triton_kernel_fused_experts (
189- output ,
282+ if expert_map is not None :
283+ topk_ids = expert_map [topk_ids ]
284+
285+ local_num_experts = w1 .size (0 )
286+ if global_num_experts == - 1 :
287+ global_num_experts = local_num_experts
288+
289+ routing_data , gather_indx , scatter_indx = self ._make_routing_data (
290+ topk_ids , topk_weights , local_num_experts )
291+
292+ experts_output = triton_kernel_fused_experts (
293+ None ,
190294 hidden_states ,
191295 w1 ,
192296 w2 ,
193- routing_data = None ,
194- gather_indx = None ,
195- scatter_indx = None ,
297+ routing_data ,
298+ gather_indx ,
299+ scatter_indx ,
196300 activation = activation ,
197301 quant_config = self .quant_config ,
198302 apply_router_weight_on_input = False ,
199- global_num_experts = global_num_experts ,
200- expert_map = expert_map ,
303+ global_num_experts = local_num_experts ,
304+ expert_map = None , # applied already
201305 a1q_scale = a1q_scale )
306+
307+ output .copy_ (experts_output , non_blocking = True )
0 commit comments