@@ -42,7 +42,8 @@ def fused_marlin_moe(
4242 apply_router_weight_on_input : bool = False ,
4343 global_num_experts : int = - 1 ,
4444 activation : Optional [str ] = "silu" ,
45- activation_func : Optional [str ] = None ,
45+ activation_func : Optional [str ] = None , # FIXME: type Callable
46+ moe_sum : Optional [str ] = None , # FIXME: type Callable
4647 expert_map : Optional [torch .Tensor ] = None ,
4748 global_scale1 : Optional [torch .Tensor ] = None ,
4849 global_scale2 : Optional [torch .Tensor ] = None ,
@@ -240,12 +241,16 @@ def fused_marlin_moe(
240241 is_k_full = is_k_full ,
241242 use_atomic_add = use_atomic_add ,
242243 use_fp32_reduce = True ,
244+
243245 is_zp_float = False ,
244246 ).view (- 1 , topk , K )
245247
246248 if output is None :
247249 output = hidden_states if inplace else torch .empty_like (hidden_states )
248- return torch .sum (intermediate_cache3 .view (- 1 , topk , K ), dim = 1 , out = output )
250+ if moe_sum is None :
251+ return torch .sum (intermediate_cache3 .view (- 1 , topk , K ), dim = 1 , out = output )
252+ else :
253+ return moe_sum (intermediate_cache3 , output )
249254
250255
251256def fused_marlin_moe_fake (
@@ -407,6 +412,7 @@ def apply(
407412 global_num_experts = global_num_experts ,
408413 activation = activation ,
409414 activation_func = self .activation ,
415+ moe_sum = self .moe_sum ,
410416 expert_map = expert_map ,
411417 output = output ,
412418 # Workspaces are swapped in workspace_shapes() to account for proper
0 commit comments