@@ -80,11 +80,13 @@ def workspace_shapes(
8080        topk : int ,
8181        num_experts : int ,
8282    ) ->  tuple [int , int , torch .dtype ]:
83+ 
8384        block_m  =  self .block_shape [0 ]
8485        M_sum  =  (M  *  topk ) +  num_experts  *  (block_m  -  1 )
8586        M_sum  =  round_up (M_sum , block_m )
8687        workspace1  =  M_sum  *  max (N  *  2 , K )
87-         workspace2  =  M_sum  *  N 
88+         workspace2  =  M_sum  *  max (N , K )
89+ 
8890        return  (workspace1 , workspace2 , a .dtype )
8991
9092    def  apply (
@@ -135,26 +137,31 @@ def apply(
135137
136138        # Note: M_sum is different than the pre-permuted shape of a1q. 
137139        M_sum  =  a1q .size (0 )
138-         workspace1  =  _resize_cache (workspace13 , (M_sum , N ))
139-         workspace2  =  _resize_cache (workspace2 , (M_sum , N  //  2 ))
140-         workspace3  =  _resize_cache (workspace13 , (M_sum , K ))
140+ 
141+         mm1_out  =  _resize_cache (workspace13 , (M_sum , N ))
142+         act_out  =  _resize_cache (workspace2 , (M_sum , N  //  2 ))
143+         quant_out  =  _resize_cache (workspace13 .view (dtype = torch .float8_e4m3fn ),
144+                                   (M_sum , N  //  2 ))
145+         mm2_out  =  _resize_cache (workspace2 , (M_sum , K ))
146+         out  =  _resize_cache (workspace13 , (inv_perm .size (0 ), K ))
141147
142148        dg .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
143-             (a1q , a1q_scale ), (w1 , w1_scale ), workspace1 , expert_ids )
149+             (a1q , a1q_scale ), (w1 , w1_scale ), mm1_out , expert_ids )
144150
145-         self .activation (activation , workspace2 ,  workspace1 .view (- 1 , N ))
151+         self .activation (activation , act_out ,  mm1_out .view (- 1 , N ))
146152
147153        a2q_scale : Optional [torch .Tensor ] =  None 
148-         a2q , a2q_scale  =  per_token_group_quant_fp8 (workspace2 ,
154+         a2q , a2q_scale  =  per_token_group_quant_fp8 (act_out ,
149155                                                   self .block_shape [1 ],
150-                                                    column_major_scales = True )
156+                                                    column_major_scales = True ,
157+                                                    out_q = quant_out )
151158
152159        dg .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
153-             (a2q , a2q_scale ), (w2 , w2_scale ), workspace3 , expert_ids )
160+             (a2q , a2q_scale ), (w2 , w2_scale ), mm2_out , expert_ids )
154161
155-         workspace3   =   workspace3 [ inv_perm , ...] 
162+         torch . index_select ( mm2_out ,  0 ,  inv_perm , out = out ) 
156163
157-         return  workspace3 
164+         return  out 
158165
159166
160167def  deep_gemm_moe_fp8 (
0 commit comments