99from  vllm .model_executor .layers .fused_moe .config  import  (
1010    FUSED_MOE_UNQUANTIZED_CONFIG , FusedMoEQuantConfig )
1111from  vllm .model_executor .layers .fused_moe .topk_weight_and_reduce  import  (
12-     TopKWeightAndReduceDelegate )
12+     TopKWeightAndReduceNoOP )
13+ from  vllm .triton_utils  import  tl , triton 
1314from  vllm .utils  import  has_triton_kernels 
1415
1516logger  =  init_logger (__name__ )
1920        import  triton_kernels .swiglu 
2021        from  triton_kernels .matmul_ogs  import  (FnSpecs , FusedActivation ,
2122                                               matmul_ogs )
22-         from  triton_kernels .routing  import  routing 
23+         from  triton_kernels .routing  import  (RoutingData , routing ,
24+                                             routing_from_bitmatrix )
25+         from  triton_kernels .tensor  import  Bitmatrix 
2326    except  (ModuleNotFoundError , AttributeError ) as  e :
2427        logger .error (
2528            "Failed to import Triton kernels. Please make sure your triton " 
2629            "version is compatible. Error: %s" , e )
2730
2831
32+ @triton .jit  
33+ def  pack_bitmatrix (
34+     bitmatrix ,
35+     topk_ids ,
36+     n_rows ,  # n_rows in bitmatrix / topk_ids  
37+     bm_cols : tl .constexpr ,  # n int32_t bitpacks in bitmatrix 
38+     n_expts_act ,  # num_topk  
39+     BLOCK_SIZE_M : tl .constexpr ,
40+     BLOCK_SIZE_K : tl .constexpr ,
41+ ):
42+     """ 
43+     Packs topk_ids into a bitmatrix. 
44+     code reference: 
45+     https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264 
46+     """ 
47+     pid_m  =  tl .program_id (0 )
48+     offsets_m  =  pid_m  *  BLOCK_SIZE_M  +  tl .arange (0 , BLOCK_SIZE_M )
49+     offsets_k  =  tl .arange (0 , BLOCK_SIZE_K )
50+     offsets  =  offsets_m [:, None ] *  n_expts_act  +  offsets_k [None , :]
51+     mask  =  (offsets_m  <  n_rows )[:, None ] &  (offsets_k  <  n_expts_act )[None , :]
52+     indices  =  tl .load (topk_ids  +  offsets , mask = mask , other = - 1 )
53+     div  =  indices  //  32 
54+     rem  =  indices  %  32 
55+     one  =  tl .cast (1 , tl .uint32 )
56+ 
57+     # Iterate through all the relevant bitmatrix columns. 
58+     for  i  in  range (bm_cols ):
59+         # When BLOCK_SIZE_K=32, offs is just the column index. 
60+         offs  =  tl .arange (0 , BLOCK_SIZE_K  //  32 ) +  i  *  (BLOCK_SIZE_K  //  32 )
61+         # All topks that need to go into this column has the correct bit set. 
62+         # Other bits are 0. x is a 2D tensor. 
63+         x  =  tl .where (div [:, :, None ] ==  offs [None , None , :],
64+                      (one  <<  rem )[:, :, None ], 0 )
65+         # Reduce x to get a single int32_t bitpack. 
66+         y  =  tl .reduce_or (x , axis = 1 )
67+         bitmatrix_ptrs  =  bitmatrix  +  offsets_m [:,
68+                                                None ] *  bm_cols  +  offs [None , :]
69+         tl .store (bitmatrix_ptrs , y , mask = offsets_m [:, None ] <  n_rows )
70+ 
71+ 
2972def  triton_kernel_moe_forward (
3073    hidden_states : torch .Tensor ,
3174    w1 ,  # Tensor or triton_kernels.Tensor 
@@ -124,48 +167,99 @@ def triton_kernel_fused_experts(
124167    return  intermediate_cache3 
125168
126169
127- class  BatchedOAITritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
170+ def  make_routing_data (
171+     topk_ids : torch .Tensor ,
172+     topk_weights : torch .Tensor ,
173+     num_local_experts : int ,
174+ ) ->  tuple ["RoutingData" , torch .Tensor , torch .Tensor ]:
175+ 
176+     topk_ids  =  topk_ids .to (torch .int16 )
177+     topk_weights  =  topk_weights .to (torch .bfloat16 )
178+ 
179+     n_rows , num_topk  =  topk_ids .size ()
180+ 
181+     BLOCK_SIZE_M  =  512 
182+     BLOCK_SIZE_K  =  32 
183+ 
184+     bm_cols  =  triton .cdiv (num_local_experts , BLOCK_SIZE_K )  # n_bitpacks 
185+     bitmatrix  =  torch .zeros ((n_rows , bm_cols ),
186+                             dtype = torch .uint32 ,
187+                             device = topk_ids .device )
188+ 
189+     grid  =  (triton .cdiv (n_rows , BLOCK_SIZE_M ), )
190+     pack_bitmatrix [grid ](
191+         bitmatrix ,
192+         topk_ids ,
193+         n_rows ,
194+         bm_cols ,
195+         num_topk ,
196+         BLOCK_SIZE_M = BLOCK_SIZE_M ,
197+         BLOCK_SIZE_K = BLOCK_SIZE_K ,
198+     )
199+ 
200+     bitmatrix_shape  =  [n_rows , bm_cols  *  32 ]
201+     bitmatrix_shape_max  =  [n_rows , None ]
202+     bitmatrix  =  Bitmatrix (bitmatrix ,
203+                           shape = bitmatrix_shape ,
204+                           shape_max = bitmatrix_shape_max ,
205+                           scratchpad = None )
206+ 
207+     # matmul_ogs expects invalid topk_weights to be -1s 
208+     topk_weights  =  torch .where (topk_ids  ==  - 1 , - 1.0 , topk_weights )
209+     routing_data , gather_indx , scatter_indx  =  routing_from_bitmatrix (
210+         bitmatrix , topk_weights , topk_ids , num_local_experts , num_topk )
211+ 
212+     return  routing_data , gather_indx , scatter_indx 
213+ 
214+ 
215+ class  BaseOAITritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
216+ 
217+     def  __init__ (self , quant_config : FusedMoEQuantConfig ):
218+         super ().__init__ (quant_config )
219+ 
220+     def  supports_expert_map (self ) ->  bool :
221+         return  True 
222+ 
223+     def  finalize_weight_and_reduce_impl (self ) ->  mk .TopKWeightAndReduce :
224+         # Weight application and reduction happens in the fused_experts kernel. 
225+         return  TopKWeightAndReduceNoOP ()
128226
129-     def  __init__ (
227+     def  _make_routing_data (
130228        self ,
131-         max_num_tokens : int ,
132-         num_dispatchers : int ,
133-         quant_config : FusedMoEQuantConfig ,
134-     ):
229+         topk_ids : torch .Tensor ,
230+         topk_weights : torch .Tensor ,
231+         num_local_experts : int ,
232+     ) ->  tuple ["RoutingData" , torch .Tensor , torch .Tensor ]:
233+         return  make_routing_data (topk_ids , topk_weights , num_local_experts )
234+ 
235+ 
236+ class  OAITritonExperts (BaseOAITritonExperts ):
237+ 
238+     def  __init__ (self , quant_config : FusedMoEQuantConfig ):
239+         # TODO (varun) : Enable activation quantization 
240+         assert  quant_config .use_mxfp4_w4a16 , "Supports only mxfp4_w4a16" 
135241        super ().__init__ (quant_config )
136-         self .max_num_tokens  =  max_num_tokens 
137-         self .num_dispatchers  =  num_dispatchers 
138242
139243    @property  
140244    def  activation_formats (
141245        self 
142246    ) ->  tuple [mk .FusedMoEActivationFormat , mk .FusedMoEActivationFormat ]:
143-         return  (mk .FusedMoEActivationFormat .BatchedExperts ,
144-                 mk .FusedMoEActivationFormat .BatchedExperts )
247+         return  (mk .FusedMoEActivationFormat .Standard ,
248+                 mk .FusedMoEActivationFormat .Standard )
145249
146250    def  supports_chunking (self ) ->  bool :
147-         return  False 
148- 
149-     def  supports_expert_map (self ) ->  bool :
150-         return  False 
151- 
152-     def  finalize_weight_and_reduce_impl (self ) ->  mk .TopKWeightAndReduce :
153-         # Let PrepareAndFinalize::finalize() decide the impl. 
154-         return  TopKWeightAndReduceDelegate ()
251+         return  True 
155252
156253    def  workspace_shapes (
157254        self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
158255        topk : int , global_num_experts : int , local_num_experts : int ,
159256        expert_tokens_meta : Optional [mk .ExpertTokensMetadata ]
160257    ) ->  tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
161258        # workspace are allocated inside the kernel 
162-         assert  a .dim () ==  2 
163-         num_dp  =  self .num_dispatchers 
164-         num_experts  =  local_num_experts 
165-         max_num_tokens  =  self .max_num_tokens 
166-         workspace2  =  (0 , 0 , 0 )
167-         output  =  (num_experts , max_num_tokens  *  num_dp , N )
168-         return  (output , workspace2 , output , a .dtype )
259+         workspace1  =  (M , K )
260+         workspace2  =  (0 , 0 )
261+         output  =  (M , K )
262+         return  (workspace1 , workspace2 , output , a .dtype )
169263
170264    def  apply (
171265        self ,
@@ -185,17 +279,29 @@ def apply(
185279        expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
186280        apply_router_weight_on_input : bool ,
187281    ):
188-         return  triton_kernel_fused_experts (
189-             output ,
282+         if  expert_map  is  not   None :
283+             topk_ids  =  expert_map [topk_ids ]
284+ 
285+         local_num_experts  =  w1 .size (0 )
286+         if  global_num_experts  ==  - 1 :
287+             global_num_experts  =  local_num_experts 
288+ 
289+         routing_data , gather_indx , scatter_indx  =  self ._make_routing_data (
290+             topk_ids , topk_weights , local_num_experts )
291+ 
292+         experts_output  =  triton_kernel_fused_experts (
293+             None ,
190294            hidden_states ,
191295            w1 ,
192296            w2 ,
193-             routing_data = None ,
194-             gather_indx = None ,
195-             scatter_indx = None ,
297+             routing_data ,
298+             gather_indx ,
299+             scatter_indx ,
196300            activation = activation ,
197301            quant_config = self .quant_config ,
198302            apply_router_weight_on_input = False ,
199-             global_num_experts = global_num_experts ,
200-             expert_map = expert_map , 
303+             global_num_experts = local_num_experts ,
304+             expert_map = None ,   # applied already 
201305            a1q_scale = a1q_scale )
306+ 
307+         output .copy_ (experts_output , non_blocking = True )
0 commit comments