@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
8484 torch::Tensor const & b_tensors, torch::Tensor const & a_scales,
8585 torch::Tensor const & b_scales, torch::Tensor const & expert_offsets,
8686 torch::Tensor const & problem_sizes, torch::Tensor const & a_strides,
87- torch::Tensor const & b_strides, torch::Tensor const & c_strides) {
87+ torch::Tensor const & b_strides, torch::Tensor const & c_strides,
88+ bool per_act_token, bool per_out_ch) {
8889 TORCH_CHECK (a_tensors.size (0 ) > 0 , " No input A tensors provided." );
8990 TORCH_CHECK (b_tensors.size (0 ) > 0 , " No input B tensors provided." );
9091 TORCH_CHECK (out_tensors.size (0 ) > 0 , " No output tensors provided." );
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
113114 if (n >= 8192 ) {
114115 cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
115116 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
116- problem_sizes, a_strides, b_strides, c_strides);
117+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
118+ per_out_ch);
117119 } else if (k >= 8192 ) {
118120 cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
119121 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
120- problem_sizes, a_strides, b_strides, c_strides);
122+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
123+ per_out_ch);
121124 } else if (m <= 16 ) {
122125 cutlass_group_gemm_caller<Cutlass3xGemmM16>(
123126 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
124- problem_sizes, a_strides, b_strides, c_strides);
127+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
128+ per_out_ch);
125129 } else {
126130 cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
127131 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
128- problem_sizes, a_strides, b_strides, c_strides);
132+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
133+ per_out_ch);
129134 }
130135}
131136
@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
134139 torch::Tensor const & b_tensors, torch::Tensor const & a_scales,
135140 torch::Tensor const & b_scales, torch::Tensor const & expert_offsets,
136141 torch::Tensor const & problem_sizes, torch::Tensor const & a_strides,
137- torch::Tensor const & b_strides, torch::Tensor const & c_strides) {
142+ torch::Tensor const & b_strides, torch::Tensor const & c_strides,
143+ bool per_act_token, bool per_out_ch) {
138144 if (out_tensors.dtype () == torch::kBFloat16 ) {
139145 run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t , cutlass::bfloat16_t >(
140146 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
141- problem_sizes, a_strides, b_strides, c_strides);
147+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
148+ per_out_ch);
142149 } else {
143150 run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t , cutlass::half_t >(
144151 out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
145- problem_sizes, a_strides, b_strides, c_strides);
152+ problem_sizes, a_strides, b_strides, c_strides, per_act_token,
153+ per_out_ch);
146154 }
147155}
148156
@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
153161 torch::Tensor const & b_tensors, torch::Tensor const & a_scales,
154162 torch::Tensor const & b_scales, torch::Tensor const & expert_offsets,
155163 torch::Tensor const & problem_sizes, torch::Tensor const & a_strides,
156- torch::Tensor const & b_strides, torch::Tensor const & c_strides) {
164+ torch::Tensor const & b_strides, torch::Tensor const & c_strides,
165+ bool per_act_token, bool per_out_ch) {
157166 dispatch_moe_mm_sm90 (out_tensors, a_tensors, b_tensors, a_scales, b_scales,
158167 expert_offsets, problem_sizes, a_strides, b_strides,
159- c_strides);
168+ c_strides, per_act_token, per_out_ch );
160169}
0 commit comments