diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py new file mode 100644 index 000000000..b9839755f --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.InjectFenceProxy()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_lower_fence_proxy(): + + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") + C_local = T.decl_buffer((32,), scope="local") + for i in T.unroll(16): + C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") + C_local = T.decl_buffer((32,), scope="local") + for i in T.unroll(16): + C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) + T.FenceProxyAsyncOp() + T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", + T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) + + _check(before, after) + + +if __name__ == "__main__": + test_lower_fence_proxy() diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index c4bf1c35a..b9c5855bf 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -56,4 +56,3 @@ def after(): if __name__ == "__main__": tilelang.testing.main() - test_lower_hopper_intrin_barrier() diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index b2251bb90..08bc17e52 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -19,3 +19,7 @@ def CreateTMADescriptorOp(*args): def TMALoadOp(*args): return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args) + + +def FenceProxyAsyncOp(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args) \ No newline at end of file