From 694719a08d9406a1ea5e81a24d29877f723caf4d Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 25 Sep 2024 21:10:39 +0200 Subject: [PATCH] [backend] Fix improper mma->dot shortcut when `warpsPerCTA[1] > 1` (#4803) https://github.com/triton-lang/triton/pull/4492/ started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves https://github.com/openxla/xla/issues/17356 Co-authored-by: Tori --- lib/Analysis/Utility.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 679dcc88d788..56630c731858 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -565,6 +565,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && (elementTypeSize == 16 || elementTypeSize == 8); return ans;