Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Support Hopper MMA to MMA convert_layout ops #4492

Merged
merged 15 commits into from
Aug 12, 2024
Merged

Conversation

Jokeren
Copy link
Contributor

@Jokeren Jokeren commented Aug 8, 2024

This PR enables mma to mma conversion on the hopper architecture.
We also replace the previous isMmaToMmaShortcut check with cvtReordersRegisters in several places.
Note that mma to mma conversion using shared memory still goes through the legacy ConvertLayoutOpConversion function; we will deprecate it soon in the next PR.

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update
@Jokeren Jokeren linked an issue Aug 9, 2024 that may be closed by this pull request
@Jokeren Jokeren marked this pull request as ready for review August 9, 2024 16:31
@Jokeren Jokeren requested a review from ptillet as a code owner August 9, 2024 16:31
@Jokeren Jokeren changed the title [BACKEND][DRAFT] Hopper mma to mma [BACKEND] Hopper mma to mma Aug 9, 2024
@jlebar
Copy link
Collaborator

jlebar commented Aug 9, 2024

[BACKEND] Hopper mma to mma

Could we make the PR description have a verb in it? Like, "Support Hopper MMA to MMA convert_layout ops", something like that?

@@ -289,12 +290,37 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (std::optional<LinearLayout> c = conversion.divideRight(

// There are two three possible cases
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

// 2. The `src_layout` has fewer registers than the `dst_layout`.
// 3. The `src_layout` has more registers than the `dst_layout`.
// In the second case, we may generate a conversion that is not surjective
// because not all lanes are covered. Instead, we could use the inverse of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying that we may *have generated* a conversion, that is, are you saying that the function "conversion" is non-surjective?

// 2. The `src_layout` has fewer registers than the `dst_layout`.
// 3. The `src_layout` has more registers than the `dst_layout`.
// In the second case, we may generate a conversion that is not surjective
// because not all lanes are covered. Instead, we could use the inverse of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, we can use

// 2. The `src_layout` has fewer registers than the `dst_layout`.
// 3. The `src_layout` has more registers than the `dst_layout`.
// In the second case, we may generate a conversion that is not surjective
// because not all lanes are covered. Instead, we could use the inverse of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all lanes are covered

I think you mean "not all destination registers" are covered, you don't mean to talk about lanes?

// because not all lanes are covered. Instead, we could use the inverse of
// the conversion, mapping from `dst_layout` to `src_layout`, which is
// surjective. This inverse layout indicates that multiple destination
// registers may come from the same source register.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this comment correctly, I think it can be rewritten more concisely.

If src_layout has fewer registers than dst_layout, then conversion, which is src_layout . dst_layout^-1, will necessarily be non-surjective. But the whole point is to cover all of the destination registers, so in this case we use the inverse layout, namely dst_layout . src_layout^-1 instead.

Additionally:

  1. I think this comment should be moved inside the if statement?
  2. Instead of using conversion sometimes and inverseConversion other times, why don't we simply always use inverseConversion?

// <inDimName, baseIdx, outValue>
std::vector<std::tuple<StringAttr, int, int>> sortedBases;
for (auto [inDimName, basis] : bases) {
for (size_t baseIdx = 0; baseIdx < basis.size(); baseIdx++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basisIdx, I think? There's no such thing as a base here.

@@ -673,12 +673,13 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
return newOp;
}

// Check if the convert will be a no-op in codegen.
// Check if the convert will be performed without shared memory.
static bool isFreeConvert(Operation *op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change the function name too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are you sure that "convert is performed without shared memory" is what you want to check? Maybe you actually want to check "convert is performed by reordering registers, i.e. it's a nop"? For example, a convert that does a warp shuffle maybe should not be "free"? I don't know how it's used though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unknow whether warp shuffle with be considered as "free" or inexpensive, but we agree that better to consider only register->register conversion for now.

@@ -1936,7 +1936,8 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) {

// -----

// Check that we handle corner cases when hoisting convert on top of extf. For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice.
// Check that we handle corner cases when hoisting convert on top of extf because convert ops on a smaller type which is faster.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment

LinearLayout({{S("register"), {{4}}},
{S("lane"), {{0}, {0}, {1}, {2}, {0}}},
LinearLayout({{S("register"), {{0}}},
{S("lane"), {{0}, {0}, {1}, {2}, {4}}},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug before? If so, did this affect production code? If so, could we explain the bug in the PR description? (I don't fully understand why the new thing is right...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was a bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't find problems in real mmav2 test cases, but we did find wrong mappings for real mmav3 test cases.
By real cases, I'm referring to those used in test_core.py but not C++ tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this unit test case, the slice is a column from the parent. Because the output dimension is 8, it's fully covered by the first registers of t0, t4, ..., t28. The second registers of t0, t4, ..., t28 just store duplicated values of the first registers.

@Jokeren Jokeren changed the title [BACKEND] Hopper mma to mma [BACKEND] Support Hopper MMA to MMA convert_layout ops Aug 9, 2024
@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 9, 2024

Hi @jlebar , feel free to take the second pass now. Thanks for the suggestions. Hope my comments are better now :)

@@ -189,18 +189,16 @@ bool supportMMA(triton::DotOp op, int version);

bool supportMMA(Value value, int version);

bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a brief comment to this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(In particular does the MmaToDot shortcut also just reorder registers? If so maybe the name for this should have "LinearLayout" or "LL" in the name?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean "LL" in the name?

Copy link
Contributor Author

@Jokeren Jokeren Aug 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular does the MmaToDot shortcut also just reorder registers?

There are some hacks. If we get true from calling matchMmaV3AndDotOperandLayout, actually we already use warp shuffle, but this is a very special case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"LL" meant LinearLayout. Anyway this is good, thanks. :)

@@ -279,6 +281,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
int numWarps = conversion.getInDimSize(str_attr("warp"));
int numBlocks = conversion.getInDimSize(str_attr("block"));

StringAttr kRegister = str_attr("register");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused now?

auto dstToSrc = inverseConversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be clearer to move this into the transferWithinThread function.

Fundamentally the current function creates a layout called conversion, which it uses to decide what kind of conversion to do -- reg-to-reg, shfl, shmem, whatever. Then it calls transferWithinX, which figures out how to do the conversion.

If you moved this into the transferWithinThread function, you wouldn't need much of a comment at all, because you wouldn't be tempted to use conversion! That is, it would be (more) obvious to do dstLayout->invertAndCompose(*srcLayout) rather than vice versa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(It's also confusing as-is that we take a variable called inverseConversion and pass it into a function where it's used in a variable called plain conversion. Again moving the code into the function may clarify this...)

auto dstToSrc = inverseConversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written here, we're implicitly asserting that if cvtReordersRegisters is true, then you can divide the layout like this.

I don't love this because we're duplicating the logic (i.e. the division) in two places. Could we make a function which tries to divide and returns an Optional instead?

// because bases that map to a location larger the shape[d]
// effectively duplicate along that dimension. For example, consider a layout
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
// shrink the output dimension size to 8:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice some typos in this. Here's what ChatGPT suggests, I think this is a good start. https://chatgpt.com/share/e/0290a621-c78b-49e3-97cf-8a95b6b83030

// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 16.
// We achieve this by setting the largest value in each output dimension d to 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't know what the "largest value in [an] output dimension" is. The old comment makes sense to me, but the new comment seems to have lost information, making it difficult for me to understand.

// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem to have lost the indentation in the example here. Was that intentional? Indentation helps readers scan the example.

// dimensions.
// Now the output dimension of this layout has a size of 8, which is the desired
// size. Note that this method works only because the bases are powers of two.
// It is unclear what to do when they are not.
LinearLayout ensureLayoutNotLargerThan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think I understand what you're doing here now. Thank you for improving the comments.

I am not convinced this new behavior matches the behavior of the old legacy layouts, though. (Maybe we don't need to match them because we got rid of emitIndices based on legacy layouts? But changing the behavior of the layouts is very subtle and can affect lots of other code.)

What problem are we solving that requires us to make this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem we are trying to solve is that the old logic is wrong for wgmma v3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"triton_gpu.convert_layout"(%1912) {allocation.offset = 0 : i32} : (tensor<64x64xf16, #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>>) -> tensor<64x64xf16, #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old logic:

 - register=1 -> (0, 1)
   register=2 -> (8, 0)
   register=4 -> (0, 4)
   register=8 -> (0, 8)
   register=16 -> (0, 16)
   register=32 -> (0, 32)
 - lane=1 -> (0, 2)
   lane=2 -> (0, 0)
   lane=4 -> (1, 0)
   lane=8 -> (2, 0)
   lane=16 -> (4, 0)
 - warp=1 -> (16, 0)
   warp=2 -> (32, 0)
 - block is a size 1 dimension
where out dims are: [dim0 (size 64), dim1 (size 64)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lane 2 should be mapped to (0, 4) instead of [0, 0]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 11, 2024

Hi @jlebar , please feel free to take another pass now.

@@ -189,18 +189,16 @@ bool supportMMA(triton::DotOp op, int version);

bool supportMMA(Value value, int version);

bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"LL" meant LinearLayout. Anyway this is good, thanks. :)

// dimensions.
// Now the output dimension of this layout has a size of 8, which is the desired
// size. Note that this method works only because the bases are powers of two.
// It is unclear what to do when they are not.
LinearLayout ensureLayoutNotLargerThan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@Jokeren Jokeren enabled auto-merge (squash) August 12, 2024 00:20
@Jokeren Jokeren merged commit 7d89248 into main Aug 12, 2024
6 checks passed
@Jokeren Jokeren deleted the keren/mma-to-mma branch August 12, 2024 00:29
vwbaker added a commit to openxla/triton that referenced this pull request Sep 25, 2024
triton-lang#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356
Jokeren pushed a commit that referenced this pull request Sep 25, 2024
…4803)

#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356

Co-authored-by: Tori <vwbaker@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Layout conversion error on H100
2 participants