@@ -77,31 +77,45 @@ struct enable_sm89_to_sm90 : Kernel {
7777};
7878
7979/*
80- This epilogue function defines a quantized GEMM operation similar to
81- torch._scaled_mm.
82-
83- A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84- per-row. B can be quantized per-tensor or per-column.
85- Any combination of per-tensor and per-row or column is supported.
86- A and B must have symmetric quantization (zero point == 0).
87-
88- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89- scales are applied elementwise with numpy-style broadcasting.
90-
91- ScaleA and ScaleB define the epilogue functions that apply the scales for
92- the A and B operands respectively. These scales may be either per-tensor or
93- per row or column.
94- */
80+ * This class provides the common ScaleA and ScaleB descriptors for the
81+ * ScaledEpilogue and ScaledEpilogueBias classes.
82+ */
9583template <typename ElementD, typename OutputTileThreadMap>
96- struct ScaledEpilogue {
97- private :
84+ struct ScaledEpilogueBase {
85+ protected :
9886 using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
9987
10088 using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
10189 OutputTileThreadMap, float , Stride<Int<1 >, Int<0 >, Int<0 >>>;
10290
10391 using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
10492 OutputTileThreadMap, float , Stride<Int<0 >, Int<1 >, Int<0 >>>;
93+ };
94+
95+ /*
96+ This epilogue function defines a quantized GEMM operation similar to
97+ torch._scaled_mm.
98+
99+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
100+ per-row. B can be quantized per-tensor or per-column.
101+ Any combination of per-tensor and per-row or column is supported.
102+ A and B must have symmetric quantization (zero point == 0).
103+
104+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
105+ scales are applied elementwise with numpy-style broadcasting.
106+
107+ ScaleA and ScaleB define the epilogue functions that apply the scales for
108+ the A and B operands respectively. These scales may be either per-tensor or
109+ per row or column.
110+ */
111+ template <typename ElementD, typename OutputTileThreadMap>
112+ struct ScaledEpilogue
113+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
114+ private:
115+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
116+ using Accum = typename SUPER::Accum;
117+ using ScaleA = typename SUPER::ScaleA;
118+ using ScaleB = typename SUPER::ScaleB;
105119
106120 using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
107121 cutlass::multiplies, float , float ,
@@ -134,6 +148,53 @@ struct ScaledEpilogue {
134148 }
135149};
136150
151+ template <typename ElementD, typename OutputTileThreadMap>
152+ struct ScaledEpilogueBias
153+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
154+ private:
155+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
156+ using Accum = typename SUPER::Accum;
157+ using ScaleA = typename SUPER::ScaleA;
158+ using ScaleB = typename SUPER::ScaleB;
159+
160+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
161+ cutlass::multiplies, float , float ,
162+ cutlass::FloatRoundStyle::round_to_nearest>;
163+
164+ using EVTCompute0 =
165+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
166+
167+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
168+ cutlass::multiply_add, ElementD, float ,
169+ cutlass::FloatRoundStyle::round_to_nearest>;
170+
171+ using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
172+ OutputTileThreadMap, ElementD, Stride<Int<0 >, Int<1 >, Int<0 >>>;
173+
174+ public:
175+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
176+ EVTCompute0, Bias>;
177+ using ArgumentType = typename EVTCompute::Arguments;
178+
179+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
180+ torch::Tensor const & b_scales,
181+ torch::Tensor const & bias) {
182+ using ScaleAArgs = typename ScaleA::Arguments;
183+ using ScaleBArgs = typename ScaleB::Arguments;
184+ using BiasArgs = typename Bias::Arguments;
185+
186+ ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
187+ ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
188+ BiasArgs bias_args{static_cast <ElementD*>(bias.data_ptr ()), {}};
189+
190+ typename EVTCompute0::Arguments evt0_compute_args{b_args};
191+
192+ typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
193+ bias_args};
194+ return evt_compute_args;
195+ }
196+ };
197+
137198template <typename Arch, template <typename > typename ArchGuard,
138199 typename ElementAB_, typename ElementD_,
139200 template <typename , typename > typename Epilogue_, typename TileShape,
@@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
168229 // clang-format off
169230 using RowMajor = typename cutlass::layout::RowMajor;
170231 using ColumnMajor = typename cutlass::layout::ColumnMajor;
171- using KernelType =
232+ using KernelType =
172233 ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
173- ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,
174- ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,
234+ ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,
235+ ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,
175236 float , cutlass::layout::RowMajor, 4 ,
176- ElementAcc, float , cutlass::arch::OpClassTensorOp,
177- Arch,
237+ ElementAcc, float , cutlass::arch::OpClassTensorOp,
238+ Arch,
178239 TileShape, WarpShape, InstructionShape,
179240 EVTD,
180241 cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
@@ -404,14 +465,13 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
404465 }
405466}
406467
407- void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
408- torch::Tensor const & b,
409- torch::Tensor const & a_scales,
410- torch::Tensor const & b_scales) {
468+ template <template <typename , typename > typename Epilogue,
469+ typename ... EpilogueArgs>
470+ void cutlass_scaled_mm_sm75_epilogue (torch::Tensor& out, torch::Tensor const & a,
471+ torch::Tensor const & b,
472+ EpilogueArgs&&... epilogue_args) {
411473 TORCH_CHECK (a.dtype () == torch::kInt8 );
412474 TORCH_CHECK (b.dtype () == torch::kInt8 );
413- TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
414- TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
415475
416476 using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
417477 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
@@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
420480 if (out.dtype () == torch::kBFloat16 ) {
421481 return cutlass_gemm_caller<cutlass_2x_gemm<
422482 cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::bfloat16_t ,
423- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
424- out, a, b, a_scales, b_scales );
483+ Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
484+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
425485 } else {
426486 TORCH_CHECK (out.dtype () == torch::kFloat16 );
427487 return cutlass_gemm_caller<cutlass_2x_gemm<
428488 cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::half_t ,
429- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
430- out, a, b, a_scales, b_scales );
489+ Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
490+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
431491 }
432492}
433493
434- void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
494+ void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
435495 torch::Tensor const & b,
436496 torch::Tensor const & a_scales,
437- torch::Tensor const & b_scales) {
438- TORCH_CHECK (a.dtype () == torch::kInt8 );
439- TORCH_CHECK (b.dtype () == torch::kInt8 );
497+ torch::Tensor const & b_scales,
498+ c10::optional<torch::Tensor> const & bias) {
440499 TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
441500 TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
501+ if (bias) {
502+ TORCH_CHECK (bias->dtype () == out.dtype (),
503+ " currently bias dtype must match output dtype " , out.dtype ());
504+ return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
505+ out, a, b, a_scales, b_scales, *bias);
506+ } else {
507+ return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
508+ b_scales);
509+ }
510+ }
511+
512+ template <template <typename , typename > typename Epilogue,
513+ typename ... EpilogueArgs>
514+ void cutlass_scaled_mm_sm80_epilogue (torch::Tensor& out, torch::Tensor const & a,
515+ torch::Tensor const & b,
516+ EpilogueArgs&&... epilogue_args) {
517+ TORCH_CHECK (a.dtype () == torch::kInt8 );
518+ TORCH_CHECK (b.dtype () == torch::kInt8 );
442519
443520 if (out.dtype () == torch::kBFloat16 ) {
444- return cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t ,
445- ScaledEpilogue>(out, a, b, a_scales,
446- b_scales);
521+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t , Epilogue>(
522+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
447523 } else {
448524 TORCH_CHECK (out.dtype () == torch::kFloat16 );
449- return cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , ScaledEpilogue >(
450- out, a, b, a_scales, b_scales );
525+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , Epilogue >(
526+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
451527 }
452528}
453529
454- void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
530+ void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
455531 torch::Tensor const & b,
456532 torch::Tensor const & a_scales,
457- torch::Tensor const & b_scales) {
533+ torch::Tensor const & b_scales,
534+ c10::optional<torch::Tensor> const & bias) {
535+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
536+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
537+ if (bias) {
538+ TORCH_CHECK (bias->dtype () == out.dtype (),
539+ " currently bias dtype must match output dtype " , out.dtype ());
540+ return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
541+ out, a, b, a_scales, b_scales, *bias);
542+ } else {
543+ return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
544+ b_scales);
545+ }
546+ }
547+
548+ template <template <typename , typename > typename Epilogue,
549+ typename ... EpilogueArgs>
550+ void cutlass_scaled_mm_sm89_epilogue (torch::Tensor& out, torch::Tensor const & a,
551+ torch::Tensor const & b,
552+ EpilogueArgs&&... epilogue_args) {
458553 using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
459554 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
460555 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
461556
462- TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
463- TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
464-
465557 if (a.dtype () == torch::kInt8 ) {
466558 TORCH_CHECK (b.dtype () == torch::kInt8 );
467559
468560 if (out.dtype () == torch::kBFloat16 ) {
469561 return cutlass_gemm_caller<cutlass_2x_gemm<
470562 cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::bfloat16_t ,
471- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
472- out, a, b, a_scales, b_scales );
563+ Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
564+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
473565 } else {
474566 assert (out.dtype () == torch::kFloat16 );
475567 return cutlass_gemm_caller<cutlass_2x_gemm<
476568 cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::half_t ,
477- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
478- out, a, b, a_scales, b_scales );
569+ Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
570+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
479571 }
480572 } else {
481573 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
482574 TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
483575
484576 if (out.dtype () == torch::kBFloat16 ) {
485- return cutlass_gemm_caller<cutlass_2x_gemm<
486- cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
487- cutlass::bfloat16_t , ScaledEpilogue, TileShape, WarpShape,
488- InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
577+ return cutlass_gemm_caller<
578+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
579+ cutlass::float_e4m3_t , cutlass::bfloat16_t , Epilogue,
580+ TileShape, WarpShape, InstructionShape, 5 >>(
581+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
489582 } else {
490583 TORCH_CHECK (out.dtype () == torch::kFloat16 );
491- return cutlass_gemm_caller<cutlass_2x_gemm<
492- cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
493- cutlass::half_t , ScaledEpilogue, TileShape, WarpShape,
494- InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
584+ return cutlass_gemm_caller<
585+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
586+ cutlass::float_e4m3_t , cutlass::half_t , Epilogue,
587+ TileShape, WarpShape, InstructionShape, 5 >>(
588+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
495589 }
496590 }
497591}
592+
593+ void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
594+ torch::Tensor const & b,
595+ torch::Tensor const & a_scales,
596+ torch::Tensor const & b_scales,
597+ c10::optional<torch::Tensor> const & bias) {
598+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
599+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
600+ if (bias) {
601+ TORCH_CHECK (bias->dtype () == out.dtype (),
602+ " currently bias dtype must match output dtype " , out.dtype ());
603+ return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
604+ out, a, b, a_scales, b_scales, *bias);
605+ } else {
606+ return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
607+ b_scales);
608+ }
609+ }
0 commit comments