diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 679dcc88d788..56630c731858 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -565,6 +565,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && (elementTypeSize == 16 || elementTypeSize == 8); return ans;