Skip to content

Allow bf16 operands on new MFMAs #144925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged

Allow bf16 operands on new MFMAs #144925

merged 1 commit into from
Jun 19, 2025

Conversation

umangyadav
Copy link
Contributor

@umangyadav umangyadav commented Jun 19, 2025

New gfx950 MFMA allows bf16 operands.

def int_amdgcn_mfma_f32_16x16x32_bf16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v8bf16_ty>;

When running amdgpu-to-rocdl, Current logic converts bf16 to i16 always which fails to compile for newer bf16 MFMA e.g. v_mfma_f32_16x16x32bf16.
Backend expects bf16 type for the operands for those newer MFMAs. This patch fixes it.

CC: @krzysz00 @dhernandez0 @giuseros @antiagainst @kuhar

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Umang Yadav (umangyadav)

Changes

New gfx950 MFMA allows bf16 operands.

def int_amdgcn_mfma_f32_16x16x32_bf16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v8bf16_ty>;

Current logic converts bf16 to i16 always which fails to compile for newer bf16 MFMA e.g. v_mfma_f32_16x16x32bf16.
Backend expects bf16 type for the operands for those newer MFMAs. This patch fixes it.


Full diff: https://github.com/llvm/llvm-project/pull/144925.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+21-7)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+2-2)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 074404add47f1..700563460f525 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
-/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
+/// allows bf16. Newer MFMAs support bf16 types on operand, check
+/// IntrinsicsAMDGPU.td file for reference.
 /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
 /// instead, which is what the f8f6f4 intrinsics use.
 /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                      Location loc, Value input) {
+                                      Location loc, Value input,
+                                      bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
-    if (vectorType.getElementType().isBF16())
+    if (vectorType.getElementType().isBF16() && !allowBf16)
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
     if (vectorType.getElementType().isInteger(8) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
 
     StringRef intrinsicName =
         isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+    // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
+    // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
+    bool allowBf16 = [&]() {
+      if (chipset < kGfx950)
+        return false;
+      if (isScaled)
+        return true;
+      return intrinsicName.contains("16x16x32.bf16") ||
+             intrinsicName.contains("32x32x16.bf16");
+    }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands(
-        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
-         adaptor.getDestC()});
+    loweredOp.addOperands({convertMFMAVectorOperand(
+                               rewriter, loc, adaptor.getSourceA(), allowBf16),
+                           convertMFMAVectorOperand(
+                               rewriter, loc, adaptor.getSourceB(), allowBf16),
+                           adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 52a5d39f668c6..39c31d5bf2fa3 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -11,9 +11,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
   amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
   // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
   amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
-  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
   amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
-  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
   amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 }  blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
   // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
   amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 }  blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a less awkward way to do this, so ... approved

@krzysz00 krzysz00 merged commit 836201f into llvm:main Jun 19, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants