diff --git a/src/op/builtin.h b/src/op/builtin.h index 05f017214..5e34ed965 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -39,7 +39,7 @@ const Op &CreateTMAIm2ColDescriptorOp(); /*! * \brief Create a list of mbarrier with num_threads * - * GetMBarrier(num_threads0, num_threads1, ...) + * CreateListofMBarrierOp(num_threads0, num_threads1, ...) * */ const Op &CreateListofMBarrierOp(); diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py new file mode 100644 index 000000000..c4bf1c35a --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tilelang import tvm as tvm +import tilelang as tl +from tilelang.utils.target import determine_target +import tilelang.language as T +import tilelang.testing +from tvm import tir + +auto_target = tvm.target.Target(determine_target("auto")) + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tvm.tir.transform.BindTarget(auto_target)(mod) + mod = tl.transform.LowerHopperIntrin()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) + transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) + transformed = tir.transform.LowerOpaqueBlock()(transformed) + + tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) + + +def test_lower_hopper_intrin_barrier(): + + @T.prim_func + def before(): + with T.Kernel(8): + _ = T.launch_thread("threadIdx.x", 128) + T.CreateListofMBarrierOp(128, 128, 128, 128) + + @T.prim_func + def after(): + with T.Kernel(8): + v_1 = T.launch_thread("threadIdx.x", 128) + T.evaluate(tir.Call("handle", "tir.create_barriers", [4])) + with T.If(v_1 == 0), T.Then(): + T.evaluate( + tir.Call("handle", "tir.ptx_init_barrier_thread_count", + [T.GetMBarrierOp(0), 128])) + T.evaluate( + tir.Call("handle", "tir.ptx_init_barrier_thread_count", + [T.GetMBarrierOp(1), 128])) + T.evaluate( + tir.Call("handle", "tir.ptx_init_barrier_thread_count", + [T.GetMBarrierOp(2), 128])) + T.evaluate( + tir.Call("handle", "tir.ptx_init_barrier_thread_count", + [T.GetMBarrierOp(3), 128])) + T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) + + _check(before, after) + + +if __name__ == "__main__": + tilelang.testing.main() + test_lower_hopper_intrin_barrier() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 3d93b2dcb..3249680e7 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -30,6 +30,7 @@ atomic_addx2, # noqa: F401 dp4a, # noqa: F401 ) +from .builtin import * # noqa: F401 def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py new file mode 100644 index 000000000..b2251bb90 --- /dev/null +++ b/tilelang/language/builtin.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""The language interface for tl programs.""" + +from tvm import tir + + +def CreateListofMBarrierOp(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args) + + +def GetMBarrierOp(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args) + + +def CreateTMADescriptorOp(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args) + + +def TMALoadOp(*args): + return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)