Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 3-5D TMA to allow loading non-matmul operands #5207

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

masahi
Copy link

@masahi masahi commented Nov 21, 2024

Mostly a mechanical change to expand the experimental support for TMA, which is currently limited to 1-2D. I have a use case for TMA to load 3D or 4D tensor which encodes blocked scales from MXFP in a specialized layout.

Swizzling is disabled for higher rank TMA, to set hasLeadingOffset = false for the dst SMEM allocated in TMA lowering. The new unittest fails if swizzling is enabled for TMA and hasLeadingOffset = true. I believe this is simply due to implementation limitations, so I hope we can enable swizziling for higher rank TMA in the future.

cc @ThomasRaoux @mbrookhart @csullivan

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@masahi masahi requested a review from ptillet as a code owner November 21, 2024 00:25
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
Copy link
Author

@masahi masahi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm removing this assert to unblock higher rank SMEM load, which seems to work fine.

I don't know what assumption this code has, so please let me know if there is a more reasonable relaxation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine, we might want to test this more heavily but doesn't have to be part of this.

// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA
// implies hasLeadingOffset = true in SMEM encoding, which is currently not
// supported for higher rank TMA copies. This convention needs to be in sync
// with the TMA lowering pass in codegen.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux This is my takeaway from our discussion yesterday. Let me know if this is ok.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases

Copy link
Author

@masahi masahi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With or without this PR, the underlying limitation that prevents 3D TMA with swizzling for matmul inputs continues to exist. So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.

For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.

Overall, I think this PR won't make the situation any worse. Maybe the new TMA representation you mentioned would solve all of those issues. But while we wait for that, it would be good to enable more use cases for TMA within Triton - it is an "experimental" feature, after all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well the case that I think is interesting is to write a batch matmul case where the global tensor is 3D but each block loads a 2D tensor and compute matmul on it.

So I would say this changes just make the limitation explicit in the API temporarily until the issue is fixed, at which point we can remove this convention.

even then you wouldn't want swizzling for this case right?

For batched matmul kind of workloads, the flattening of a higher dim tensor into 2D is straightforward. So there is still an escape hatch.

That breaks the bound checking part right?

Copy link
Author

@masahi masahi Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even then you wouldn't want swizzling for this case right?

Oh sorry if there was misunderstanding. I believe 2D-5D TMA should be treated equally, and I do hope that we can remove this restriction. Whether or not swizzing would be beneficial for my use case is a separate question I need to investigate in the future. Right now my inner-most axis size is 16B so no swizzling would be applied anyway. But I can tweak the sizes of the inner-most two dims, to make the inner-most axis wider and apply swizzling if I want to.

That breaks the bound checking part right?

hmm I haven't thought about that but indeed I don't see how OOB check can work if some dims are flattened (if possible at all).

Maybe the device-side tensor-map creation can be used? After we get the batch id (or group id for grouped gemm), we can use 2D TMA. I'm not sure if that's supported now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get clarified on your concern. I thought this PR would not have any negative implications, since higher-rank TMA with swizzling doesn't work anyway. But are you saying that, one important special case of 3D TMA, where the actual load is 2D (since one of copy dim sizes is always one) is supposed to work with the current impl, but my change would disallow the swizzling for that case as well?

If that's the case, the only workaround, without a proper fix, would be to make swizzling a parameter for higher-rank TMA that the user provide. By default, we don't swizzle. We also need to pass the same swizzling param to tl.experimental_descriptor_load(...) to make the codegen and runtime code in sync.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea would be to base the decision to enable swizzling not on the rank of the global tensor but the "effective rank" of the box, where by "effective rank" I mean a rank after removing size-1 dims.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code changes look reasonable to me but it is definitely showing the limitation of the convention we have around swizzling.
@peterbell10 is working on the next representation for TMA and this representation will become deprecated or kept for experimental.
If this is going to block you I think it is fine to go with this as this is in the experimental path but if there is an alternative we should look at how to fit that in the new path.

// For now, we do not swizzle for higher ranks. Enabling swizzling in TMA
// implies hasLeadingOffset = true in SMEM encoding, which is currently not
// supported for higher rank TMA copies. This convention needs to be in sync
// with the TMA lowering pass in codegen.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct but this convention makes me bit nervous as we won't be able to handle the case where we 3D inputs for a batch matmul kind of cases

auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine, we might want to test this more heavily but doesn't have to be part of this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants