Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Introducing MemHammer #14164

Merged
merged 1 commit into from
Mar 20, 2023
Merged

Conversation

cblmemo
Copy link
Contributor

@cblmemo cblmemo commented Mar 1, 2023

This PR introduces MemHammer, which performs threadblock level auto data movement in MetaSchedule. The vision that memhammer holds is to free users from laborious manual schedules at threadblock level. With memhammer, the only thing user needs to do is mark a specific block with an annotation auto_copy, and memhammer will lower it with auto thread index binding, vectorize, and wmma API calls. We also introduced two new schedule primitives read_at and write_at, which enable users to perform a cache read / write in an easy-to-use manner, without arduous cache_read, compute_at, and other manual optimizations.

Given a data movement description like this:

@tvm.script.ir_module
class GlobalToShared:
    @T.prim_func
    def main(a: T.handle, b: T.handle) -> None:
        A = T.match_buffer(a, [1024, 1024])
        B = T.match_buffer(b, [1024, 1024])
        with T.block("root"):
            T.block_attr({"warp_execution": True})
            for bx in T.thread_binding(8, thread="blockIdx.x"):
                for by in T.thread_binding(8, thread="blockIdx.y"):
                    for ty in T.thread_binding(8, thread="threadIdx.y"):
                        with T.block():
                            A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn")
                            with T.block("A_shared"):
                                T.block_attr({"auto_copy": 1, "vector_bytes": 16})
                                for ax0, ax1 in T.grid(128, 128):
                                    A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1]
                            with T.block("B"):
                                for ax0, ax1 in T.grid(128, 128):
                                    B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1]

By annotating the block with T.block_attr({"auto_copy": 1}) and other optional arguments, it will be lowered to the following code with cooperative fetch, vectorize, and other specified features:

@tvm.script.ir_module
class TransformedGlobalToShared:
    @T.prim_func
    def main(a: T.handle, b: T.handle) -> None:
        A = T.match_buffer(a, [1024, 1024])
        B = T.match_buffer(b, [1024, 1024])
        with T.block("root"):
            T.block_attr({"warp_execution":True})
            for bx in T.thread_binding(8, thread="blockIdx.x"):
                for by in T.thread_binding(8, thread="blockIdx.y"):
                    for ty in T.thread_binding(8, thread="threadIdx.y"):
                        with T.block():
                            A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn")
                            with T.block("A_shared"):
                                T.block_attr({"auto_copy":1, "vector_bytes":16})
                                for outer in T.serial(16):
                                    for ty_1 in T.thread_binding(8, thread="threadIdx.y"):
                                        for tx in T.thread_binding(32, thread="threadIdx.x"):
                                            for vec in T.vectorized(4):
                                                A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128]
                            with T.block("B"):
                                for ax0, ax1 in T.grid(128, 128):
                                    B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1]

For more examples, see tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py.

All supported features:

  • Inverse index mapping;
  • Coalesced access that auto binds loops to thread index and performs vectorize;
  • Add local cache stage;
  • Rewrite data movement to wmma fragments with wmma API calls.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 1, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@cblmemo cblmemo force-pushed the memhammer branch 2 times, most recently from 5d0de40 to b27abfb Compare March 13, 2023 07:14
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
@cblmemo cblmemo changed the title [WIP][MetaSchedule] Introducing MemHammer [MetaSchedule] Introducing MemHammer Mar 16, 2023
@cblmemo cblmemo marked this pull request as ready for review March 16, 2023 02:28
@spectrometerHBH spectrometerHBH merged commit 36b3097 into apache:main Mar 20, 2023
@cblmemo cblmemo deleted the memhammer branch November 24, 2023 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants