diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index 939879b2c59f..dbf79a065115 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - auto mainloop_args = [&](){ - // layout_SFA and layout_SFB cannot be swapped since they are deduced. - if (swap_ab) { - return typename GemmKernel::MainloopArguments{ - b_ptr, b_stride, a_ptr, a_stride, - b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB - }; - } - else { - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - } - }(); + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.layout_SFB = layout_SFB; + if (swap_ab) { + mainloop_args.ptr_A = b_ptr; + mainloop_args.dA = b_stride; + mainloop_args.ptr_B = a_ptr; + mainloop_args.dB = a_stride; + mainloop_args.ptr_SFA = b_scales_ptr; + mainloop_args.ptr_SFB = a_scales_ptr; + } else { + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.ptr_SFB = b_scales_ptr; + } auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index 78d5cf37fa6d..811741aee58b 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -125,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -143,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); - - auto mainloop_args = [&](){ - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - }(); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index 86220264151e..147eb8efc077 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -115,6 +115,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -135,17 +136,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); - - auto mainloop_args = [&](){ - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - }(); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr());