@@ -920,7 +920,7 @@ def __init__(
920920 self .batched_router_logits : Optional [torch .Tensor ] = None
921921 if (self .moe_parallel_config .use_pplx_kernels
922922 or self .moe_parallel_config .use_deepep_ll_kernels
923- or self .moe_parallel_config .use_flashinfer_cutlass_kernels ):
923+ or self .moe_config .use_flashinfer_cutlass_kernels ):
924924 self .batched_hidden_states = torch .zeros (
925925 (moe .max_num_tokens , self .hidden_size ),
926926 dtype = moe .in_dtype ,
@@ -974,7 +974,7 @@ def use_deepep_ll_kernels(self):
974974
975975 @property
976976 def use_flashinfer_cutlass_kernels (self ):
977- return self .moe_parallel_config .use_flashinfer_cutlass_kernels
977+ return self .moe_config .use_flashinfer_cutlass_kernels
978978
979979 def update_expert_map (self ):
980980 # ep_size and ep_rank should already be updated
@@ -1665,7 +1665,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16651665 # only when data parallelism (DP) is enabled.
16661666 use_flashinfer_cutlass_kernels = (
16671667 self .dp_size > 1
1668- and self .moe_parallel_config .use_flashinfer_cutlass_kernels )
1668+ and self .moe_config .use_flashinfer_cutlass_kernels )
16691669 if (self .moe_parallel_config .use_pplx_kernels
16701670 or self .moe_parallel_config .use_deepep_ll_kernels
16711671 or use_flashinfer_cutlass_kernels ):
@@ -1674,7 +1674,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16741674 do_naive_dispatch_combine : bool = (
16751675 self .dp_size > 1
16761676 and not self .moe_parallel_config .use_deepep_ht_kernels
1677- and not self .moe_parallel_config .use_flashinfer_cutlass_kernels )
1677+ and not self .moe_config .use_flashinfer_cutlass_kernels )
16781678 if do_naive_dispatch_combine :
16791679 hidden_states , router_logits = get_ep_group ().dispatch (
16801680 hidden_states , router_logits )
0 commit comments