diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h new file mode 100644 index 000000000000..afe18792dee0 --- /dev/null +++ b/include/tvm/ir/analysis.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/analysis.h + * + * Analysis routines that must function across multiple IR types for + * correctness. For example, identifying unused functions, when both TIR + * + */ +#ifndef TVM_IR_ANALYSIS_H_ +#define TVM_IR_ANALYSIS_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ir { + +class CalleeCollector { + public: + /* \brief Functor to be registered for IR types + * + * Should be implemented for each `BaseFunc` subclass. + * Implementation should call `CalleeCollector::Mark` for each + * `GlobalVar` in the function. + */ + using FType = NodeFunctor; + TVM_DLL static FType& vtable() { + static FType inst; + return inst; + } + + virtual ~CalleeCollector() {} + + /* \brief Collect the GlobalVar in a function */ + virtual void Mark(GlobalVar gvar) = 0; +}; + +Map> CollectCallMap(const IRModule& mod); + +} // namespace ir +} // namespace tvm + +#endif // TVM_IR_ANALYSIS_H_ diff --git a/include/tvm/ir/replace_global_var.h b/include/tvm/ir/replace_global_var.h new file mode 100644 index 000000000000..c15dd5f4e5ad --- /dev/null +++ b/include/tvm/ir/replace_global_var.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/replace_global_var.h + * + * \brief A utility to replace GlobalVar instances across all TVM IR + * types in an IRMdoule. + */ +#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_ +#define TVM_IR_REPLACE_GLOBAL_VAR_H_ + +#include + +namespace tvm { +namespace transform { + +/*! + * \brief Replace GlobalVar instances across any IR type. + * + * \param mod The module to update + * + * \param replacements The map, where each entry maps from an old + * `GlobalVar` to the new `GlobalVar` that should replace it. + * + * \return The updated IRModule + */ +TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map replacements); + +struct GlobalVarReplacer { + using FType = NodeFunctor)>; + TVM_DLL static FType& vtable() { + static FType inst; + return inst; + } +}; + +} // namespace transform +} // namespace tvm + +#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_ diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 939a5f638381..fdac74a0b4ec 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=unused-import """Common data structures across all IR variants.""" + from . import diagnostics, instrument, transform from .adt import Constructor, TypeData from .affine_type import TensorAffineType, TupleAffineType @@ -61,3 +62,5 @@ TypeVar, ) from .type_relation import TypeCall, TypeRelation + +from . import analysis diff --git a/python/tvm/ir/_ffi_analysis_api.py b/python/tvm/ir/_ffi_analysis_api.py new file mode 100644 index 000000000000..0013ec3b5026 --- /dev/null +++ b/python/tvm/ir/_ffi_analysis_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ir.analysis""" + +import tvm._ffi + + +tvm._ffi._init_api("ir.analysis", __name__) diff --git a/python/tvm/ir/analysis.py b/python/tvm/ir/analysis.py new file mode 100644 index 000000000000..11fa819e2275 --- /dev/null +++ b/python/tvm/ir/analysis.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=unused-import + +"""Common analysis across all IR variants.""" + +from typing import Dict, List + +import tvm +from . import _ffi_analysis_api as _ffi + + +def collect_call_map( + module: "tvm.ir.IRModule", +) -> Dict["tvm.ir.GlobalVar", List["tvm.ir.GlobalVar"]]: + """Collect the call map of a module + + Parameters + ---------- + module: tvm.ir.IRModule + The module to inspect + + Returns + ------- + call_map: Dict[tvm.ir.GlobalVar, List[tvm.ir.GlobalVar]] + A map from functions to the subroutines they call. + + """ + return _ffi.CollectCallMap(module) diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc new file mode 100644 index 000000000000..9de36b0a28af --- /dev/null +++ b/src/ir/analysis.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/analysis.cc + * \brief Analysis functions that must span multiple IR types + */ +#include + +#include "../support/ordered_set.h" + +namespace tvm { +namespace ir { + +Map> CollectCallMap(const IRModule& mod) { + struct CalleeCollectorImpl : CalleeCollector { + void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } + support::OrderedSet gvars; + }; + + Map> call_map; + for (const auto& [gvar, base_func] : mod->functions) { + CalleeCollectorImpl collector; + CalleeCollector::vtable()(base_func, &collector); + call_map.Set(gvar, Array{collector.gvars.begin(), collector.gvars.end()}); + } + return call_map; +} + +TVM_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); + +} // namespace ir +} // namespace tvm diff --git a/src/ir/replace_global_var.cc b/src/ir/replace_global_var.cc new file mode 100644 index 000000000000..08d66d0e7cf2 --- /dev/null +++ b/src/ir/replace_global_var.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/replace_global_var.cc + * \brief IRModule transform to replace GlobalVar instances across any IR type. + */ + +#include + +#include + +namespace tvm { +namespace transform { + +IRModule ReplaceGlobalVar(IRModule mod, Map replacements) { + std::vector to_remove; + IRModule updates; + + const auto& vtable = GlobalVarReplacer::vtable(); + + for (const auto& [old_gvar, old_func] : mod->functions) { + auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar); + auto new_func = vtable(old_func, replacements); + + if (!new_gvar.same_as(old_gvar)) { + to_remove.push_back(old_gvar); + } + if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) { + updates->Add(new_gvar, new_func); + } + } + + if (to_remove.size() || updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + for (const auto& old_gvar : to_remove) { + write_ptr->Remove(old_gvar); + } + write_ptr->Update(updates); + } + return mod; +} + +TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/analysis/collect_call_map.cc b/src/relax/analysis/collect_call_map.cc new file mode 100644 index 000000000000..3e0170d3444d --- /dev/null +++ b/src/relax/analysis/collect_call_map.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relax/analysis/collect_call_map.cc + * + * \brief Collect cross-IR call graph + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { +using ir::CalleeCollector; + +struct Visitor : ExprVisitor { + explicit Visitor(CalleeCollector* collector) : collector(collector) {} + CalleeCollector* collector; + void VisitExpr_(const GlobalVarNode* node) override { collector->Mark(GetRef(node)); } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { + Visitor visitor{collector}; + visitor(Downcast(func)); + }); + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) {}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 9b2a561c7fec..a517d5a035e2 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -22,6 +22,8 @@ */ #include +#include +#include #include #include @@ -32,26 +34,46 @@ namespace transform { Pass AttachGlobalSymbol() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - mod.CopyOnWrite(); - String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); - std::vector > updates; + IRModule updates; + Map gvar_updates; + + for (const auto& [gvar, func] : mod->functions) { + Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); - for (auto& p : mod->functions) { - BaseFunc func = p.second; // TODO(tvm-team): re-enable once fix relax integration part - // if (func->GetAttr(tvm::attr::kGlobalSymbol)) continue; + // if (old_name) continue; + + Optional new_name; + BaseFunc new_func; + if (auto* prim_func = func.as()) { - updates.emplace_back(p.first, - WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, - c_prefix + p.first->name_hint)); + new_name = c_prefix + gvar->name_hint; + new_func = WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); } else if (auto* relax_func = func.as()) { - updates.emplace_back(p.first, WithAttr(GetRef(relax_func), - tvm::attr::kGlobalSymbol, p.first->name_hint)); + new_name = gvar->name_hint; + new_func = WithAttr(GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); + } + + if (new_name.defined() && (!old_name.defined() || old_name.value() != new_name.value())) { + updates->Add(gvar, new_func); + if (new_name.value() != gvar->name_hint) { + GlobalVar new_gvar(new_name.value()); + if (auto sinfo = gvar->struct_info_.as()) { + UpdateStructInfo(new_gvar, sinfo.value()); + } + + gvar_updates.Set(gvar, new_gvar); + } } } - for (const auto& pair : updates) { - mod->Add(pair.first, pair.second, true); + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + + if (gvar_updates.size()) { + mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates); + } } return mod; }; diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 9591b45595f9..4305554342ad 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -32,6 +32,7 @@ * Any binding blocks that are left empty will be removed by the normalizer. */ +#include #include #include #include @@ -42,89 +43,40 @@ namespace tvm { namespace relax { -/** - * \brief Detects all the functions that can be possibly called by entry function. - */ -class CallTracer : public ExprVisitor { - public: - explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} {} - - void VisitExpr_(const GlobalVarNode* op) final { - auto gvar = GetRef(op); - called_funcs_.insert(gvar); - if (auto func = mod_->functions.Get(gvar)) { - if (const auto* function_node = func.as()) { - VisitExpr(GetRef(function_node)); - } - // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. - } else { - // The GlobalVar is not contained in the IRModule. While the - // input IRModule is ill-formed, this specific case is allowed - // for use with `relax.transform.ApplyPassToFunction`. If this - // occurs, DCE should not remove any internal functions from the - // IRModule, as their removal is only valid if we have a - // complete call graph. - all_callees_found_ = false; - } - } +IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set& entry_funcs) { + auto call_map = ir::CollectCallMap(mod); + + std::unordered_set reachable = entry_funcs; + std::vector to_visit(entry_funcs.begin(), entry_funcs.end()); + bool all_callees_in_module = true; - void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } + while (to_visit.size()) { + GlobalVar visiting = to_visit.back(); + to_visit.pop_back(); - void VisitExpr_(const FunctionNode* func_node) final { - auto func = GetRef(func_node); - if (visiting_.find(func) == visiting_.end()) { - visiting_.insert(func); - for (auto param : func_node->params) { - ExprVisitor::VisitExpr(param); + if (auto it = call_map.find(visiting); it != call_map.end()) { + for (GlobalVar callee : (*it).second) { + if (!reachable.count(callee)) { + reachable.insert(callee); + to_visit.push_back(callee); + } } - ExprVisitor::VisitExpr(func_node->body); + } else { + all_callees_in_module = false; } } - void Trace(std::string entry) { - called_funcs_.insert(mod_->GetGlobalVar(entry)); - auto main_func = mod_->Lookup(entry); - VisitExpr(main_func); - } - - /* \brief Check if a function is unreachable - * - * \param gvar The function to be checked - * - * \return True if the function can be proven to be unreachable, - * either directly or indirectly, from an external caller. - * Otherwise, false. - */ - bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const { - return all_callees_found_ && !called_funcs_.count(gvar); - } - - private: - IRModule mod_; - - /* \brief Whether all callees could be located within the IRModule */ - bool all_callees_found_{true}; - - // Record the names of all encountered functions. - std::unordered_set called_funcs_; - - // Record the expressions that are being visited. - std::unordered_set visiting_; -}; - -IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set& entry_funcs) { - CallTracer tracer(mod); - for (const auto& gvar : entry_funcs) { - tracer.VisitExpr(gvar); + if (!all_callees_in_module) { + return mod; } std::vector to_remove; - for (const auto& kv : mod->functions) { + for (const auto& [gvar, func] : mod->functions) { // The tracer contains all user-provided entry functions, all // externally-callable functions, and anything that is directly or // indirectly accessible from an entry function. - if (tracer.CheckIfProvablyUnreachable(kv.first)) { - to_remove.push_back(kv.first); + if (!reachable.count(gvar)) { + to_remove.push_back(gvar); } } diff --git a/src/relax/transform/replace_global_var.cc b/src/relax/transform/replace_global_var.cc new file mode 100644 index 000000000000..b81b831036ff --- /dev/null +++ b/src/relax/transform/replace_global_var.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relax/transform/replace_global_var.cc + * + * \brief GlobalVar replacement across IR types + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { +using tvm::transform::GlobalVarReplacer; + +struct Mutator : ExprMutator { + Map replacements; + explicit Mutator(Map replacements) : replacements(replacements) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* node) override { + auto gvar = GetRef(node); + return replacements.Get(gvar).value_or(gvar); + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& func, + Map replacements) -> BaseFunc { + Mutator mutator(replacements); + return Downcast(mutator(Downcast(func))); + }); + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& func, + Map) -> BaseFunc { + return Downcast(func); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/tir/analysis/collect_call_map.cc b/src/tir/analysis/collect_call_map.cc new file mode 100644 index 000000000000..98f7585c6b79 --- /dev/null +++ b/src/tir/analysis/collect_call_map.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/tir/analysis/collect_call_map.cc + * + * \brief Collect cross-IR call graph + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +using ir::CalleeCollector; + +struct Visitor : StmtExprVisitor { + explicit Visitor(CalleeCollector* collector) : collector(collector) {} + CalleeCollector* collector; + void VisitExpr_(const CallNode* node) override { + StmtExprVisitor::VisitExpr_(node); + if (auto opt_gvar = node->op.as()) { + collector->Mark(opt_gvar.value()); + } + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { + Visitor visitor{collector}; + visitor(Downcast(func)->body); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/replace_global_var.cc b/src/tir/transforms/replace_global_var.cc new file mode 100644 index 000000000000..8ef8ba9276b0 --- /dev/null +++ b/src/tir/transforms/replace_global_var.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/tir/transforms/replace_global_var.cc + * + * \brief GlobalVar replacement across IR types + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +using tvm::transform::GlobalVarReplacer; + +struct Mutator : StmtExprMutator { + Map replacements; + explicit Mutator(Map replacements) : replacements(replacements) {} + + PrimExpr VisitExpr_(const CallNode* node) override { + auto call = Downcast(StmtExprMutator::VisitExpr_(node)); + if (auto old_gvar = call->op.as()) { + if (auto new_gvar = replacements.Get(old_gvar.value())) { + call.CopyOnWrite()->op = new_gvar.value(); + } + } + return call; + } +}; + +} // namespace + +TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) + .set_dispatch([](const ObjectRef& obj, + Map replacements) -> BaseFunc { + Mutator mutator(replacements); + auto func = Downcast(obj); + auto new_body = mutator(func->body); + + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } + return func; + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/ir/analysis/test_collect_call_map.py b/tests/python/ir/analysis/test_collect_call_map.py new file mode 100644 index 000000000000..9068bffc5fe0 --- /dev/null +++ b/tests/python/ir/analysis/test_collect_call_map.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, List + +import tvm +import tvm.testing +from tvm.ir import GlobalVar + +from tvm.script import ir as I, tir as T, relax as R + +from tvm.ir.analysis import collect_call_map + + +def _build_str_map(call_map: Dict[GlobalVar, List[GlobalVar]]) -> Dict[str, List[str]]: + return { + caller.name_hint: [callee.name_hint for callee in callees] + for caller, callees in call_map.items() + } + + +def test_collect_relax_to_relax(): + @I.ir_module + class Module: + @R.function + def main(): + return Module.subroutine() + + @R.function + def subroutine(): + return R.tuple() + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +def test_collect_relax_to_tir(): + @I.ir_module + class Module: + @R.function + def main() -> R.Prim("int32"): + return Module.subroutine(R.prim_value(T.int32(42))) + + @T.prim_func + def subroutine(i: T.int32) -> T.int32: + return i + 1 + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +def test_collect_tir_to_tir(): + @I.ir_module + class Module: + @T.prim_func + def main() -> T.int32: + return Module.subroutine(42) + + @T.prim_func + def subroutine(i: T.int32) -> T.int32: + return i + 1 + + call_map = collect_call_map(Module) + str_map = _build_str_map(call_map) + expected = { + "main": ["subroutine"], + "subroutine": [], + } + assert str_map == expected + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index 680df969474a..39f6d061f721 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -89,7 +89,7 @@ def test_system_lib_prefix(): class Before: I.module_attrs({"system_lib_prefix": "hello_"}) - @T.prim_func + @T.prim_func(private=True) def tir_zeros(x: T.Buffer((2), "float32")) -> None: x[0] = T.float32(0) @@ -103,13 +103,13 @@ class Expected: I.module_attrs({"system_lib_prefix": "hello_"}) @T.prim_func - def tir_zeros(x: T.Buffer((2), "float32")) -> None: + def hello_tir_zeros(x: T.Buffer((2), "float32")) -> None: T.func_attr({"global_symbol": "hello_tir_zeros"}) x[0] = T.float32(0) @R.function def main() -> R.Tensor: - gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), dtype="float32")) + gv0 = R.call_tir(Expected.hello_tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 before = Before diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 142faf51607b..04a4379d77f6 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -346,6 +346,42 @@ def main( assert check_if_func_exists(new_mod, "unused_func") +def test_preserve_indirectly_used_prim_func(): + @tvm.script.ir_module + class InputModule: + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir( + InputModule.tir_add_tensors, + [x, w], + out_sinfo=R.Tensor((16, 16), "float32"), + ) + return gv0 + + @T.prim_func(private=True) + def tir_add_tensors( + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), + ): + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = InputModule.tir_add_float32(x[vi, vj], y[vi, vj]) + + @T.prim_func(private=True) + def tir_add_float32(x: T.float32, y: T.float32) -> T.float32: + return x + y + + mod = InputModule + assert mod + new_mod = DeadCodeElimination()(mod) + + tvm.ir.assert_structural_equal(mod, new_mod) + + def test_multiple_unused_funcs(): @tvm.script.ir_module class InputModule: @@ -399,7 +435,11 @@ def main( ) lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( - lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv0, + lv1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) lv3: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( lv2, axes=[0, 3, 1, 2] @@ -428,7 +468,11 @@ def main( ) lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( - lv0, lv1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv0, + lv1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) R.output(lv2) gv3 = R.astype(lv2, dtype="float16") @@ -464,7 +508,11 @@ def main( gv_w, axes=[0, 2, 3, 1] ) lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( - lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv1, + lv2, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) # dead instruction -> usee lv1 also dead. lv4: R.Tensor((2, 3, 28, 28), dtype="float32") = R.permute_dims( @@ -491,7 +539,11 @@ def main( gv_w, axes=[0, 2, 3, 1] ) lv3: R.Tensor((2, 26, 26, 4), dtype="float16") = R.nn.conv2d( - lv1, lv2, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC" + lv1, + lv2, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", ) R.output(lv3) return lv3