-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce triton-arith-to-linalg pass #85
Introduce triton-arith-to-linalg pass #85
Conversation
@haishanzzzz We have a new pattern added recently in this PR as well: #86 |
It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass. Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:
This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both Here's the rough diff for the approach I describe above: |
Thank you for suggesting this Nhat and trying it out for me. This is a great idea. Will address before PR closes. Please feel free to continue with the review in the mean time btw. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haishanzzzz Thanks for the refactor. This mostly reuses what we already have so all look good to me.
I do have one question before closing regarding
Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.
So after triton-to-structured
, the following types of tt.addptr
may exist:
- those that deal with scalars
- those that we can't analyze
So in this pass, we have the option of converting tt.addptr
to linalg, but looking at AddPtrConverter
, I don't see us filtering out scalar addptr
at all. Should we add this?
Thank you for the review @nhat-nguyen! We filter for non-scalar addptr in TritonArithToLinalgPass.cpp with the following:
|
I believe this change has broken the build due to hardcodes triton-shared directory names. See issue #91 for more details. |
This PR introduces
triton-to-structured
pass. Please see #81 for background.Most of the logic of the pass are directly copied from
triton-to-linalg
. The main differences are:tt.addptr
op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.tt.func
->func.func
tt.get_program_id
-> function argumentstt.assert
->cf.assert