Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M4a] Schedule Rule: Multi-Level-Tiling #10043

Merged
merged 4 commits into from
Jan 26, 2022

Conversation

jinhongyii
Copy link
Contributor

@jinhongyii jinhongyii commented Jan 24, 2022

This PR is one of the schedule rule for MetaSchedule.
The rule does not support auto tensorization for now.

Co-authored-by: Junru Shao junrushao1994@gmail.com
Co-authored-by: Xiyou Zhou xiyou@octoml.ai
Co-authored-by: Bohan Hou spectrometerh@gmail.com
Co-authored-by: Siyuan Feng Hzfengsy@sjtu.edu.cn
Co-authored-by: Ruihang Lai lairuihangdongdong@qq.com
Co-authored-by: Wuwei Lin wuwei@apache.org

@junrushao1994 @Hzfengsy @comaniac

- None on CPU
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
use_tensor_core : bool
Whether to apply tensor core wmma intrinsic for the computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a broader question, but are we going to bake tensorcore specific logic into this core schedule rule? What if I want to support other matrix cores from Intel or AMD?

I feel like the design should be that backend specific rules be decoupled from the main "driver" that does actual tiling like the one introduced in this PR. cc @junrushao1994 @vinx13

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We want to support auto tensorization for all hardware platforms. use-tensor-core is probably not a good name here - maybe use-tensor-intrin could be a better one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this up because I saw the upcoming code in https://github.com/junrushao1994/tvm/blob/meta-schedule/src/meta_schedule/schedule_rule/multi_level_tiling.cc where tensorcore stuff is hardcoded all over the place. I expect non-trivial refactoring before we can land this on main.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you are right. We might want to remove all the tensor core hard-coding here and focusing on AutoScheduler alignment first during upstreaming

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@junrushao junrushao changed the title [MetaSchedule] Schedule Rule: Multi Level Tiling [MetaSchedule][M4a] Schedule Rule: Multi-Level-Tiling Jan 26, 2022
@Hzfengsy Hzfengsy merged commit ffbe491 into apache:main Jan 26, 2022
sunggg pushed a commit to sunggg/tvm that referenced this pull request Jan 29, 2022
* multi level tiling

* remove tensor core related code

* pylint

* fix

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* multi level tiling

* remove tensor core related code

* pylint

* fix

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants