From a7d3f5dc94d239eee8d0ddf8104ca7b5c9cf99ed Mon Sep 17 00:00:00 2001 From: honglinzhu Date: Wed, 26 Feb 2025 15:40:47 +0800 Subject: [PATCH 1/3] Add support for func attr inheritance in SplitLayoutRewritePreproc fix bug in test delete layout_free_buffers --- .../transform/split_layout_rewrite_preproc.cc | 23 ++++- ..._transform_split_layout_rewrite_preproc.py | 84 +++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 5fee946c26dd..69b031339770 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -81,7 +81,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body)); - PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map); + Map dict; + for (const auto& [key, original_value] : original_func_->attrs->dict) { + if (key == "global_symbol") { + dict.Set(key, Downcast(original_value) + "_weight_prepack"); + } else if (key != "layout_free_buffers") { + dict.Set(key, original_value); + } + } + DictAttrs attrs(dict); + PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map, attrs); return RenewDefs(func); } @@ -118,7 +127,17 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers)); - PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map); + Map dict; + for (const auto& [key, original_value] : original_func_->attrs->dict) { + if (key == "global_symbol") { + dict.Set(key, Downcast(original_value) + "_prepacked"); + } else if (key != "layout_free_buffers") { + dict.Set(key, original_value); + } + } + DictAttrs attrs(dict); + PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map, attrs); + return RenewDefs(func); } diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index e6b4c8ec4e2a..a5b74283fe2a 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -216,5 +216,89 @@ def forward( tvm.ir.assert_structural_equal(mod, After) +def test_attr_inheritance(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)}) + W_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv = R.call_tir( + cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + if __name__ == "__main__": tvm.testing.main() From 1aa94925aec1e4f07795eeac35e39564d98e6dbe Mon Sep 17 00:00:00 2001 From: honglinzhu Date: Thu, 27 Feb 2025 10:59:54 +0800 Subject: [PATCH 2/3] rebase latest main --- .../python/relax/test_transform_split_layout_rewrite_preproc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index a5b74283fe2a..53e528303123 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -66,6 +66,7 @@ def tir_func_prepacked( W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), Out: T.Buffer((224, 224), "float32"), ): + T.func_attr({"layout_free_buffers": [1]}) for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): with T.block("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) @@ -77,6 +78,7 @@ def tir_func_weight_prepack( W: T.Buffer((224, 224), "float32"), W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), ): + T.func_attr({"layout_free_buffers": [1]}) for i, j in T.grid(224, 224): with T.block("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) From 17cd76220bfb6a90838ea32a964d027cc9e68cbd Mon Sep 17 00:00:00 2001 From: honglinzhu Date: Thu, 27 Feb 2025 18:33:31 +0800 Subject: [PATCH 3/3] delete layout_free_buffers --- .../python/relax/test_transform_split_layout_rewrite_preproc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index 53e528303123..a5b74283fe2a 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -66,7 +66,6 @@ def tir_func_prepacked( W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), Out: T.Buffer((224, 224), "float32"), ): - T.func_attr({"layout_free_buffers": [1]}) for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): with T.block("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) @@ -78,7 +77,6 @@ def tir_func_weight_prepack( W: T.Buffer((224, 224), "float32"), W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), ): - T.func_attr({"layout_free_buffers": [1]}) for i, j in T.grid(224, 224): with T.block("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j])