Skip to content

Commit

Permalink
Don't segfault on unsupported mma->mma
Browse files Browse the repository at this point in the history
triton-lang#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356
  • Loading branch information
vwbaker committed Sep 25, 2024
1 parent 493f991 commit b552571
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
Expand Down

0 comments on commit b552571

Please sign in to comment.