Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1a0906b
[Dev] Add FlashDecoding example
chengyupku Jan 24, 2025
bf11c0a
[Dev] Merge conflicts
chengyupku Jan 25, 2025
1ecfac5
merge upstream
chengyupku Jan 25, 2025
26c96f6
[CI][Test] Add test cases for tilelang kernel convolution
chengyupku Jan 25, 2025
6eaa0f0
Merge branch 'main' of https://github.com/tile-ai/tilelang into main
chengyupku Jan 25, 2025
bf345ed
[CI][Test] Add test cases for tilelang kernel FlashAttention
chengyupku Jan 25, 2025
2bac564
Reduce the number of stages to ensure the shared memory allocation is…
chengyupku Jan 25, 2025
155ee64
Temporarily remove the dim128 case
chengyupku Jan 25, 2025
60a859c
lint
chengyupku Jan 25, 2025
1481781
update einops in requirements-dev.txt
chengyupku Jan 25, 2025
0084c9a
update einops in requirements-test.txt
chengyupku Jan 25, 2025
bb4babb
remove einops in requirements-dev.txt
chengyupku Jan 25, 2025
6253e1c
[CI][Test] Add test cases for tilelang transform ClusterPlanning
chengyupku Jan 26, 2025
7fc58a1
Merge branch 'main' of https://github.com/tile-ai/tilelang into main
chengyupku Jan 26, 2025
fb2716e
[CI][Test] Add test cases for tilelang transform LowerHopperIntrin
chengyupku Jan 26, 2025
0c7b643
Merge branch 'main' of https://github.com/tile-ai/tilelang into main
chengyupku Jan 26, 2025
2897423
Merge branch 'main' of https://github.com/tile-ai/tilelang into main
chengyupku Feb 8, 2025
da9f1a8
[CI][Test] Add test cases for tilelang transform InjectFenceProxy
chengyupku Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
from tilelang.utils.target import determine_target
import tilelang.language as T
import tilelang.testing
from tvm import tir

auto_target = tvm.target.Target(determine_target("auto"))


def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
transformed = tir.transform.LowerOpaqueBlock()(transformed)

tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)


def test_lower_fence_proxy():

@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))

@T.prim_func
def after():
with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.FenceProxyAsyncOp()
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))

_check(before, after)


if __name__ == "__main__":
test_lower_fence_proxy()
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,3 @@ def after():

if __name__ == "__main__":
tilelang.testing.main()
test_lower_hopper_intrin_barrier()
4 changes: 4 additions & 0 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def CreateTMADescriptorOp(*args):

def TMALoadOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)


def FenceProxyAsyncOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
Loading