From b8b5fb6a1c63bdd3409e2e266d2ac386f8fbbb26 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 12 Sep 2024 13:25:23 -0500 Subject: [PATCH] [IR] Expose ReplaceGlobalVars utility in the Python API (#17361) * [IR] Expose ReplaceGlobalVars utility in the Python API This is a follow-up PR to https://github.com/apache/tvm/pull/17202, which added a general utility to replace `GlobalVar` instances across all TVM IR types. This PR exposes this new utility through the Python API, and explicitly tests its functionality. * Lint fix --- ...ace_global_var.h => replace_global_vars.h} | 10 +- python/tvm/ir/module.py | 28 ++ ...e_global_var.cc => replace_global_vars.cc} | 43 ++- src/relax/transform/attach_global_symbol.cc | 4 +- ...e_global_var.cc => replace_global_vars.cc} | 23 +- ...e_global_var.cc => replace_global_vars.cc} | 20 +- .../ir/test_transform_replace_global_var.py | 306 ++++++++++++++++++ 7 files changed, 418 insertions(+), 16 deletions(-) rename include/tvm/ir/{replace_global_var.h => replace_global_vars.h} (85%) rename src/ir/{replace_global_var.cc => replace_global_vars.cc} (55%) rename src/relax/transform/{replace_global_var.cc => replace_global_vars.cc} (72%) rename src/tir/transforms/{replace_global_var.cc => replace_global_vars.cc} (75%) create mode 100644 tests/python/ir/test_transform_replace_global_var.py diff --git a/include/tvm/ir/replace_global_var.h b/include/tvm/ir/replace_global_vars.h similarity index 85% rename from include/tvm/ir/replace_global_var.h rename to include/tvm/ir/replace_global_vars.h index c15dd5f4e5ad..ea91d46d7c0a 100644 --- a/include/tvm/ir/replace_global_var.h +++ b/include/tvm/ir/replace_global_vars.h @@ -18,13 +18,13 @@ */ /*! - * \file tvm/ir/replace_global_var.h + * \file tvm/ir/replace_global_vars.h * * \brief A utility to replace GlobalVar instances across all TVM IR * types in an IRMdoule. */ -#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_ -#define TVM_IR_REPLACE_GLOBAL_VAR_H_ +#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_ +#define TVM_IR_REPLACE_GLOBAL_VARS_H_ #include @@ -41,7 +41,7 @@ namespace transform { * * \return The updated IRModule */ -TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map replacements); +TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map replacements); struct GlobalVarReplacer { using FType = NodeFunctor)>; @@ -54,4 +54,4 @@ struct GlobalVarReplacer { } // namespace transform } // namespace tvm -#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_ +#endif // TVM_IR_REPLACE_GLOBAL_VARS_H_ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index ea3ef6d8831b..3c76dbfdd839 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" + from __future__ import annotations from typing import Dict, Union @@ -216,6 +217,33 @@ def get_global_vars(self): """ return _ffi_api.Module_GetGlobalVars(self) + def replace_global_vars( + self, + replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]], + ) -> "IRModule": + """Replace GlobalVar instances within the module + + Replace GlobalVars within the IRModule. Since the IRModule + may contain internal references to a GlobalVar, either in TIR + or in Relax, this method should be used whenever replacing or + renaming a GlobalVar. + + Parameters + ---------- + replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]] + + A dictionary where each key is a GlobalVar to be replaced, + and the corresponding value is the GlobalVar with which to + replace it. + + Returns + ------- + IRModule + The updated module + + """ + return _ffi_api.Module_ReplaceGlobalVars(self, replacements) + def get_global_type_vars(self): """Collect all global type vars defined in this module. diff --git a/src/ir/replace_global_var.cc b/src/ir/replace_global_vars.cc similarity index 55% rename from src/ir/replace_global_var.cc rename to src/ir/replace_global_vars.cc index 08d66d0e7cf2..9607dab11a6a 100644 --- a/src/ir/replace_global_var.cc +++ b/src/ir/replace_global_vars.cc @@ -18,18 +18,22 @@ */ /*! - * \file src/ir/replace_global_var.cc + * \file src/ir/replace_global_vars.cc * \brief IRModule transform to replace GlobalVar instances across any IR type. */ -#include +#include #include namespace tvm { namespace transform { -IRModule ReplaceGlobalVar(IRModule mod, Map replacements) { +IRModule ReplaceGlobalVars(IRModule mod, Map replacements) { + if (replacements.empty()) { + return mod; + } + std::vector to_remove; IRModule updates; @@ -57,7 +61,38 @@ IRModule ReplaceGlobalVar(IRModule mod, Map replacements) return mod; } -TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar); +TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); + +IRModule ModuleReplaceGlobalVars( + IRModule mod, Map, Variant> replacements) { + Map gvar_replacements; + for (const auto& [before, after] : replacements) { + GlobalVar gvar_before; + if (auto gvar = before.as()) { + gvar_before = gvar.value(); + } else if (auto str = before.as()) { + gvar_before = mod->GetGlobalVar(str.value()); + } else { + LOG(FATAL) << "Variant must contain either String or GlobalVar"; + } + + GlobalVar gvar_after; + if (auto gvar = after.as()) { + gvar_after = gvar.value(); + } else if (auto str = after.as()) { + gvar_after = gvar_before; + gvar_after.CopyOnWrite()->name_hint = str.value(); + } else { + LOG(FATAL) << "Variant must contain either String or GlobalVar"; + } + + gvar_replacements.Set(gvar_before, gvar_after); + } + + return ReplaceGlobalVars(mod, gvar_replacements); +} + +TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index a517d5a035e2..6f18339436fb 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -22,7 +22,7 @@ */ #include -#include +#include #include #include #include @@ -72,7 +72,7 @@ Pass AttachGlobalSymbol() { mod.CopyOnWrite()->Update(updates); if (gvar_updates.size()) { - mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates); + mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates); } } return mod; diff --git a/src/relax/transform/replace_global_var.cc b/src/relax/transform/replace_global_vars.cc similarity index 72% rename from src/relax/transform/replace_global_var.cc rename to src/relax/transform/replace_global_vars.cc index b81b831036ff..ea5d5e18d8ff 100644 --- a/src/relax/transform/replace_global_var.cc +++ b/src/relax/transform/replace_global_vars.cc @@ -19,13 +19,13 @@ /*! * - * \file src/relax/transform/replace_global_var.cc + * \file src/relax/transform/replace_global_vars.cc * * \brief GlobalVar replacement across IR types */ #include -#include +#include #include #include #include @@ -53,7 +53,24 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, Map replacements) -> BaseFunc { Mutator mutator(replacements); - return Downcast(mutator(Downcast(func))); + auto new_func = Downcast(mutator(Downcast(func))); + + // If the function is externally exposed, and is being replaced + // by a GlobalVar with a new name, then the function's + // kGlobalSymbol must be updated to match. + if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + + return new_func; }); TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) diff --git a/src/tir/transforms/replace_global_var.cc b/src/tir/transforms/replace_global_vars.cc similarity index 75% rename from src/tir/transforms/replace_global_var.cc rename to src/tir/transforms/replace_global_vars.cc index 8ef8ba9276b0..3e8437063775 100644 --- a/src/tir/transforms/replace_global_var.cc +++ b/src/tir/transforms/replace_global_vars.cc @@ -19,12 +19,12 @@ /*! * - * \file src/tir/transforms/replace_global_var.cc + * \file src/tir/transforms/replace_global_vars.cc * * \brief GlobalVar replacement across IR types */ -#include +#include #include #include @@ -61,6 +61,22 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) if (!new_body.same_as(func->body)) { func.CopyOnWrite()->body = new_body; } + + // If the function is externally exposed, and is being replaced + // by a GlobalVar with a new name, then the function's + // kGlobalSymbol must be updated to match. + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = opt.value(); + for (const auto& [before, after] : replacements) { + if (before->name_hint == name) { + if (after->name_hint != name) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint); + } + break; + } + } + } + return func; }); diff --git a/tests/python/ir/test_transform_replace_global_var.py b/tests/python/ir/test_transform_replace_global_var.py new file mode 100644 index 000000000000..d31993141500 --- /dev/null +++ b/tests/python/ir/test_transform_replace_global_var.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +def _get_before_module(): + @I.ir_module + class Module: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Module.relax_subroutine(A) + C = R.call_tir(Module.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Module.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Module.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + return Module + + +def test_no_op_if_no_replacements(): + """If no replacements are performed, the IRModule is unmodified""" + + before = _get_before_module() + expected = before + + after = before.replace_global_vars({}) + + tvm.ir.assert_structural_equal(expected, after) + assert before.same_as(after) + + +def test_replace_relax_main(): + """An externally-exposed Relax function may be replaced + + In this example, the "relax_main" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the "global_symbol" attribute of the + externally-exposed function. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"relax_main": "relax_main_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_relax_subroutine(): + """An internal Relax function may be replaced + + In this example, the "relax_subroutine" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the GlobalVar used to call the subroutine within + "relax_main". The "global_symbol" attribute does not need to be + updated, because internal functions do not have this attribute. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"relax_subroutine": "relax_subroutine_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine_with_new_name(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine_with_new_name( + A: R.Tensor([16], "float32"), + ) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_tir_main(): + """An externally-exposed TIR function may be replaced + + In this example, the "tir_main" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, the "global_symbol" attribute of the externally-exposed + function. In addition, calls to the TIR function should be + updated to use the new GlobalVar. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"tir_main": "tir_main_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main_with_new_name(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_replace_tir_subroutine(): + """An internally-exposed TIR function may be replaced + + In this example, the "tir_subroutine" function is renamed. This + requires changing both the GlobalVar used to refer to the + function, and the GlobalVar used to refer to it. Internal + functions do not have the "global_symbol" attribute, so it does + not need to be updated. + + """ + + before = _get_before_module() + after = before.replace_global_vars({"tir_subroutine": "tir_subroutine_with_new_name"}) + + @I.ir_module + class Expected: + @R.function + def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine(A) + C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main(C, D) + + return D + + @R.function(private=True) + def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine_with_new_name(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +def test_simultaneous_replacements(): + """Multiple replacements may be performed simultaneously""" + + before = _get_before_module() + after = before.replace_global_vars( + { + "relax_main": "relax_main_with_new_name", + "relax_subroutine": "relax_subroutine_with_new_name", + "tir_main": "tir_main_with_new_name", + "tir_subroutine": "tir_subroutine_with_new_name", + } + ) + + @I.ir_module + class Expected: + @R.function + def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + R.func_attr({"relax.force_pure": True}) + + B = Expected.relax_subroutine_with_new_name(A) + C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32")) + + D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0) + Expected.tir_main_with_new_name(C, D) + + return D + + @R.function(private=True) + def relax_subroutine_with_new_name( + A: R.Tensor([16], "float32"), + ) -> R.Tensor([16], "float32"): + B = R.add(A, R.prim_value(T.float32(1.0))) + return B + + @T.prim_func + def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + Expected.tir_subroutine_with_new_name(A.data, B.data) + + @T.prim_func(private=True) + def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")): + A = T.decl_buffer(16, "float32", data=A_data) + B = T.decl_buffer(16, "float32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + tvm.ir.assert_structural_equal(Expected, after) + + +if __name__ == "__main__": + tvm.testing.main()