Skip to content

Commit

Permalink
[backend] Fix improper mma->dot shortcut when warpsPerCTA[1] > 1 (#…
Browse files Browse the repository at this point in the history
…4803)

#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356

Co-authored-by: Tori <vwbaker@google.com>
  • Loading branch information
chsigg and vwbaker authored Sep 25, 2024
1 parent eb7925d commit 694719a
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
Expand Down

0 comments on commit 694719a

Please sign in to comment.