@@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146146
147147 using ElementAB = typename Gemm::ElementAB;
148148 using ElementD = typename Gemm::ElementD;
149+ using ElementBlockScale = typename Gemm::ElementBlockScale;
149150
150151 int32_t m = a.size (0 ), n = b.size (1 ), k = a.size (1 );
151152
@@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
166167 ScaleConfig::tile_atom_to_shape_SFB (make_shape (n, m, k, 1 )) :
167168 ScaleConfig::tile_atom_to_shape_SFB (make_shape (m, n, k, 1 ));
168169
169- auto a_ptr = static_cast <ElementAB*>(a.data_ptr ());
170- auto b_ptr = static_cast <ElementAB*>(b.data_ptr ());
171- auto a_scales_ptr = static_cast <float *>(a_scales.data_ptr ());
172- auto b_scales_ptr = static_cast <float *>(b_scales.data_ptr ());
170+ auto a_ptr = static_cast <ElementAB const *>(a.data_ptr ());
171+ auto b_ptr = static_cast <ElementAB const *>(b.data_ptr ());
172+ auto a_scales_ptr = static_cast <ElementBlockScale const *>(a_scales.data_ptr ());
173+ auto b_scales_ptr = static_cast <ElementBlockScale const *>(b_scales.data_ptr ());
173174
174- auto mainloop_args = [&](){
175- // layout_SFA and layout_SFB cannot be swapped since they are deduced.
176- if (swap_ab) {
177- return typename GemmKernel::MainloopArguments{
178- b_ptr, b_stride, a_ptr, a_stride,
179- b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
180- };
181- }
182- else {
183- return typename GemmKernel::MainloopArguments{
184- a_ptr, a_stride, b_ptr, b_stride,
185- a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
186- };
187- }
188- }();
175+ typename GemmKernel::MainloopArguments mainloop_args{};
176+ mainloop_args.layout_SFA = layout_SFA;
177+ mainloop_args.layout_SFB = layout_SFB;
178+ if (swap_ab) {
179+ mainloop_args.ptr_A = b_ptr;
180+ mainloop_args.dA = b_stride;
181+ mainloop_args.ptr_B = a_ptr;
182+ mainloop_args.dB = a_stride;
183+ mainloop_args.ptr_SFA = b_scales_ptr;
184+ mainloop_args.ptr_SFB = a_scales_ptr;
185+ } else {
186+ mainloop_args.ptr_A = a_ptr;
187+ mainloop_args.dA = a_stride;
188+ mainloop_args.ptr_B = b_ptr;
189+ mainloop_args.dB = b_stride;
190+ mainloop_args.ptr_SFA = a_scales_ptr;
191+ mainloop_args.ptr_SFB = b_scales_ptr;
192+ }
189193 auto prob_shape = swap_ab ? cute::make_shape (n, m, k, 1 ) : cute::make_shape (m, n, k, 1 );
190194
191195 auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
0 commit comments