Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EVT for S32 accum and BF16 C/output tensors #1826

Conversation

alexsamardzic
Copy link
Contributor

@alexsamardzic alexsamardzic commented Sep 18, 2024

To reproduce the problem:

First, apply the patch below to change 47_ampere_gemm_universal_streamk_broadcast example, so that S8/S8 GEMM is performed, producing S32 result, and then the accumulator is combined with some F16 values in the epilogue, to produce F16 result. After these changes, the exampe will build and run fine. However, if then cutlass::half_t replaced with cutlass::bfloat16_t in examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu, the example won't build. The reason for failure is missing specialization of DefaultIteratorsTensorOp, that this PR is adding.

The patch for 47_ampere_gemm_universal_streamk_broadcast example, to reproduce the problem
diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
index ed65e58c..e2125bdf 100644
--- a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
+++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
@@ -96,13 +96,13 @@
 /////////////////////////////////////////////////////////////////////////////////////////////////
 
 // A matrix configuration
-using         ElementA         = cutlass::half_t;                                  // Element type for A matrix operand
+using         ElementA         = int8_t;                                  // Element type for A matrix operand
 using         LayoutA          = cutlass::layout::RowMajor;                        // Layout type for A matrix operand
 constexpr int AlignmentA       = 128 / cutlass::sizeof_bits<ElementA>::value;      // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
 
 // B matrix configuration
-using         ElementB         = cutlass::half_t;                                  // Element type for B matrix operand
-using         LayoutB          = cutlass::layout::RowMajor;                        // Layout type for B matrix operand
+using         ElementB         = int8_t;                                  // Element type for B matrix operand
+using         LayoutB          = cutlass::layout::ColumnMajor;                        // Layout type for B matrix operand
 constexpr int AlignmentB       = 128 / cutlass::sizeof_bits<ElementB>::value;      // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
 
 // C1/C2/D matrix configuration
@@ -116,13 +116,13 @@ using         LayoutOutput     = cutlass::layout::RowMajor;
 // constexpr int AlignmentOutput  = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
 
 // Multiply-accumulate blocking/pipelining details
-using ElementAccumulator  = cutlass::half_t;                          // Element type for internal accumulation
-using ElementCompute      = cutlass::half_t;                          // Element type for compute
+using ElementAccumulator  = int32_t;                          // Element type for internal accumulation
+using ElementCompute      = float;                          // Element type for compute
 using ArchTag             = cutlass::arch::Sm80;                      // Tag indicating the minimum SM that supports the intended feature
 using OperatorClass       = cutlass::arch::OpClassTensorOp;           // Operator class tag
-using ThreadblockShape    = cutlass::gemm::GemmShape<128, 128, 32>;   // Threadblock-level tile size (concept: GemmShape)
-using WarpShape           = cutlass::gemm::GemmShape<64, 64, 32>;     // Warp-level tile size (concept: GemmShape)
-using InstructionShape    = cutlass::gemm::GemmShape<16, 8, 16>;      // Instruction-level tile size (concept: GemmShape)
+using ThreadblockShape    = cutlass::gemm::GemmShape<128, 128, 128>;   // Threadblock-level tile size (concept: GemmShape)
+using WarpShape           = cutlass::gemm::GemmShape<64, 64, 64>;     // Warp-level tile size (concept: GemmShape)
+using InstructionShape    = cutlass::gemm::GemmShape<16, 8, 32>;      // Instruction-level tile size (concept: GemmShape)
 constexpr int NumStages   = 4;                                        // Number of global->shared pipeline stages used in the GEMM mainloop
 constexpr int EVTEpilogueStages = 1;                                  // Number of epilogue stages in EVT
 
@@ -253,7 +253,7 @@ using EVTKernelStreamK =
     EVTD,
     cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
     NumStages,
-    cutlass::arch::OpMultiplyAdd,
+    cutlass::arch::OpMultiplyAddSaturate,
     EVTEpilogueStages
 >::GemmKernel;
 
@@ -707,32 +707,32 @@ int main(int argc, const char **argv)
   if (options.split_k_factor == 1)
   {
     // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
-    Result basic_dp         = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
+    // Result basic_dp         = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
     Result streamk_default  = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
 
-    printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
+    // printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
 
     // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
     options.avail_sms       = 1;        // Set loadbalancing width to 1 SM (no load balancing)
     Result streamk_dp       = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
     options.avail_sms       = -1;       // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
 
-    printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
+    // printf("  Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
 
     options.split_k_factor++;     // Increment splitting factor for next evaluation
 
   }
 
   // Show that StreamK can emulate "Split-K" with a tile-splitting factor
-  Result basic_splitk = run<DeviceGemmBasic>(
-    std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
-    options);
+  // Result basic_splitk = run<DeviceGemmBasic>(
+  //   std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
+  //   options);
 
   Result streamk_splitk = run<DeviceGemmStreamK>(
     std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
     options);
 
-  printf("  Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
+  // printf("  Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
 
   return 0;
 }

(Note: This PR is practically a completion of #812. BTW, the issue is found in the context of this work.)

@alexsamardzic alexsamardzic force-pushed the fix-evt-int32-accum-bfloat16-c branch from 8dbe183 to fe42718 Compare October 14, 2024 18:38
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Oct 14, 2024

@hwu36: Could someone please check, and eventually merge, this?

Apparently, this fix is included in 3.6.0. Closing the PR.

@alexsamardzic alexsamardzic deleted the fix-evt-int32-accum-bfloat16-c branch October 14, 2024 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant