88import triton
99import triton .language as tl
1010
11+ import vllm ._moe_C as moe_kernels
1112from vllm ._C import ops
1213from vllm .logger import init_logger
13- from vllm .utils import is_hip
1414
1515logger = init_logger (__name__ )
1616
@@ -108,8 +108,8 @@ def fused_moe_kernel(
108108 offs_k [None , :] * stride_ak )
109109
110110 off_experts = tl .load (expert_ids_ptr + pid_m )
111- b_ptrs = b_ptr + off_experts * stride_be + ( offs_k [:, None ] * stride_bk +
112- offs_bn [None , :] * stride_bn )
111+ b_ptrs = ( b_ptr + off_experts * stride_be +
112+ ( offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn ) )
113113
114114 # -----------------------------------------------------------
115115 # Iterate to compute a block of the C matrix.
@@ -121,10 +121,12 @@ def fused_moe_kernel(
121121 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
122122 # Load the next block of A and B, generate a mask by checking the
123123 # K dimension.
124- a = tl .load (a_ptrs ,
125- mask = token_mask [:, None ] &
126- (offs_k [None , :] < K - k * BLOCK_SIZE_K ),
127- other = 0.0 )
124+ a = tl .load (
125+ a_ptrs ,
126+ mask = token_mask [:, None ] &
127+ (offs_k [None , :] < K - k * BLOCK_SIZE_K ),
128+ other = 0.0 ,
129+ )
128130 b = tl .load (b_ptrs ,
129131 mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K ,
130132 other = 0.0 )
@@ -144,8 +146,8 @@ def fused_moe_kernel(
144146 # -----------------------------------------------------------
145147 # Write back the block of the output
146148 offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
147- c_ptrs = c_ptr + stride_cm * offs_token [:, None ] + stride_cn * offs_cn [
148- None , :]
149+ c_ptrs = ( c_ptr + stride_cm * offs_token [:, None ] +
150+ stride_cn * offs_cn [ None , :])
149151 c_mask = token_mask [:, None ] & (offs_cn [None , :] < N )
150152 tl .store (c_ptrs , accumulator , mask = c_mask )
151153
@@ -193,31 +195,46 @@ def moe_align_block_size(
193195 sorted_ids = torch .empty (
194196 (topk_ids .numel () + num_experts * (block_size - 1 ), ),
195197 dtype = torch .int32 ,
196- device = topk_ids .device )
197- expert_ids = torch .empty ((topk_ids .numel () + num_experts , ),
198- dtype = torch .int32 ,
199- device = topk_ids .device )
198+ device = topk_ids .device ,
199+ )
200+ expert_ids = torch .empty (
201+ (topk_ids .numel () + num_experts , ),
202+ dtype = torch .int32 ,
203+ device = topk_ids .device ,
204+ )
200205 sorted_ids .fill_ (topk_ids .numel ())
201206 num_tokens_post_pad = torch .empty ((1 ),
202207 dtype = torch .int32 ,
203208 device = topk_ids .device )
204- ops .moe_align_block_size (topk_ids , num_experts , block_size , sorted_ids ,
205- expert_ids , num_tokens_post_pad )
209+ ops .moe_align_block_size (
210+ topk_ids ,
211+ num_experts ,
212+ block_size ,
213+ sorted_ids ,
214+ expert_ids ,
215+ num_tokens_post_pad ,
216+ )
206217 return sorted_ids , expert_ids , num_tokens_post_pad
207218
208219
209- def invoke_fused_moe_kernel (A : torch .Tensor , B : torch .Tensor , C : torch .Tensor ,
210- topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
211- sorted_token_ids : torch .Tensor ,
212- expert_ids : torch .Tensor ,
213- num_tokens_post_padded : torch .Tensor ,
214- mul_routed_weight : bool , top_k : int ,
215- config : Dict [str , Any ]) -> None :
220+ def invoke_fused_moe_kernel (
221+ A : torch .Tensor ,
222+ B : torch .Tensor ,
223+ C : torch .Tensor ,
224+ topk_weights : torch .Tensor ,
225+ topk_ids : torch .Tensor ,
226+ sorted_token_ids : torch .Tensor ,
227+ expert_ids : torch .Tensor ,
228+ num_tokens_post_padded : torch .Tensor ,
229+ mul_routed_weight : bool ,
230+ top_k : int ,
231+ config : Dict [str , Any ],
232+ ) -> None :
216233 assert topk_weights .stride (1 ) == 1
217234 assert sorted_token_ids .stride (0 ) == 1
218235
219236 grid = lambda META : (triton .cdiv (sorted_token_ids .shape [0 ], META [
220- ' BLOCK_SIZE_M' ]) * triton .cdiv (B .shape [1 ], META [' BLOCK_SIZE_N' ]), )
237+ " BLOCK_SIZE_M" ]) * triton .cdiv (B .shape [1 ], META [" BLOCK_SIZE_N" ]), )
221238
222239 fused_moe_kernel [grid ](
223240 A ,
@@ -310,8 +327,8 @@ def fused_moe(
310327 - torch.Tensor: The output tensor after applying the MoE layer.
311328 """
312329 # Check constraints.
313- assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
314- "Number of tokens mismatch" )
330+ assert ( hidden_states .shape [0 ] == gating_output .shape [0 ]
331+ ), "Number of tokens mismatch"
315332 assert hidden_states .shape [1 ] == w1 .shape [2 ], "Hidden size mismatch"
316333 assert gating_output .shape [1 ] == w1 .shape [0 ], "Number of experts mismatch"
317334 assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
@@ -323,34 +340,26 @@ def fused_moe(
323340 M , _ = hidden_states .shape
324341 E , N , _ = w1 .shape
325342
326- if is_hip ():
327- # The MoE kernels are not yet supported on ROCm.
328- routing_weights = torch .softmax (gating_output ,
329- dim = - 1 ,
330- dtype = torch .float32 )
331- topk_weights , topk_ids = torch .topk (routing_weights , topk , dim = - 1 )
332- else :
333- import vllm ._moe_C as moe_kernels
334-
335- topk_weights = torch .empty (M ,
336- topk ,
337- dtype = torch .float32 ,
338- device = hidden_states .device )
339- topk_ids = torch .empty (M ,
343+ topk_weights = torch .empty (M ,
340344 topk ,
341- dtype = torch .int32 ,
345+ dtype = torch .float32 ,
342346 device = hidden_states .device )
343- token_expert_indicies = torch .empty (M ,
344- topk ,
345- dtype = torch .int32 ,
346- device = hidden_states .device )
347- moe_kernels .topk_softmax (
348- topk_weights ,
349- topk_ids ,
350- token_expert_indicies ,
351- gating_output .float (), # TODO(woosuk): Optimize this.
352- )
353- del token_expert_indicies # Not used. Will be used in the future.
347+ topk_ids = torch .empty (M ,
348+ topk ,
349+ dtype = torch .int32 ,
350+ device = hidden_states .device )
351+ token_expert_indicies = torch .empty (M ,
352+ topk ,
353+ dtype = torch .int32 ,
354+ device = hidden_states .device )
355+ moe_kernels .topk_softmax (
356+ topk_weights ,
357+ topk_ids ,
358+ token_expert_indicies ,
359+ gating_output .float (), # TODO(woosuk): Optimize this.
360+ )
361+ del token_expert_indicies # Not used. Will be used in the future.
362+
354363 if renormalize :
355364 topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
356365
@@ -367,48 +376,74 @@ def fused_moe(
367376 else :
368377 # Else use the default config
369378 config = {
370- ' BLOCK_SIZE_M' : 64 ,
371- ' BLOCK_SIZE_N' : 64 ,
372- ' BLOCK_SIZE_K' : 32 ,
373- ' GROUP_SIZE_M' : 8
379+ " BLOCK_SIZE_M" : 64 ,
380+ " BLOCK_SIZE_N" : 64 ,
381+ " BLOCK_SIZE_K" : 32 ,
382+ " GROUP_SIZE_M" : 8 ,
374383 }
375384
376385 if M <= E :
377386 config = {
378- ' BLOCK_SIZE_M' : 16 ,
379- ' BLOCK_SIZE_N' : 32 ,
380- ' BLOCK_SIZE_K' : 64 ,
381- ' GROUP_SIZE_M' : 1
387+ " BLOCK_SIZE_M" : 16 ,
388+ " BLOCK_SIZE_N" : 32 ,
389+ " BLOCK_SIZE_K" : 64 ,
390+ " GROUP_SIZE_M" : 1 ,
382391 }
383392
384- intermediate_cache1 = torch .empty ((M , topk_ids .shape [1 ], N ),
385- device = hidden_states .device ,
386- dtype = hidden_states .dtype )
387- intermediate_cache2 = torch .empty ((M * topk_ids .shape [1 ], N // 2 ),
388- device = hidden_states .device ,
389- dtype = hidden_states .dtype )
390- intermediate_cache3 = torch .empty ((M , topk_ids .shape [1 ], w2 .shape [1 ]),
391- device = hidden_states .device ,
392- dtype = hidden_states .dtype )
393+ intermediate_cache1 = torch .empty (
394+ (M , topk_ids .shape [1 ], N ),
395+ device = hidden_states .device ,
396+ dtype = hidden_states .dtype ,
397+ )
398+ intermediate_cache2 = torch .empty (
399+ (M * topk_ids .shape [1 ], N // 2 ),
400+ device = hidden_states .device ,
401+ dtype = hidden_states .dtype ,
402+ )
403+ intermediate_cache3 = torch .empty (
404+ (M , topk_ids .shape [1 ], w2 .shape [1 ]),
405+ device = hidden_states .device ,
406+ dtype = hidden_states .dtype ,
407+ )
393408
394409 sorted_token_ids , expert_ids , num_tokens_post_padded = moe_align_block_size (
395- topk_ids , config [' BLOCK_SIZE_M' ], E )
410+ topk_ids , config [" BLOCK_SIZE_M" ], E )
396411
397- invoke_fused_moe_kernel (hidden_states , w1 , intermediate_cache1 ,
398- topk_weights , topk_ids , sorted_token_ids ,
399- expert_ids , num_tokens_post_padded , False ,
400- topk_ids .shape [1 ], config )
412+ invoke_fused_moe_kernel (
413+ hidden_states ,
414+ w1 ,
415+ intermediate_cache1 ,
416+ topk_weights ,
417+ topk_ids ,
418+ sorted_token_ids ,
419+ expert_ids ,
420+ num_tokens_post_padded ,
421+ False ,
422+ topk_ids .shape [1 ],
423+ config ,
424+ )
401425
402426 ops .silu_and_mul (intermediate_cache2 , intermediate_cache1 .view (- 1 , N ))
403427
404- invoke_fused_moe_kernel (intermediate_cache2 , w2 , intermediate_cache3 ,
405- topk_weights , topk_ids , sorted_token_ids ,
406- expert_ids , num_tokens_post_padded , True , 1 ,
407- config )
428+ invoke_fused_moe_kernel (
429+ intermediate_cache2 ,
430+ w2 ,
431+ intermediate_cache3 ,
432+ topk_weights ,
433+ topk_ids ,
434+ sorted_token_ids ,
435+ expert_ids ,
436+ num_tokens_post_padded ,
437+ True ,
438+ 1 ,
439+ config ,
440+ )
408441
409442 if inplace :
410- return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
411- dim = 1 ,
412- out = hidden_states )
443+ return torch .sum (
444+ intermediate_cache3 .view (* intermediate_cache3 .shape ),
445+ dim = 1 ,
446+ out = hidden_states ,
447+ )
413448 return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
414449 dim = 1 )
0 commit comments