From b5525715be4d9fa9b4acb6c06fe67a5880d7b349 Mon Sep 17 00:00:00 2001 From: Tori Date: Wed, 25 Sep 2024 14:11:21 +0200 Subject: [PATCH] Don't segfault on unsupported mma->mma https://github.com/triton-lang/triton/pull/4492/ started causing an issue where chained MMAs on hopper would segfault with 8 warps. It seems that previously this was checked, but the check got removed in this PR and it's still unsupported. Adding back this check means these MMAs will have to go back to shared memory, but it's better than segfaulting until it's actually supported. Resolves https://github.com/openxla/xla/issues/17356 --- lib/Analysis/Utility.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 679dcc88d788..56630c731858 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -565,6 +565,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && (elementTypeSize == 16 || elementTypeSize == 8); return ans;