@@ -626,8 +626,9 @@ def apply(
626626        raise  NotImplementedError 
627627
628628
629- def  _slice_scales (scales : Optional [torch .Tensor ], start : int ,
630-                   end : int ) ->  Optional [torch .Tensor ]:
629+ def  _slice_scales (
630+     scales : Optional [torch .Tensor ], start : int , end : int 
631+ ) ->  Optional [torch .Tensor ]:
631632    if  scales  is  not None :
632633        if  scales .numel () ==  1 :
633634            return  scales 
@@ -640,8 +641,9 @@ class SharedResizableBuffer:
640641    def  __init__ (self ):
641642        self .buffer  =  None 
642643
643-     def  get (self , shape : tuple [int , ...], device : torch .device ,
644-             dtype : torch .dtype ) ->  torch .Tensor :
644+     def  get (
645+         self , shape : tuple [int , ...], device : torch .device , dtype : torch .dtype 
646+     ) ->  torch .Tensor :
645647        assert  shape  !=  ()
646648        shape_numel  =  prod (shape )
647649        if  (
@@ -717,8 +719,11 @@ def _chunk_info(self, M: int) -> tuple[int, int]:
717719        get num_chunks == 1. Take max(M, 1) to avoid divide by zero. 
718720        If there are no tokens to process, the number of chunks will be zero. 
719721        """ 
720-         CHUNK_SIZE  =  (max (M , 1 ) if  not  self .fused_experts .supports_chunking ()
721-                       else  min (M , envs .VLLM_FUSED_MOE_CHUNK_SIZE ))
722+         CHUNK_SIZE  =  (
723+             max (M , 1 )
724+             if  not  self .fused_experts .supports_chunking ()
725+             else  min (M , envs .VLLM_FUSED_MOE_CHUNK_SIZE )
726+         )
722727        num_chunks  =  cdiv (M , CHUNK_SIZE )
723728        # If there are no tokens, then there should be no loop iterations. 
724729        assert  M  >  0  or  num_chunks  ==  0 
@@ -755,31 +760,37 @@ def _allocate_buffers(
755760        workspace_dtype  =  self .fused_experts .workspace_dtype (out_dtype )
756761
757762        workspace13_shape , workspace2_shape , fused_out_shape  =  (
758-             self .fused_experts .workspace_shapes (M_chunk , M_full , N , K , top_k ,
759-                                                 global_num_experts ,
760-                                                 local_num_experts ,
761-                                                 expert_tokens_meta ))
763+             self .fused_experts .workspace_shapes (
764+                 M_chunk ,
765+                 M_full ,
766+                 N ,
767+                 K ,
768+                 top_k ,
769+                 global_num_experts ,
770+                 local_num_experts ,
771+                 expert_tokens_meta ,
772+             )
773+         )
762774
763775        # We can reuse the memory between cache1 and cache3 because by the 
764776        # time we need cache3, we're done with cache1. 
765-         workspace13  =  buffers .workspace13 .get (workspace13_shape , 
766-                                                device = device ,
767-                                                dtype = workspace_dtype )
768-         workspace2  =  buffers .workspace2 .get (workspace2_shape , 
769-                                              device = device ,
770-                                              dtype = workspace_dtype )
777+         workspace13  =  buffers .workspace13 .get (
778+             workspace13_shape ,  device = device ,  dtype = workspace_dtype 
779+         )
780+         workspace2  =  buffers .workspace2 .get (
781+             workspace2_shape ,  device = device ,  dtype = workspace_dtype 
782+         )
771783
772784        # Construct the entire output that can then be processed in chunks. 
773785        # Reuse workspace13 for the output in the non-chunked case as long 
774786        # as it is large enough. This will not always be the case for standard 
775787        # format experts and with experts that have empty workspaces. 
776-         if  num_chunks  ==  1  and  prod (workspace13_shape ) >=  prod (
777-                 fused_out_shape ):
788+         if  num_chunks  ==  1  and  prod (workspace13_shape ) >=  prod (fused_out_shape ):
778789            fused_out  =  _resize_cache (workspace13 , fused_out_shape )
779790        else :
780-             fused_out  =  buffers .fused_out .get (fused_out_shape , 
781-                                                device = device ,
782-                                                dtype = out_dtype )
791+             fused_out  =  buffers .fused_out .get (
792+                 fused_out_shape ,  device = device ,  dtype = out_dtype 
793+             )
783794
784795        return  workspace13 , workspace2 , fused_out 
785796
@@ -794,8 +805,7 @@ def _slice_output_tensor(
794805        if  num_chunks  ==  1 :
795806            return  fused_out 
796807
797-         assert  fused_out .size (0 ) %  M  ==  0 , (
798-             f"fused_out shape { fused_out .shape } { M }  )
808+         assert  fused_out .size (0 ) %  M  ==  0 , f"fused_out shape { fused_out .shape } { M }  
799809        factor  =  fused_out .size (0 ) //  M 
800810        out_chunk_size  =  CHUNK_SIZE  *  factor 
801811        s  =  chunk_idx  *  out_chunk_size 
@@ -816,23 +826,24 @@ def _slice_expert_tokens_metadata(
816826        # The existing expert_num_tokens is for the entire a1q 
817827        # input. Chunking forces recomputation of the number 
818828        # of tokens assigned to each expert. 
819-         c_expert_num_tokens  =  count_expert_num_tokens (chunk_topk_ids , 
820-                                                        local_num_experts ,
821-                                                        expert_map )
829+         c_expert_num_tokens  =  count_expert_num_tokens (
830+             chunk_topk_ids ,  local_num_experts ,  expert_map 
831+         )
822832
823833        c_expert_num_tokens_cpu  =  None 
824834        need_expert_num_tokens_cpu  =  (
825-             full_expert_tokens_meta .expert_num_tokens_cpu  is  not None )
835+             full_expert_tokens_meta .expert_num_tokens_cpu  is  not None 
836+         )
826837        if  need_expert_num_tokens_cpu :
827838            # This is blocking as some implementations need the count 
828839            # on the CPU to determine appropriate input/out fused-moe 
829840            # buffers 
830-             c_expert_num_tokens_cpu  =  c_expert_num_tokens .to (
831-                 "cpu" , non_blocking = False )
841+             c_expert_num_tokens_cpu  =  c_expert_num_tokens .to ("cpu" , non_blocking = False )
832842
833843        return  ExpertTokensMetadata (
834844            expert_num_tokens = c_expert_num_tokens ,
835-             expert_num_tokens_cpu = c_expert_num_tokens_cpu )
845+             expert_num_tokens_cpu = c_expert_num_tokens_cpu ,
846+         )
836847
837848    def  _prepare (
838849        self ,
@@ -843,11 +854,11 @@ def _prepare(
843854        expert_map : Optional [torch .Tensor ],
844855        apply_router_weight_on_input : bool ,
845856    ) ->  tuple [
846-              torch .Tensor ,
847-              Optional [torch .Tensor ],
848-              Optional [ExpertTokensMetadata ],
849-              torch .Tensor ,
850-              torch .Tensor ,
857+         torch .Tensor ,
858+         Optional [torch .Tensor ],
859+         Optional [ExpertTokensMetadata ],
860+         torch .Tensor ,
861+         torch .Tensor ,
851862    ]:
852863        """ 
853864        The _prepare method is a wrapper around self.prepare_finalize.prepare 
@@ -859,16 +870,21 @@ def _prepare(
859870            # TODO(lucas): enable in follow-up 
860871            assert  not  dbo_enabled ()
861872
862-             (a1q , a1q_scale , expert_tokens_meta , _expert_topk_ids ,
863-              _expert_topk_weights ) =  self .prepare_finalize .prepare (
864-                  hidden_states ,
865-                  topk_weights ,
866-                  topk_ids ,
867-                  global_num_experts ,
868-                  expert_map ,
869-                  apply_router_weight_on_input ,
870-                  self .fused_experts .quant_config ,
871-              )
873+             (
874+                 a1q ,
875+                 a1q_scale ,
876+                 expert_tokens_meta ,
877+                 _expert_topk_ids ,
878+                 _expert_topk_weights ,
879+             ) =  self .prepare_finalize .prepare (
880+                 hidden_states ,
881+                 topk_weights ,
882+                 topk_ids ,
883+                 global_num_experts ,
884+                 expert_map ,
885+                 apply_router_weight_on_input ,
886+                 self .fused_experts .quant_config ,
887+             )
872888        else :
873889            # Overlap shared expert compute with all2all dispatch. 
874890            dbo_maybe_run_recv_hook ()
@@ -931,7 +947,9 @@ def _fused_experts(
931947        apply_router_weight_on_input : bool ,
932948        expert_tokens_meta : Optional [ExpertTokensMetadata ],
933949    ) ->  torch .Tensor :
934-         _ , M_full , N , K , top_k  =  _moe_problem_size (a1q , w1 , w2 , topk_ids )
950+         _ , M_full , N , K , top_k  =  self .fused_experts .moe_problem_size (
951+             a1q , w1 , w2 , topk_ids 
952+         )
935953
936954        num_chunks , CHUNK_SIZE  =  self ._chunk_info (M_full )
937955
@@ -959,19 +977,32 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
959977        else :
960978            assert  num_chunks  >  0 
961979            workspace13 , workspace2 , fused_out  =  self ._allocate_buffers (
962-                 in_dtype , a1q .device , CHUNK_SIZE , M_full , N , K , top_k ,
963-                 global_num_experts , local_num_experts , expert_tokens_meta )
980+                 in_dtype ,
981+                 a1q .device ,
982+                 CHUNK_SIZE ,
983+                 M_full ,
984+                 N ,
985+                 K ,
986+                 top_k ,
987+                 global_num_experts ,
988+                 local_num_experts ,
989+                 expert_tokens_meta ,
990+             )
964991
965992        for  chunk_idx  in  range (num_chunks ):
966993            s , e  =  input_chunk_range (chunk_idx )
967994
968995            c_expert_tokens_meta  =  self ._slice_expert_tokens_metadata (
969-                 num_chunks , expert_tokens_meta , topk_ids [s :e ],
970-                 local_num_experts , expert_map )
996+                 num_chunks ,
997+                 expert_tokens_meta ,
998+                 topk_ids [s :e ],
999+                 local_num_experts ,
1000+                 expert_map ,
1001+             )
9711002
972-             c_fused_out  =  self ._slice_output_tensor (fused_out ,  chunk_idx , 
973-                                                      num_chunks , CHUNK_SIZE ,
974-                                                      M_full )
1003+             c_fused_out  =  self ._slice_output_tensor (
1004+                 fused_out ,  chunk_idx ,  num_chunks , CHUNK_SIZE ,  M_full 
1005+             )
9751006
9761007            self .fused_experts .apply (
9771008                output = c_fused_out ,
@@ -1111,15 +1142,14 @@ def forward(
11111142        if  global_num_experts  ==  - 1 :
11121143            global_num_experts  =  local_num_experts 
11131144
1114-         a1q , a1q_scale , expert_tokens_meta , topk_ids , topk_weights  =  (
1115-             self ._prepare (
1116-                 hidden_states ,
1117-                 topk_weights ,
1118-                 topk_ids ,
1119-                 global_num_experts ,
1120-                 expert_map ,
1121-                 apply_router_weight_on_input ,
1122-             ))
1145+         a1q , a1q_scale , expert_tokens_meta , topk_ids , topk_weights  =  self ._prepare (
1146+             hidden_states ,
1147+             topk_weights ,
1148+             topk_ids ,
1149+             global_num_experts ,
1150+             expert_map ,
1151+             apply_router_weight_on_input ,
1152+         )
11231153
11241154        fused_out  =  self ._fused_experts (
11251155            in_dtype = hidden_states .dtype ,
0 commit comments