@@ -81,13 +81,13 @@ struct sm90_8_bit_config_N8192 {
8181
8282template  <typename  InType, typename  OutType,
8383          template  <typename , typename , typename > typename  Epilogue>
84- struct  sm90_16_bit_config_default  {
85-   //  M in (16, inf) 
84+ struct  sm90_16_bit_config_M512  {
85+   //  M in [1, 512] 
8686  using  KernelSchedule =
8787      cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
8888  using  EpilogueSchedule =
8989      cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
90-   using  TileShape = cute::Shape<cute::_64, cute::_256 , cute::_128 >;
90+   using  TileShape = cute::Shape<cute::_64, cute::_128 , cute::_64 >;
9191  using  ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
9292
9393  using  Cutlass3xGemm =
@@ -97,46 +97,14 @@ struct sm90_16_bit_config_default {
9797
9898template  <typename  InType, typename  OutType,
9999          template  <typename , typename , typename > typename  Epilogue>
100- struct  sm90_16_bit_config_M16  {
101-   //  M in [1, 16]
102-   using  KernelSchedule =
103-       cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
104-   using  EpilogueSchedule =
105-       cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
106-   using  TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
107-   using  ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
108- 
109-   using  Cutlass3xGemm =
110-       cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
111-                           KernelSchedule, EpilogueSchedule>;
112- };
113- 
114- template  <typename  InType, typename  OutType,
115-           template  <typename , typename , typename > typename  Epilogue>
116- struct  sm90_16_bit_config_K8192  {
117-   //  K in [8192, inf)
118-   using  KernelSchedule =
119-       cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
120-   using  EpilogueSchedule =
121-       cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
122-   using  TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
123-   using  ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
124- 
125-   using  Cutlass3xGemm =
126-       cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
127-                           KernelSchedule, EpilogueSchedule>;
128- };
129- 
130- template  <typename  InType, typename  OutType,
131-           template  <typename , typename , typename > typename  Epilogue>
132- struct  sm90_16_bit_config_N8192  {
133-   //  N in [8192, inf)
100+ struct  sm90_16_bit_config_default  {
101+   //  M in (1024, inf]
134102  using  KernelSchedule =
135103      cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
136104  using  EpilogueSchedule =
137105      cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
138-   using  TileShape = cute::Shape<cute::_64, cute::_128 , cute::_256 >;
139-   using  ClusterShape = cute::Shape<cute::_1, cute::_8 , cute::_1>;
106+   using  TileShape = cute::Shape<cute::_64, cute::_256 , cute::_64 >;
107+   using  ClusterShape = cute::Shape<cute::_1, cute::_1 , cute::_1>;
140108
141109  using  Cutlass3xGemm =
142110      cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
@@ -204,11 +172,7 @@ void run_cutlass_moe_mm_sm90_16_bit(
204172  TORCH_CHECK (b_tensors.size (0 ) > 0 , " No input B tensors provided." 
205173  TORCH_CHECK (out_tensors.size (0 ) > 0 , " No output tensors provided." 
206174
207-   using  Cutlass3xGemmN8192 = typename  sm90_16_bit_config_N8192<
208-       InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
209-   using  Cutlass3xGemmK8192 = typename  sm90_16_bit_config_K8192<
210-       InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
211-   using  Cutlass3xGemmM16 = typename  sm90_16_bit_config_M16<
175+   using  Cutlass3xGemmM512 = typename  sm90_16_bit_config_M512<
212176      InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
213177  using  Cutlass3xGemmDefault = typename  sm90_16_bit_config_default<
214178      InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
@@ -217,16 +181,8 @@ void run_cutlass_moe_mm_sm90_16_bit(
217181  uint32_t  const  n = out_tensors.size (1 );
218182  uint32_t  const  k = a_tensors.size (1 );
219183
220-   if  (n >= 8192 ) {
221-     cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmN8192>(
222-         out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
223-         a_strides, b_strides, c_strides);
224-   } else  if  (k >= 8192 ) {
225-     cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmK8192>(
226-         out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
227-         a_strides, b_strides, c_strides);
228-   } else  if  (m <= 16 ) {
229-     cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM16>(
184+   if  (m <= 512 ) {
185+     cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM512>(
230186        out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
231187        a_strides, b_strides, c_strides);
232188  } else  {
0 commit comments