diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py new file mode 100644 index 000000000..860c449ed --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tl.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), + True) + + +def test_trival_pipeline(): + + @T.prim_func + def before(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 1, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1] + }): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + @T.prim_func + def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0]) + T.writes(C[tx, 0]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads(A[tx, 1:1], B[0:2, tx, 0]) + T.writes(B[1:1, tx, 0], C[tx, 0:0]) + for i in range(0): + with T.block(""): + T.reads(A[tx, i + 1]) + T.writes(B[i + 1, tx, 0]) + B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(""): + T.reads(B[i, tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[i, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[0, tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[0, tx, 0] + T.float32(1) + + _check(before, expected) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_frontend_legalize.py b/testing/python/transform/test_tilelang_transform_frontend_legalize.py new file mode 100644 index 000000000..076d63d9e --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_frontend_legalize.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = tl.transform.FrontendLegalize()(mod) + print(mod.script()) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), + True) + + +def test_let_binding(): + + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): + for i in range(128): + for j in range(128): + with T.block("compute"): + factor = T.float32(2.0) + value = A[i, j] * factor + B[i, j] = value + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): + for i in range(128): + for j in range(128): + with T.block("compute"): + B[i, j] = A[i, j] * T.float32(2.0) + + _check(before, expected) + + +def test_parallel_scope(): + + @T.prim_func + def before(A: T.Buffer((128,), "float32")): + for i in T.Parallel(128): + with T.block("parallel"): + value = T.float32(1.0) + A[i] = value + + @T.prim_func + def expected(A: T.Buffer((128,), "float32")): + for i in T.Parallel(128): + with T.block("parallel"): + A[i] = T.float32(1.0) + + _check(before, expected) + + +if __name__ == "__main__": + tilelang.testing.main()