Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce triton-arith-to-linalg pass #85

Merged

Conversation

haishanzzzz
Copy link
Contributor

This PR introduces triton-to-structured pass. Please see #81 for background.

Most of the logic of the pass are directly copied from triton-to-linalg. The main differences are:

  • Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.
  • Add pass options to control the behavior of a few misc. conversion:
    • tt.func -> func.func
    • tt.get_program_id -> function arguments
    • tt.assert -> cf.assert

@nhat-nguyen
Copy link
Collaborator

@haishanzzzz We have a new pattern added recently in this PR as well: #86

@haishanzzzz haishanzzzz marked this pull request as ready for review January 19, 2024 23:21
@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Jan 22, 2024

It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.

Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:

  • copy all of the shared patterns to a header file, perhaps ConversionPatterns.hpp, and place it under triton/include/TritonArithToLinalg/ConversionPatterns.hpp
    • note that we still leave the specific patterns intact; I think the only difference is AddPtrConverter
  • remove all the pattern definitions in both TritonArithToLinalg.cpp and TritonToLinalg.cpp
  • now both lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp and lib/Conversion/TritonToLinalg/TritonToLinalg.cpp will need to include ConversionPatterns.hpp:
    • #include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp"

This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both TritonToLinalg and TritonArithToLinalg depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file under TritonArithToLinalg.

Here's the rough diff for the approach I describe above:
a2d16df
Note that in the above diff I made TritonArithToLinalg reuse the old AddPtrConverter, but hopefully this illustrates what I'm trying to describe.

@haishanzzzz
Copy link
Contributor Author

It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.

Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:

  • copy all of the shared patterns to a header file, perhaps ConversionPatterns.hpp, and place it under triton/include/TritonArithToLinalg/ConversionPatterns.hpp

    • note that we still leave the specific patterns intact; I think the only difference is AddPtrConverter
  • remove all the pattern definitions in both TritonArithToLinalg.cpp and TritonToLinalg.cpp

  • now both lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp and lib/Conversion/TritonToLinalg/TritonToLinalg.cpp will need to include ConversionPatterns.hpp:

    • #include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp"

This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both TritonToLinalg and TritonArithToLinalg depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file under TritonArithToLinalg.

Here's the rough diff for the approach I describe above: a2d16df Note that in the above diff I made TritonArithToLinalg reuse the old AddPtrConverter, but hopefully this illustrates what I'm trying to describe.

Thank you for suggesting this Nhat and trying it out for me. This is a great idea. Will address before PR closes.

Please feel free to continue with the review in the mean time btw.

nhat-nguyen
nhat-nguyen previously approved these changes Jan 24, 2024
Copy link
Collaborator

@nhat-nguyen nhat-nguyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haishanzzzz Thanks for the refactor. This mostly reuses what we already have so all look good to me.

I do have one question before closing regarding

Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.

So after triton-to-structured, the following types of tt.addptr may exist:

  1. those that deal with scalars
  2. those that we can't analyze

So in this pass, we have the option of converting tt.addptr to linalg, but looking at AddPtrConverter, I don't see us filtering out scalar addptr at all. Should we add this?

@nhat-nguyen nhat-nguyen dismissed their stale review January 24, 2024 19:17

further questions

@haishanzzzz
Copy link
Contributor Author

@haishanzzzz Thanks for the refactor. This mostly reuses what we already have so all look good to me.

I do have one question before closing regarding

Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.

So after triton-to-structured, the following types of tt.addptr may exist:

  1. those that deal with scalars
  2. those that we can't analyze

So in this pass, we have the option of converting tt.addptr to linalg, but looking at AddPtrConverter, I don't see us filtering out scalar addptr at all. Should we add this?

Thank you for the review @nhat-nguyen!

We filter for non-scalar addptr in TritonArithToLinalgPass.cpp with the following:

  if (addptrToLinalg) {
      target.addDynamicallyLegalOp<triton::AddPtrOp>([](triton::AddPtrOp op) {
        return !op.getResult().getType().isa<ShapedType>();
      });
    }

@haishanzzzz haishanzzzz reopened this Jan 24, 2024
@nhat-nguyen nhat-nguyen merged commit 7033698 into microsoft:main Jan 25, 2024
3 checks passed
@aaronsm
Copy link
Contributor

aaronsm commented Jan 26, 2024

I believe this change has broken the build due to hardcodes triton-shared directory names. See issue #91 for more details.

@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Jan 26, 2024

@aaronsm Would you mind sharing more details? Issue #91 was opened 3 days before but we only merged this yesterday. The failure that you see is expected as triton had a big refactor in how 3rd party plugins work. I'm working on a fix right now. I also left a comment in the linked issue.

nhat-nguyen pushed a commit that referenced this pull request Feb 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants