diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 5fee946c26dd..69b031339770 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -81,7 +81,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body)); - PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map); + Map dict; + for (const auto& [key, original_value] : original_func_->attrs->dict) { + if (key == "global_symbol") { + dict.Set(key, Downcast(original_value) + "_weight_prepack"); + } else if (key != "layout_free_buffers") { + dict.Set(key, original_value); + } + } + DictAttrs attrs(dict); + PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map, attrs); return RenewDefs(func); } @@ -118,7 +127,17 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*init=*/NullOpt, /*alloc_buffers=*/alloc_buffers)); - PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map); + Map dict; + for (const auto& [key, original_value] : original_func_->attrs->dict) { + if (key == "global_symbol") { + dict.Set(key, Downcast(original_value) + "_prepacked"); + } else if (key != "layout_free_buffers") { + dict.Set(key, original_value); + } + } + DictAttrs attrs(dict); + PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map, attrs); + return RenewDefs(func); } diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index e6b4c8ec4e2a..a5b74283fe2a 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -216,5 +216,89 @@ def forward( tvm.ir.assert_structural_equal(mod, After) +def test_attr_inheritance(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)}) + W_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv = R.call_tir( + cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + if __name__ == "__main__": tvm.testing.main()