Skip to content

Commit

Permalink
[ROCm] Fix FP32 atomic_rmw
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Jun 19, 2024
1 parent 6739494 commit 308cd1e
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions third_party/triton/temporary/amd_pr7.patch
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,31 @@ index f59efd6..cf601f0 100644
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16NZ(loc, rewriter, v);
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
index 83f24d711..82aad06c5 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
@@ -599,12 +599,23 @@ struct AtomicRMWOpConversion
auto maybeKind = matchAtomicOp(atomicRmwAttr);
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
// atomics for MI-* series of AMD GPU.
+ if(isa<FloatType>(valElements[i].getType()) &&
+ (*maybeKind != mlir::LLVM::AtomicBinOp::fadd)) {
+ valElem = bitcast(valElements[i],
+ int_ty(valElements[i].getType().getIntOrFloatBitWidth()));
+ }
+
Value atom = rewriter
.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, rmwPtr, valElements[i],
atomicMemOrdering, StringRef("agent"))
.getResult();

+ if(isa<FloatType>(valElements[i].getType()) &&
+ (*maybeKind != mlir::LLVM::AtomicBinOp::fadd)) {
+ atom = bitcast(atom, valElements[i].getType());
+ }
+
// NV for the f16v2 case generates one packed instruction. We have to
// create two separate instructions since LLVM::AtomicRMWOp doesn't
// support this. Can be optimized out with rocdl.raw.buffer.atomic.

0 comments on commit 308cd1e

Please sign in to comment.