diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 48800b193cb4..fffcab49667c 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -29,6 +29,7 @@ #ifndef TVM_DRIVER_DRIVER_API_H_ #define TVM_DRIVER_DRIVER_API_H_ +#include #include #include #include @@ -99,6 +100,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param global_var_supply The GlobalVarSupply to be used in the module. * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ @@ -106,7 +108,7 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool simple_mode = false); + GlobalVarSupply global_var_supply, bool simple_mode = false); /*! * \brief Build an IRModule given a TE schedule, args and binds. This function also applies @@ -115,13 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, * \param args The arguments to the function (Array of Tensor, Buffer and Vars) * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param global_var_supply The GlobalVarSupply to be used in the module. * \param simple_mode Disables the loop partition pass. Defaults to false. * \return The result module. */ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, - bool simple_mode = false); + GlobalVarSupply global_var_supply, bool simple_mode = false); /*! * \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want @@ -130,10 +133,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. + * \param global_var_supply The GlobalVarSupply to be used in the module and when creating + * GlobalVars. * \return The result module. */ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds); + const std::unordered_map& binds, + GlobalVarSupply global_var_supply); /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h new file mode 100644 index 000000000000..276c64a0d753 --- /dev/null +++ b/include/tvm/ir/global_var_supply.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/global_var_supply.h + * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar. + */ +#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_ +#define TVM_IR_GLOBAL_VAR_SUPPLY_H_ + +#include +#include + +#include "tvm/ir/expr.h" +#include "tvm/ir/module.h" +#include "tvm/ir/name_supply.h" + +namespace tvm { + +/*! + * \brief GlobalVarSupply can be used to generate unique GlobalVars. + */ +class GlobalVarSupplyNode : public Object { + public: + /*! + * \brief Empty constructor. Will use an empty NameSupply. + */ + GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {} + + /*! + * \brief Constructor. + * \param name_supply The NameSupply to use for generating the names of fresh GlobalVars. + * \param name_to_var_map An optional map. + */ + explicit GlobalVarSupplyNode(NameSupply name_supply, + std::unordered_map name_to_var_map = {}); + + /*! + * \brief Generates a unique GlobalVar from this supply. + * \param name The name from which the name of the GlobalVar is derived. + * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended + * to the name. \return A unique GlobalVar. + */ + GlobalVar FreshGlobal(String name, bool add_prefix = true); + + /*! + * \brief Looks up for a GlobalVar with the given name in this supply. + * If no entry is found, creates one, places it in the cache and returns it. + * \param name The name of the GlobalVar to search for. + * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to + * the name before performing the search. \return A cached GlobalVar. + */ + GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true); + + /*! + * \brief Reserves an existing GlobalVar with this supply. + * \param var The GlobalVar to be registered. + * \param allow_conflict Allow conflict with other GlobalVars that have the same name. + */ + void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); + + void VisitAttrs(AttrVisitor* v) {} + + /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ + NameSupply name_supply_; + + static constexpr const char* _type_key = "GlobalVarSupply"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object); + + private: + std::unordered_map name_to_var_map_; +}; + +/*! + * \brief Managed reference class to GlobalVarSupplyNode. + * \sa GlobalVarSupplyNode + */ +class GlobalVarSupply : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param name_supply The NameSupply to be used when generating new GlobalVars. + * \param name_to_var_map An optional map. + */ + TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply, + std::unordered_map name_to_var_map = {}); + + /*! + * \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are + * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array + * of IRModules. + */ + TVM_DLL explicit GlobalVarSupply(const Array& modules); + + /*! + * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are + * guaranteed not to conflict with GlobalVars that belong to the modules. \param module The + * IRModule. + */ + TVM_DLL explicit GlobalVarSupply(const IRModule module); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode); +}; + +} // namespace tvm + +#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index f73f2230df4d..7313b4f78349 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -323,14 +323,6 @@ class IRModuleNode : public Object { /*! \brief Helper function for registering a typedef's constructors */ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); - /*! - * \brief Returns a version of \p name which is unique amongst all function definitions in module. - * - * \param name The original name. - * \return Updated name which is unique. - */ - String GetUniqueName(const String& name); - /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -481,6 +473,15 @@ namespace attr { // Following are attributes for IRModule only. +/*! + * \brief Name of the module + * + * Type: String + * + * \sa tvm::runtime::String + */ +constexpr const char* kModuleName = "mod_name"; + /*! * \brief Executor targeted by the module * diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h new file mode 100644 index 000000000000..a85a6fe70a66 --- /dev/null +++ b/include/tvm/ir/name_supply.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/name_supply.h + * \brief NameSupply that can be used to generate unique variable names. + */ +#ifndef TVM_IR_NAME_SUPPLY_H_ +#define TVM_IR_NAME_SUPPLY_H_ + +#include +#include +#include + +#include "tvm/ir/expr.h" + +namespace tvm { + +/*! + * \brief NameSupply can be used to generate unique names. + */ +class NameSupplyNode : public Object { + public: + /*! + * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. + */ + NameSupplyNode() = default; + + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this NameSupply. + * \param name_map The map used to guarantee uniqueness. + */ + NameSupplyNode(const String& prefix, std::unordered_map name_map) + : prefix_(prefix), name_map(std::move(name_map)) {} + + /*! + * \brief Generates a unique name from this NameSupply. + * \param name The name from which the generated name is derived. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name. \return A unique name. + */ + String FreshName(const String& name, bool add_prefix = true); + + /*! + * \brief Reserves an existing name with this NameSupply. + * \param name The name to be reserved. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before reserving it. \return The name that was reserved with the NameSupply. It can be + * different if a prefix is added. + */ + String ReserveName(const String& name, bool add_prefix = true); + + /*! + * \brief Checks if this NameSupply already generated a name. + * \param name The name to check. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before checking for it. \return True if the name has already been generated. False + * otherwise. + */ + bool ContainsName(const String& name, bool add_prefix = true); + + void VisitAttrs(AttrVisitor* v) {} + + // Prefix for all GlobalVar names. It can be empty. + std::string prefix_; + + static constexpr const char* _type_key = "NameSupply"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); + + private: + /*! \brief Helper function to add the NameSupply prefix to the name. */ + String add_prefix_to_name(const String& name); + + /*! + * \brief Function that will generate a unique name. + * \param name The name to be used as a base. + * \return A unique name. + */ + std::string GetUniqueName(std::string name); + + /*! \brief A map that is used to generate unique names. */ + std::unordered_map name_map; +}; + +/*! + * \brief Managed reference class to NameSupplyNode. + * \sa NameSupplyNode + */ +class NameSupply : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this NameSupply. + * \param name_map An optional map. + */ + TVM_DLL explicit NameSupply(const String& prefix, + std::unordered_map name_map = {}); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode); +}; + +} // namespace tvm + +#endif // TVM_IR_NAME_SUPPLY_H_ diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py new file mode 100644 index 000000000000..095ac43c03b8 --- /dev/null +++ b/python/tvm/ir/supply.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" +import tvm +from tvm import Object, IRModule +from . import _ffi_api + + +@tvm._ffi.register_object("NameSupply") +class NameSupply(Object): + """NameSupply that can be used to generate unique names. + + Parameters + ---------- + prefix: The prefix to be added to the generated names. + """ + + def __init__(self, prefix=""): + self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) + + def fresh_name(self, name, add_prefix=True): + """Generates a unique name from this NameSupply. + + Parameters + ---------- + name: String + The name from which the generated name is derived. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name. + """ + return _ffi_api.NameSupply_FreshName(self, name, add_prefix) + + def reserve_name(self, name, add_prefix=True): + """Reserves an existing name with this NameSupply. + + Parameters + ---------- + name: String + The name to be reserved. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before reserving it. + """ + return _ffi_api.NameSupply_ReserveName(self, name, add_prefix) + + def contains_name(self, name, add_prefix=True): + """Checks if this NameSupply already generated a name. + + Parameters + ---------- + name: String + The name to check. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before checking for it. + """ + return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) + + +@tvm._ffi.register_object("GlobalVarSupply") +class GlobalVarSupply(Object): + """GlobalVarSupply that holds a mapping between names and GlobalVars. + + GlobalVarSupply can be used to generate new GlobalVars with a unique name. + It also can be used to retrieve previously generated GlobalVars based on a name. + + Parameters + ---------- + value: Union[List[IRModule], IRModule, NameSupply] + The IRModules used to build this GlobalVarSupply or a NameSupply. + """ + + def __init__(self, value=None): + if value is None: + name_supply = NameSupply("") + self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply) + elif isinstance(value, NameSupply): + self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value) + elif isinstance(value, (list, tvm.container.Array)): + self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value) + elif isinstance(value, IRModule): + self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value) + + def fresh_global(self, name, add_prefix=True): + """Generates a unique GlobalVar from this supply. + + Parameters + ---------- + name: String + The name from which the name of the GlobalVar is derived. + + add_prefix: bool + If set to true, then the prefix of the contained NameSupply will be prepended + to the name. + """ + return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix) + + def unique_global_for(self, name, add_prefix=True): + """Looks up for a GlobalVar with the given name in this supply. If no entry is found + , creates one, places it in the cache and returns it. + + Parameters + ---------- + name: String + The name of the GlobalVar to search for. + + add_prefix: bool + If set to true, the prefix of the contained NameSupply will be prepended to the + name before performing the search. + """ + return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix) + + def reserve_global(self, global_var, allow_conflict=False): + """Reserves an existing GlobalVar with this supply. + + Parameters + ---------- + global_var: GlobalVar + The GlobalVar to be registered. + + allow_conflict: bool + Allow conflict with other GlobalVars that have the same name + """ + return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index ab60aef9ae1f..c930bf0c4e73 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1371,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i auto pass_ctx = tvm::transform::PassContext::Current(); auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, - std::unordered_map()); + std::unordered_map(), + GlobalVarSupply(NameSupply(""))); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 24c7ee74cdcf..79c9e567b459 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -42,24 +42,6 @@ std::string dot_to_underscore(std::string s) { return s; } -std::string CodeGenHybrid::GetUniqueName(std::string prefix) { - prefix = dot_to_underscore(prefix); - auto it = ids_allocated_.find(prefix); - if (it != ids_allocated_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (ids_allocated_.count(name) == 0) { - prefix = name; - break; - } - } - } - ids_allocated_[prefix] = 0; - return prefix; -} - std::string CodeGenHybrid::Finish() { return stream.str(); } void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { @@ -428,7 +410,7 @@ std::string CodeGenHybrid::GetVarID(const VarNode* v) { if (id_map_.count(key)) { return id_map_[key]; } - return id_map_[key] = GetUniqueName(v->name_hint); + return id_map_[key] = ids_allocated->FreshName(v->name_hint); } std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { @@ -440,57 +422,57 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { if (tensor->op->num_outputs() > 1) { name_hint += "_v" + std::to_string(tensor->value_index); } - return id_map_[key] = GetUniqueName(name_hint); + return id_map_[key] = ids_allocated->FreshName(name_hint); } void CodeGenHybrid::ReserveKeywords() { - GetUniqueName("def"); - GetUniqueName("for"); - GetUniqueName("in"); - GetUniqueName("range"); - GetUniqueName("True"); - GetUniqueName("False"); - GetUniqueName("unroll"); - GetUniqueName("const_range"); - GetUniqueName("parallel"); - GetUniqueName("vectorize"); - GetUniqueName("bind"); - GetUniqueName("threadIdx.x"); - GetUniqueName("threadIdx.y"); - GetUniqueName("threadIdx.z"); - GetUniqueName("blockIdx.x"); - GetUniqueName("blockIdx.y"); - GetUniqueName("blockIdx.z"); - GetUniqueName("vthread"); - GetUniqueName("allocate"); - GetUniqueName("output_tensor"); - GetUniqueName("sqrt"); - GetUniqueName("log"); - GetUniqueName("tanh"); - GetUniqueName("power"); - GetUniqueName("exp"); - GetUniqueName("sigmoid"); - GetUniqueName("popcount"); - GetUniqueName("likely"); - GetUniqueName("int8"); - GetUniqueName("int16"); - GetUniqueName("int32"); - GetUniqueName("int64"); - GetUniqueName("uint8"); - GetUniqueName("uint16"); - GetUniqueName("uint32"); - GetUniqueName("uint64"); - GetUniqueName("float16"); - GetUniqueName("float32"); - GetUniqueName("float64"); - GetUniqueName("ceil_div"); - GetUniqueName("max_num_threads"); + ids_allocated->ReserveName("def"); + ids_allocated->ReserveName("for"); + ids_allocated->ReserveName("in"); + ids_allocated->ReserveName("range"); + ids_allocated->ReserveName("True"); + ids_allocated->ReserveName("False"); + ids_allocated->ReserveName("unroll"); + ids_allocated->ReserveName("const_range"); + ids_allocated->ReserveName("parallel"); + ids_allocated->ReserveName("vectorize"); + ids_allocated->ReserveName("bind"); + ids_allocated->ReserveName("threadIdx.x"); + ids_allocated->ReserveName("threadIdx.y"); + ids_allocated->ReserveName("threadIdx.z"); + ids_allocated->ReserveName("blockIdx.x"); + ids_allocated->ReserveName("blockIdx.y"); + ids_allocated->ReserveName("blockIdx.z"); + ids_allocated->ReserveName("vthread"); + ids_allocated->ReserveName("allocate"); + ids_allocated->ReserveName("output_tensor"); + ids_allocated->ReserveName("sqrt"); + ids_allocated->ReserveName("log"); + ids_allocated->ReserveName("tanh"); + ids_allocated->ReserveName("power"); + ids_allocated->ReserveName("exp"); + ids_allocated->ReserveName("sigmoid"); + ids_allocated->ReserveName("popcount"); + ids_allocated->ReserveName("likely"); + ids_allocated->ReserveName("int8"); + ids_allocated->ReserveName("int16"); + ids_allocated->ReserveName("int32"); + ids_allocated->ReserveName("int64"); + ids_allocated->ReserveName("uint8"); + ids_allocated->ReserveName("uint16"); + ids_allocated->ReserveName("uint32"); + ids_allocated->ReserveName("uint64"); + ids_allocated->ReserveName("float16"); + ids_allocated->ReserveName("float32"); + ids_allocated->ReserveName("float64"); + ids_allocated->ReserveName("ceil_div"); + ids_allocated->ReserveName("max_num_threads"); } void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, const Array& outputs, const std::string& name) { ReserveKeywords(); - GetUniqueName(name); + ids_allocated->ReserveName(name); stream << "def " << name << "("; for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index da45ffb6a8ce..53026c7fc3b3 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -24,6 +24,7 @@ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ +#include #include #include #include @@ -145,19 +146,14 @@ class CodeGenHybrid : public ExprFunctor, const int tab_{4}; /*! \brief Print the current indent spaces. */ inline void PrintIndent(); - /*! \brief Keys are ids allocated, and values are the suffix to prevent double-name. */ - std::map ids_allocated_; + /*! \brief NameSupply for allocated ids. */ + NameSupply ids_allocated = NameSupply(""); /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ std::map binds_; - /*! - * \brief Find an unallocated name for the given prefix. - * \param prefix The given prefix. - */ - std::string GetUniqueName(std::string prefix); /*! \brief The output code string builder. */ std::stringstream stream; /*! diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 6f4fb618d334..cbf809a267a6 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -261,7 +261,8 @@ IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { // Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds) { + const std::unordered_map& binds, + GlobalVarSupply global_var_supply) { sch = sch.normalize(); transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -289,7 +290,8 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const if (noalias) { f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } - return IRModule(Map({{GlobalVar(name), f}})); + GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false); + return IRModule(Map({{global_var, f}})); } TVM_REGISTER_GLOBAL("driver.schedule_to_module") @@ -302,7 +304,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") c_binds.insert({kv.first, kv.second}); } } - IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); + IRModule mod = + ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply(""))); return mod; }); @@ -337,17 +340,19 @@ TVM_REGISTER_GLOBAL("driver.lower_primfunc") }); IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { + const std::unordered_map& binds, + GlobalVarSupply global_var_supply, bool simple_mode) { Array ref_args; for (ObjectRef x : args) { ref_args.push_back(x); } - return LowerSchedule(std::move(sch), ref_args, name, binds); + return LowerSchedule(std::move(sch), ref_args, name, binds, global_var_supply); } IRModule LowerSchedule(te::Schedule sch, const Array& args, const std::string& name, - const std::unordered_map& binds, bool simple_mode) { - IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); + const std::unordered_map& binds, + GlobalVarSupply global_var_supply, bool simple_mode) { + IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply); // Get the legacy TE pass list Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(mod, pass_list); @@ -363,7 +368,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert({kv.first, kv.second}); } } - return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); + return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")), + simple_mode); }); /** diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc new file mode 100644 index 000000000000..383d4445adcf --- /dev/null +++ b/src/ir/global_var_supply.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file global_var_supply.cc + * \brief GlobalVarSupply that can be used to generate unique GlobalVars. + */ +#include "tvm/ir/global_var_supply.h" + +#include + +#include + +#include "tvm/ir/expr.h" + +namespace tvm { +GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, + std::unordered_map name_to_var_map) { + auto n = make_object(name_supply, name_to_var_map); + data_ = std::move(n); +} + +std::string GetModuleName(const IRModule& module) { + return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); +} + +GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply(NameSupply("")) { + if (!modules.empty()) { + IRModule first_mod = modules.front(); + this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); + } + for (auto& mod : modules) { + for (auto kv : mod->functions) { + this->operator->()->ReserveGlobalVar(kv.first); + } + } +} + +GlobalVarSupply::GlobalVarSupply(const IRModule module) + : GlobalVarSupply(Array{module}) {} + +void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { + name_supply_->ReserveName(var->name_hint, false); + if (!allow_conflict) { + ICHECK(name_to_var_map_.count(var->name_hint) == 0) + << "GlobalVar " << var << " conflicts by name in this supply."; + } + name_to_var_map_[var->name_hint] = var; +} + +GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, + std::unordered_map name_to_var_map) + : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} + +GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) { + String final_name = name_supply_->ReserveName(name, add_prefix); + + auto it = name_to_var_map_.find(final_name); + if (it != name_to_var_map_.end()) { + return it->second; + } else { + GlobalVar var = GlobalVar(final_name); + name_to_var_map_.emplace(final_name, var); + return var; + } +} + +GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { + String final_name = name_supply_->FreshName(name, add_prefix); + ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) + << "GlobalVar already exists for name " << final_name; + GlobalVar var = GlobalVar(final_name); + name_to_var_map_.emplace(final_name, var); + return var; +} + +TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") + .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { + return GlobalVarSupply(std::move(mod)); +}); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array& mods) { + return GlobalVarSupply(mods); +}); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") + .set_body_method(&GlobalVarSupplyNode::FreshGlobal); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") + .set_body_method(&GlobalVarSupplyNode::UniqueGlobalFor); + +TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") + .set_body_method(&GlobalVarSupplyNode::ReserveGlobalVar); + +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 6f2c9f9fe994..8d6de5a536a7 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,6 +21,7 @@ * \file module.cc * \brief The global module in Relay. */ +#include #include #include #include @@ -292,20 +293,6 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } -String IRModuleNode::GetUniqueName(const String& name) { - String result = name; - int suffix = 0; - while (true) { - auto it = global_var_map_.find(result); - if (it == global_var_map_.end()) { - return result; - } - std::ostringstream os; - os << name << "_" << ++suffix; - result = os.str(); - } -} - /*! * \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs * ('one') side above the rhs ('two'). @@ -397,12 +384,14 @@ std::pair IRModule::FromExprInContext( func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } + GlobalVar main_gv; + auto global_var_supply = GlobalVarSupply(mod); if (gv_name.empty()) { // Bind function to 'main' (though rename if would clash with existing 'main'). - gv_name = mod->GetUniqueName("main"); + main_gv = global_var_supply->FreshGlobal("main", false); + } else { + main_gv = global_var_supply->UniqueGlobalFor(gv_name, false); } - - GlobalVar main_gv(gv_name); mod->Add(main_gv, func); return {mod, main_gv}; } diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc new file mode 100644 index 000000000000..93f568253cba --- /dev/null +++ b/src/ir/name_supply.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file name_supply.cc + * \brief NameSupply that can be used to generate unique variable names. + */ +#include "tvm/ir/name_supply.h" + +#include + +#include + +namespace tvm { + +NameSupply::NameSupply(const String& prefix, std::unordered_map name_map) { + auto n = make_object(prefix, std::move(name_map)); + data_ = std::move(n); +} + +String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { + String final_name = name; + if (add_prefix) { + final_name = add_prefix_to_name(name); + } + name_map[final_name] = 0; + return final_name; +} + +String NameSupplyNode::FreshName(const String& name, bool add_prefix) { + String unique_name = name; + if (add_prefix) { + unique_name = add_prefix_to_name(name); + } + unique_name = GetUniqueName(unique_name); + return unique_name; +} + +bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { + String unique_name = name; + if (add_prefix) { + unique_name = add_prefix_to_name(name); + } + + return name_map.count(unique_name); +} + +String NameSupplyNode::add_prefix_to_name(const String& name) { + if (prefix_.empty()) { + return name; + } + + std::ostringstream ss; + ICHECK(name.defined()); + ss << prefix_ << "_" << name; + return ss.str(); +} + +std::string NameSupplyNode::GetUniqueName(std::string name) { + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + auto it = name_map.find(name); + if (it != name_map.end()) { + auto new_name = name; + while (!name_map.insert({new_name, 0}).second) { + std::ostringstream os; + os << name << "_" << (++it->second); + new_name = os.str(); + } + return new_name; + } + name_map[name] = 0; + return name; +} + +TVM_REGISTER_NODE_TYPE(NameSupplyNode); + +TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { + return NameSupply(prefix); +}); + +TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName") + .set_body_method(&NameSupplyNode::FreshName); + +TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName") + .set_body_method(&NameSupplyNode::ReserveName); + +TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName") + .set_body_method(&NameSupplyNode::ContainsName); + +} // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index c72511775acd..ab725d82e676 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -466,7 +466,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorFreshName(func_name); auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); return AddNode(node, call); } @@ -604,22 +604,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorEndObject(); } - /*! - * \brief Get unique name for func - * - * \param name - * \return std::string - */ - std::string _GetUniqueName(const std::string& name) { - if (!name_map_.count(name)) { - name_map_[name] = 1; - return name; - } - auto index = name_map_[name]; - name_map_[name] += 1; - return _GetUniqueName(name + std::to_string(index)); - } - protected: /*! \brief nodes */ std::vector nodes_; @@ -645,8 +629,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; - /*! \brief name map */ - std::unordered_map name_map_; + /*! \brief NameSupply */ + NameSupply name_supply_ = NameSupply(""); }; class GraphExecutorCodegenModule : public runtime::ModuleNode { diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index af4b49b4f1da..c577e8e356d6 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -74,9 +74,9 @@ Array ExtractTask( }); // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); - std::unordered_map name_map; + NameSupply name_supply = NameSupply(""); for (ExtractedTask task : tasks) { - task->task_name = tec::GetUniqueName(task->task_name, &name_map); + task->task_name = name_supply->FreshName(task->task_name); } return tasks; } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 8ca5a32b7fb9..5c79ed2070cc 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -92,6 +92,7 @@ #include #include #include +#include #include #include #include @@ -134,30 +135,33 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode); class TECompilerImpl : public TECompilerNode { public: - explicit TECompilerImpl(Optional opt_mod) { + explicit TECompilerImpl(Optional opt_mod, Optional opt_mod_name) { + String mod_name = opt_mod_name.value_or(""); + NameSupply name_supply = NameSupply(mod_name /* prefix */); + global_var_supply_ = GlobalVarSupply(name_supply); // Make sure we don't collide with any existing globals in the module. if (opt_mod) { for (const auto& kv : opt_mod.value()->functions) { - name_map_[kv.first->name_hint] = 1; + global_var_supply_->name_supply_->ReserveName(kv.first->name_hint, false); } } } // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; + CachedFunc Lower(const CCacheKey& key) { + return LowerInternal(key, global_var_supply_)->cached_func; } + // TODO(gigiblender): Only to be called by the global TE compiler. + // Remove this when the global TE compiler is removed. CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); + global_var_supply_->name_supply_->prefix_ = mod_name; + return LowerInternal(key, global_var_supply_)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); + CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply(""))); if (value->packed_func != nullptr) { return value->packed_func; } @@ -335,7 +339,7 @@ class TECompilerImpl : public TECompilerNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { + CCacheValue LowerInternal(const CCacheKey& key, GlobalVarSupply global_var_supply) { VLOG(1) << "lowering:" << std::endl << PrettyPrint(key->source_func) << std::endl << "for target:" << std::endl @@ -360,7 +364,7 @@ class TECompilerImpl : public TECompilerNode { if (opt_compiler.defined()) { // Don't compile now since we don't have anywhere to put the resulting runtime module. // Instead place the original definition in the cache and wait for LowerExternalFunctions. - IRModule ir_module; + IRModule ir_module({}, {}); Optional opt_global_symbol = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(opt_global_symbol.defined()) << "External function has not been attached a name yet."; @@ -369,7 +373,7 @@ class TECompilerImpl : public TECompilerNode { // the module's globals. Furthermore, the external codegen tool must bind the compiled // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName // here. - auto global_var = GlobalVar(opt_global_symbol.value()); + auto global_var = global_var_supply->UniqueGlobalFor(opt_global_symbol.value(), false); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr}, @@ -388,10 +392,7 @@ class TECompilerImpl : public TECompilerNode { With target_scope(key->target); ICHECK(!value->cached_func.defined()); - value->cached_func = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - auto mangled = mangle_fn(name); - return GetUniqueName(mangled, &name_map_); - }); + value->cached_func = PrimFuncFor(key->source_func, key->target, global_var_supply); if (value->cached_func->prim_func.defined()) { VLOG(1) << "Lowering PrimFunc"; @@ -443,16 +444,11 @@ class TECompilerImpl : public TECompilerNode { } auto func_name = value->cached_func->prim_fn_var->name_hint; VLOG(1) << "scheduling"; - IRModule scheduled_module = - tvm::LowerSchedule(value->cached_func->schedule, all_args, func_name, binds); + IRModule scheduled_module = tvm::LowerSchedule(value->cached_func->schedule, all_args, + func_name, binds, global_var_supply); scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module)); - // Unfortunately the above machinery creates its own GlobalVars instead of using *the* - // GlobalVar we established above. Fix this before the confusion spreads any further. - // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name. for (const auto& kv : scheduled_module->functions) { - GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint - ? value->cached_func->prim_fn_var - : kv.first; + GlobalVar global_var = kv.first; auto func = kv.second; // Propagate the structural hash of the relay function to the tir // function so associations can be made between the two. @@ -498,9 +494,7 @@ class TECompilerImpl : public TECompilerNode { using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); - value->cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); + value->cached_func = ShapeFuncFor(key->source_func, key->target, global_var_supply_); ICHECK( value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as()); @@ -527,8 +521,8 @@ class TECompilerImpl : public TECompilerNode { /*! \brief compiler cache lock*/ std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; + /*! \brief internal GlobalVarSupply to get unique GlobalVars */ + GlobalVarSupply global_var_supply_; /*! \brief internal compiler cache */ std::unordered_map cache_; /*! \brief internal compiler cache for shape funcs */ @@ -539,15 +533,16 @@ class TECompilerImpl : public TECompilerNode { Map device_contexts_; }; -TECompiler::TECompiler(Optional opt_mod) { - auto object = make_object(std::move(opt_mod)); +TECompiler::TECompiler(Optional opt_mod, Optional mod_name) { + auto object = make_object(std::move(opt_mod), std::move(mod_name)); data_ = object; } /*! \brief The global TE compiler */ // TODO(mbs): To be terminated with extreme prejudice. TECompiler& TECompiler::Global() { - static TECompiler* inst = new TECompiler(make_object(Optional())); + static TECompiler* inst = + new TECompiler(make_object(Optional(), Optional())); return *inst; } TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); @@ -629,12 +624,11 @@ using AnalysisRemapping = std::unordered_map(primitive_func), target, GetVirtualDevice(GetRef(call_node))); - CachedFunc cfunc = compiler_->Lower(key, module_name_); + CachedFunc cfunc = compiler_->Lower(key); ICHECK(cfunc.defined()); return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), call_node->span, target, cfunc->funcs->functions); @@ -942,17 +936,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // module we'll ultimately emit for each required device-type. Note that a primitive may be // lowered for multiple device types, each which will be assigned a fresh var. std::unordered_map primitive_functions_; - String module_name_; TECompiler compiler_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; -Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn, - CompilationConfig config) { +Pass LowerTensorExpr(TECompiler compiler, ProcessFn process_fn, CompilationConfig config) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, process_fn, config, module_name, compiler); + LowerTensorExprMutator lower_te(module, process_fn, config, compiler); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -1184,7 +1176,7 @@ void UpdateFunctionMetadata(BaseFunc func, /*! \brief Main lowering driving. */ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, CompilationConfig config) { - TECompiler compiler(module); + TECompiler compiler(module, module_name); // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten // module as we go (including rewritten Functions, lowered primitives, and runtime modules @@ -1199,7 +1191,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). IRModule updated_module = - LowerTensorExpr(module_name, compiler, std::move(process_fn), std::move(config))(module); + LowerTensorExpr(compiler, std::move(process_fn), std::move(config))(module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 5d16da4b8bb2..f2ba84014a09 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -73,7 +73,7 @@ class TECompilerNode : public Object { * \param key The key to the cached function. * \return The result. */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + virtual CachedFunc Lower(const CCacheKey& key) = 0; /*! * \brief Get lowered result. @@ -137,7 +137,7 @@ class TECompilerNode : public Object { /*! \brief cache entry used in compile engine */ class TECompiler : public ObjectRef { public: - explicit TECompiler(Optional opt_mod = {}); + explicit TECompiler(Optional opt_mod = {}, Optional mod_name = {}); explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} TECompilerNode* operator->() { return static_cast(get_mutable()); } using ContainerType = TECompilerNode; diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index bfb351f82b78..da52d94b4e46 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -137,8 +137,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator Lower(const Function& relay_func, - std::function renamer) { + Array Lower(const Function& relay_func) { for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { @@ -327,15 +326,15 @@ class ScheduleBuilder : public ExprVisitor { } } - CachedFunc Create(const Function& relay_func, std::function renamer) { + CachedFunc Create(const Function& relay_func, GlobalVarSupply global_var_supply) { LowerToTECompute lower_te_compute(target_); - Array tensor_outs = lower_te_compute.Lower(relay_func, renamer); + Array tensor_outs = lower_te_compute.Lower(relay_func); Array fn_inputs = lower_te_compute.fn_inputs_; VisitExpr(relay_func->body); // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); + auto prim_fn_var = global_var_supply->FreshGlobal(lower_te_compute.candidate_name_); prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships @@ -402,8 +401,9 @@ class ScheduleBuilder : public ExprVisitor { } } - return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, prim_func, {}, - IRModule(Map({})), lower_te_compute.constant_tensors_); + IRModule funcs = IRModule(Map({})); + return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, prim_func, {}, funcs, + lower_te_compute.constant_tensors_); } void VisitExpr_(const CallNode* call_node) final { @@ -446,8 +446,8 @@ class ScheduleBuilder : public ExprVisitor { * The funcs field in cache is not yet populated. */ CachedFunc PrimFuncFor(const Function& source_func, const Target& target, - std::function renamer) { - return ScheduleBuilder(target).Create(source_func, renamer); + GlobalVarSupply global_var_supply) { + return ScheduleBuilder(target).Create(source_func, global_var_supply); } // Creates shape function from functor. @@ -456,7 +456,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> MakeShapeFunc() {} CachedFunc Create(const Function& prim_func, const Target& target, - std::function renamer) { + GlobalVarSupply global_var_supply) { VLOG_CONTEXT << "MakeShapeFunc"; TShapeDataDependent shape_func_param_states; @@ -527,8 +527,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. - auto func_name = renamer(candidate_name); - auto prim_fn_gvar = GlobalVar(func_name); + auto prim_fn_gvar = global_var_supply->FreshGlobal(candidate_name); // Gather the result types, again from the p.o.v. of the shape function rather than // the primitive it is derived for. @@ -569,19 +568,10 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); - - // Unfortunately the above machinery creates its own GlobalVars instead of using *the* - // GlobalVar we established above. Fix this before the confusion spreads any further. - // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name. - IRModule fixed_lowered_module; - for (const auto& kv : lowered_module->functions) { - GlobalVar global_var = - kv.first->name_hint == prim_fn_gvar->name_hint ? prim_fn_gvar : kv.first; - fixed_lowered_module->Add(global_var, kv.second); - } + IRModule lowered_module = + tvm::LowerSchedule(schedule, all_args, prim_fn_gvar->name_hint, binds, global_var_supply); return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, - shape_func_param_states, fixed_lowered_module); + shape_func_param_states, lowered_module); } Array VisitExpr(const Expr& expr) final { @@ -791,15 +781,14 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> }; CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, - std::function renamer) { - return MakeShapeFunc().Create(prim_func, target, renamer); + GlobalVarSupply global_var_supply) { + return MakeShapeFunc().Create(prim_func, target, global_var_supply); } std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - Array outputs = - lower_te_compute.Lower(source_func, [](std::string name) { return name; }); + Array outputs = lower_te_compute.Lower(source_func); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { @@ -814,34 +803,10 @@ std::pair, std::string> LowerTECompute(const Function& source_ return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); } -/*! - * \brief Get unique name from name. - * \param name The orginal name. - * \return Updated name which is unique. - */ -std::string GetUniqueName(std::string name, std::unordered_map* name_map_) { - for (size_t i = 0; i < name.length(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - while (true) { - auto it = name_map_->find(name); - if (it == name_map_->end()) { - (*name_map_)[name] = 1; - return name; - } else { - std::ostringstream os; - os << name << "_" << it->second; - ++(it->second); - name = os.str(); - } - } - return name; -} - TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { auto tgt = tvm::Target("ext_dev"); LowerToTECompute lower_te_compute(tgt); - auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); + auto outputs = lower_te_compute.Lower(prim_func); return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, outputs, te::Schedule(), tir::PrimFunc(), {}, IRModule(Map({})), lower_te_compute.constant_tensors_); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index ac2619826019..894a5f5be5f6 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ #define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ +#include #include #include #include @@ -227,13 +228,10 @@ std::pair, std::string> LowerTECompute(const Function& source_ * The funcs field in cache is not yet populated. */ CachedFunc PrimFuncFor(const Function& source_func, const Target& target, - std::function renamer); + GlobalVarSupply global_var_supply); CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, - std::function renamer); - -// TODO(mbs): Bring name uniqification under control -- this is replicated in quite a few places. -std::string GetUniqueName(std::string name, std::unordered_map* name_map); + GlobalVarSupply global_var_supply); // implementations inline size_t CCacheKeyNode::Hash() const { diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index b2776a41c50c..42fec9e27af2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -22,6 +22,7 @@ * \brief The dataflow pattern matcher for Relay. */ +#include #include #include #include @@ -438,15 +439,8 @@ Expr InferType(const Expr& expr) { Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { IRModule mod(m->functions, m->type_definitions, m->Imports()); - int idx = 0; - std::string gv_name; - do { - std::ostringstream oss; - oss << "_tmp" << idx; - gv_name = oss.str(); - ++idx; - } while (mod->ContainGlobalVar(gv_name)); - GlobalVar gvar(gv_name); + GlobalVarSupply global_var_supply = GlobalVarSupply(mod); + GlobalVar gvar = global_var_supply->FreshGlobal("_tmp", false); BaseFunc func; if (expr.as()) { func = Downcast(expr); diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index c538dac048b3..25111cec8eda 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,8 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - tec::PrimFuncFor(GetRef(func), Target::Current(), - [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc b/src/relay/transforms/meta_schedule_layout_rewrite.cc index b817802f17ef..8a70f224c611 100644 --- a/src/relay/transforms/meta_schedule_layout_rewrite.cc +++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc @@ -127,8 +127,7 @@ Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) { if (const auto* func = call->op.as()) { LayoutIndexQueue* self = LayoutIndexQueue::Global(); self->queue_.clear(); - tec::PrimFuncFor(GetRef(func), Target::Current(), - [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); if (!self->queue_.empty()) { std::deque queue = std::move(self->queue_); self->queue_.clear(); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index bc1ed518d473..e2df2e4272ad 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -333,14 +333,15 @@ class Partitioner : public MixedModeMutator { WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); - std::string fname = name; - ICHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; + GlobalVarSupply global_var_supply = GlobalVarSupply(module_); + GlobalVar glob_func = global_var_supply->FreshGlobal(name, false); + ICHECK(!module_->ContainGlobalVar(glob_func->name_hint)) + << "Global function " << glob_func->name_hint << " already exists"; // Create a global function and add it to the IRModule for the region. // This way we lift the functions that should be handled by external // codegen to the module scope and rely on the pass manager to prevent // relay function level passes (i.e. simplify inference and fusion) // optimizing it. - GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); module_ = relay::transform::InferType()(module_); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3fe7fa50d3cf..6b1ca81d85f6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -45,33 +45,33 @@ void CodeGenC::InitFuncState(const PrimFunc& f) { void CodeGenC::ReserveKeywordsAsUnique() { // skip the first underscore, so SSA variable starts from _1 - GetUniqueName("_"); - GetUniqueName("extern"); - GetUniqueName("void"); - GetUniqueName("int"); - GetUniqueName("float"); - GetUniqueName("double"); - GetUniqueName("char"); - GetUniqueName("unsigned"); - GetUniqueName("short"); - GetUniqueName("long"); - GetUniqueName("if"); - GetUniqueName("else"); - GetUniqueName("switch"); - GetUniqueName("case"); - GetUniqueName("default"); - GetUniqueName("for"); - GetUniqueName("do"); - GetUniqueName("while"); - GetUniqueName("goto"); - GetUniqueName("register"); - GetUniqueName("continue"); - GetUniqueName("break"); - GetUniqueName("typedef"); - GetUniqueName("struct"); - GetUniqueName("enum"); - GetUniqueName("union"); - GetUniqueName("return"); + name_supply_->ReserveName("_"); + name_supply_->ReserveName("extern"); + name_supply_->ReserveName("void"); + name_supply_->ReserveName("int"); + name_supply_->ReserveName("float"); + name_supply_->ReserveName("double"); + name_supply_->ReserveName("char"); + name_supply_->ReserveName("unsigned"); + name_supply_->ReserveName("short"); + name_supply_->ReserveName("long"); + name_supply_->ReserveName("if"); + name_supply_->ReserveName("else"); + name_supply_->ReserveName("switch"); + name_supply_->ReserveName("case"); + name_supply_->ReserveName("default"); + name_supply_->ReserveName("for"); + name_supply_->ReserveName("do"); + name_supply_->ReserveName("while"); + name_supply_->ReserveName("goto"); + name_supply_->ReserveName("register"); + name_supply_->ReserveName("continue"); + name_supply_->ReserveName("break"); + name_supply_->ReserveName("typedef"); + name_supply_->ReserveName("struct"); + name_supply_->ReserveName("enum"); + name_supply_->ReserveName("union"); + name_supply_->ReserveName("return"); } void CodeGenC::AddFunction(const PrimFunc& f) { diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 54975d166ea2..a47158d37883 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -42,7 +42,7 @@ namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); } +CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_str, const std::unordered_set& devices) { @@ -207,8 +207,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_args) { this->PrintIndent(); - std::string ret_val = GetUniqueName("ret_val"); - std::string ret_type_code = GetUniqueName("ret_type_code"); + std::string ret_val = name_supply_->FreshName("ret_val"); + std::string ret_type_code = name_supply_->FreshName("ret_type_code"); this->stream << "TVMValue " << ret_val << ";\n"; this->PrintIndent(); this->stream << "int " << ret_type_code << ";\n"; @@ -231,8 +231,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args, const std::string& resource_handle_name) { this->PrintIndent(); - std::string ret_val = GetUniqueName("ret_val"); - std::string ret_type_code = GetUniqueName("ret_type_code"); + std::string ret_val = name_supply_->FreshName("ret_val"); + std::string ret_type_code = name_supply_->FreshName("ret_type_code"); this->stream << "TVMValue " << ret_val << ";\n"; this->PrintIndent(); this->stream << "int " << ret_type_code << ";\n"; @@ -264,7 +264,7 @@ std::string CodeGenCHost::GetPackedName(const CallNode* op) { if (it != declared_globals_.end()) { unique_name = it->second; } else { - unique_name = GetUniqueName(packed_func_name); + unique_name = name_supply_->FreshName(packed_func_name); declared_globals_[packed_func_name] = unique_name; decl_stream << "static void* " << unique_name << " = NULL;\n"; } @@ -310,7 +310,7 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::tvm_stack_alloca())) { - std::string stack_name = GetUniqueName("stack"); + std::string stack_name = name_supply_->FreshName("stack"); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); ICHECK(num != nullptr); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index dde1d112edb3..7350292167ac 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -43,8 +43,8 @@ CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); - vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state); - vid_global_barrier_expect_ = GetUniqueName("__barrier_expect"); + vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); + vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect"); ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } @@ -403,7 +403,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. - std::string sret = GetUniqueName("_"); + std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(t, stream); stream << ' ' << sret << ";\n"; @@ -555,7 +555,7 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { this->PrintIndent(); this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; this->PrintIndent(); - std::string ptr = GetUniqueName("pf"); + std::string ptr = name_supply_->FreshName("pf"); this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; this->PrintIndent(); this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; @@ -589,7 +589,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. - std::string sret = GetUniqueName("_"); + std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(target_ty, stream); stream << ' ' << sret << ";\n"; @@ -631,7 +631,7 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr // v = __ret; // // Declare the result vector. - std::string sret = GetUniqueName("_"); + std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(ret_dtype, stream); stream << ' ' << sret << ";\n"; @@ -1138,7 +1138,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && op->dtype.lanes() == op->condition.dtype().lanes()); - std::string r_var = GetUniqueName("_"); + std::string r_var = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(op->dtype, stream); stream << ' ' << r_var << ";\n"; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 0ec617911519..b3ca3eb46149 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -55,7 +55,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from _1 - GetUniqueName("_"); + name_supply_->FreshName("_"); // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); @@ -94,7 +94,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { } // Setup normal arguments. size_t nargs = f->params.size() - num_buffer; - std::string varg = GetUniqueName("arg"); + std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer @@ -127,8 +127,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { decl_stream << "};\n\n"; } // Setup the thread group info. - ICHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); - ICHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 2353d2e6baf2..75833fd93629 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -28,34 +28,14 @@ namespace tvm { namespace codegen { void CodeGenSourceBase::ClearFuncState() { - name_alloc_map_.clear(); + name_supply_ = NameSupply(""); ssa_assign_map_.clear(); var_idmap_.clear(); scope_mark_.clear(); } -std::string CodeGenSourceBase::GetUniqueName(std::string prefix) { - for (size_t i = 0; i < prefix.size(); ++i) { - if (prefix[i] == '.') prefix[i] = '_'; - } - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - prefix = name; - break; - } - } - } - name_alloc_map_[prefix] = 0; - return prefix; -} - std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { - if (name_alloc_map_.count(src)) return src; + if (name_supply_->ContainsName(src)) return src; auto it = ssa_assign_map_.find(src); if (it != ssa_assign_map_.end()) { if (scope_mark_.at(it->second.scope_id)) { @@ -63,7 +43,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } } SSAEntry e; - e.vid = GetUniqueName("_"); + e.vid = name_supply_->FreshName("_"); e.scope_id = static_cast(scope_mark_.size() - 1); ssa_assign_map_[src] = e; this->PrintIndent(); @@ -74,7 +54,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; - std::string vid = GetUniqueName(key); + std::string vid = name_supply_->FreshName(key); std::replace(vid.begin(), vid.end(), ':', '_'); std::replace(vid.begin(), vid.end(), '-', '_'); std::replace(vid.begin(), vid.end(), '.', '_'); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 66287f9ad181..2fd0abcd68a6 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include #include @@ -97,12 +98,6 @@ class CodeGenSourceBase { * \param t The type of the expression. */ std::string SSAGetID(std::string src, DataType t); - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - std::string GetUniqueName(std::string prefix); /*! * \brief mark the beginning of a new scope * \return The scope id. @@ -127,12 +122,12 @@ class CodeGenSourceBase { std::ostringstream stream; /*! \brief name of each variable */ std::unordered_map var_idmap_; + /*! \brief NameSupply for allocation */ + NameSupply name_supply_ = NameSupply(""); private: /*! \brief assignment map of ssa */ std::unordered_map ssa_assign_map_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; /*! \brief array to check whether we are inside certain scope */ std::vector scope_mark_; /*! \brief The current indentation value */ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 68b25a165373..55df71a8053e 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -61,8 +62,10 @@ struct CreateFuncInfo { ProducerToBufferTransformer transformer; /*! \brief The buffers should be allocated at function root. */ Array root_alloc; - /*! \brief The count map to make block name unique. */ - std::unordered_map name_count; + /*! \brief The NameSupply to make block name unique. */ + NameSupply name_supply = NameSupply(""); + + String FreshName(String base_name) { return name_supply->FreshName(base_name); } explicit CreateFuncInfo(Array arg_list) : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} @@ -71,16 +74,6 @@ struct CreateFuncInfo { return std::any_of(arg_list.begin(), arg_list.end(), [&tensor](const te::Tensor& arg) { return tensor == arg; }); } - - String GetUniqueName(const String& prefix) { - String unique_prefix = prefix; - auto it = name_count.find(prefix); - while (name_count.count(unique_prefix)) { - unique_prefix = prefix + "_" + std::to_string(++it->second); - } - name_count[unique_prefix] = 0; - return unique_prefix; - } }; class LayoutFreePlaceholdersNormalizer : public StmtMutator { @@ -179,7 +172,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, Stmt body; if (const auto* reduce = expr_body.as()) { // Case 1. Reduce compute - block_name = info->GetUniqueName(compute_op->name); + block_name = info->FreshName(compute_op->name); int n_buffers = buffers.size(); Array lhs; @@ -236,7 +229,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, } else { // Case 2. Data parallel compute ICHECK_EQ(tensors.size(), 1); - block_name = info->GetUniqueName(tensors[0]->GetNameHint()); + block_name = info->FreshName(tensors[0]->GetNameHint()); const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map); body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices); } @@ -387,7 +380,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf Block(/*iter_vars=*/{}, /*reads=*/std::move(reads), /*writes=*/std::move(writes), - /*name_hint=*/info->GetUniqueName(extern_op->name), + /*name_hint=*/info->FreshName(extern_op->name), /*body=*/std::move(body), /*init=*/NullOpt, /*alloc_buffers=*/{}, diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 85845616f1a6..dc56a3ce762f 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -21,6 +21,7 @@ * \file split_host_device.cc * \brief Split device function from host. */ +#include #include #include #include @@ -302,12 +303,15 @@ class HostDeviceSplitter : public StmtMutator { arguments.push_back(var); } } + GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); + GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); + PrimFunc device_func(params, Substitute(body, remap_vars)); device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); - device_func = - WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); + device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, + runtime::String(kernel_symbol_global->name_hint)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); @@ -315,11 +319,11 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); } - (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); + (*device_mod_)->Add(kernel_symbol_global, device_func); // generate calls to the device function Array call_args; - call_args.push_back(StringImm(kernel_symbol)); + call_args.push_back(StringImm(kernel_symbol_global->name_hint)); for (PrimExpr arg : arguments) { call_args.push_back(arg); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index ff3641cd6982..3d2adb235546 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = LowerSchedule(s, args, "func", binds); + auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply(NameSupply(""))); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -121,8 +121,9 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds); - auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds); + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); + auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds, global_var_supply); + auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds, global_var_supply); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc index 097de862a926..442f76a8cff3 100644 --- a/tests/cpp/c_codegen_test.cc +++ b/tests/cpp/c_codegen_test.cc @@ -52,7 +52,8 @@ TEST(CCodegen, MainFunctionOrder) { auto args = Array({A, B, elemwise_add}); std::unordered_map binds; - auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds); + auto lowered = + LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply(NameSupply(""))); Map inputs = {{target_c, lowered}}; runtime::Module module = build(inputs, Target()); Array functions = module->GetFunction("get_func_names", false)(); @@ -81,7 +82,8 @@ auto BuildLowered(std::string op_name, tvm::Target target) { auto args = Array({A, B, op}); std::unordered_map binds; - auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds); + auto lowered_s = + LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply(NameSupply(""))); return lowered_s; } diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc new file mode 100644 index 000000000000..75b9ae86a9ab --- /dev/null +++ b/tests/cpp/name_supply_test.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace tvm; + +NameSupply preambleNameSupply() { + NameSupply name_supply = NameSupply("prefix"); + name_supply->FreshName("test"); + return name_supply; +} + +TEST(NameSupply, FreshName) { + NameSupply name_supply = preambleNameSupply(); + String fresh = name_supply->FreshName("test"); + + EXPECT_EQ(fresh.compare("prefix_test_1"), 0); +} + +TEST(NameSupply, FreshNameNoConflict) { + NameSupply name_supply = preambleNameSupply(); + String fresh = name_supply->FreshName("name_2"); + EXPECT_EQ(fresh.compare("prefix_name_2"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name_1"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name_3"), 0); +} + +TEST(NameSupply, ContainsName) { + NameSupply name_supply = preambleNameSupply(); + + EXPECT_TRUE(name_supply->ContainsName("test")); + EXPECT_FALSE(name_supply->ContainsName("test_1")); +} + +TEST(NameSupply, ReserveName) { + NameSupply name_supply = preambleNameSupply(); + name_supply->ReserveName("otherTest", false); + + EXPECT_TRUE(name_supply->ContainsName("otherTest", false)); + EXPECT_FALSE(name_supply->ContainsName("otherTest")); + + name_supply->ReserveName("otherTest"); + EXPECT_TRUE(name_supply->ContainsName("prefix_otherTest", false)); + EXPECT_TRUE(name_supply->ContainsName("otherTest")); +} + +GlobalVarSupply preambleVarSupply() { + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); + global_var_supply->FreshGlobal("test"); + return global_var_supply; +} + +TEST(GlobalVarSupply, FreshGlobal) { + GlobalVarSupply global_var_supply = preambleVarSupply(); + GlobalVar first_var = global_var_supply->FreshGlobal("test"); + GlobalVar second_var = global_var_supply->FreshGlobal("test"); + + EXPECT_FALSE(tvm::StructuralEqual()(first_var, second_var)); + EXPECT_EQ(first_var->name_hint.compare("test_1"), 0); + EXPECT_EQ(second_var->name_hint.compare("test_2"), 0); +} + +TEST(GlobalVarSupply, UniqueGlobalFor) { + GlobalVarSupply global_var_supply = preambleVarSupply(); + GlobalVar first_var = global_var_supply->UniqueGlobalFor("someName"); + GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName"); + + EXPECT_TRUE(tvm::StructuralEqual()(first_var, second_var)); + EXPECT_EQ(first_var->name_hint.compare("someName"), 0); + EXPECT_EQ(second_var->name_hint.compare("someName"), 0); +} + +TEST(GlobalVarSupply, ReserveGlobal) { + GlobalVarSupply global_var_supply = preambleVarSupply(); + GlobalVar var = GlobalVar("someName"); + global_var_supply->ReserveGlobalVar(var); + GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName"); + GlobalVar third_var = global_var_supply->FreshGlobal("someName"); + + EXPECT_TRUE(tvm::StructuralEqual()(var, second_var)); + EXPECT_FALSE(tvm::StructuralEqual()(var, third_var)); + EXPECT_EQ(second_var->name_hint.compare("someName"), 0); + EXPECT_EQ(third_var->name_hint.compare("someName_1"), 0); +} + +TEST(GlobalVarSupply, BuildIRModule) { + auto x = relay::Var("x", relay::Type()); + auto f = relay::Function(tvm::Array{x}, x, relay::Type(), {}); + GlobalVar var = GlobalVar("test"); + IRModule module = IRModule({{var, f}}); + + GlobalVarSupply global_var_supply = GlobalVarSupply(module); + GlobalVar second_var = global_var_supply->UniqueGlobalFor("test", false); + GlobalVar third_var = global_var_supply->FreshGlobal("test", false); + + EXPECT_TRUE(tvm::StructuralEqual()(var, second_var)); + EXPECT_FALSE(tvm::StructuralEqual()(var, third_var)); + EXPECT_EQ(second_var->name_hint.compare("test"), 0); + EXPECT_EQ(third_var->name_hint.compare("test_1"), 0); +} diff --git a/tests/python/relay/backend/test_pass_lower_te.py b/tests/python/relay/backend/test_pass_lower_te.py index 310a16e269e0..fb79c1f2e7a6 100644 --- a/tests/python/relay/backend/test_pass_lower_te.py +++ b/tests/python/relay/backend/test_pass_lower_te.py @@ -203,12 +203,12 @@ def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1 # Expected: # def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(?, ?), float32] { # %0 = (%a, %a); - # call_lowered(@my_dyn, %0, metadata={prim_shape_fn_var='shape_func_add', relay_attrs={Extern=1}, prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, all_prim_fn_vars=[]}) + # call_lowered(@my_dyn, %0, metadata={prim_shape_fn_var='test_shape_func_add', relay_attrs={Extern=1}, prim_shape_fn_states=[2, 2], prim_shape_fn_num_inputs=2, all_prim_shape_fn_vars=['shape_func_add'], prim_shape_fn_num_outputs=1, all_prim_fn_vars=[]}) # } # def @my_dyn(%x: Tensor[(5, 7), float32] , %y: Tensor[(5, 7), float32] , Extern=1) -> Tensor[(?, ?), float32] { # add(%x, %y) # } - # def @shape_func_add = + # def @test_shape_func_add = main = actual_mod["main"] call = main.body @@ -218,14 +218,14 @@ def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1 assert len(call.args[1].fields) == 2 assert call.args[1].fields[0].name_hint == "a" assert call.args[1].fields[1].name_hint == "a" - assert call.attrs.metadata["prim_shape_fn_var"].name_hint == "shape_func_add" + assert call.attrs.metadata["prim_shape_fn_var"].name_hint == "test_shape_func_add" assert call.attrs.metadata["relay_attrs"].Extern == 1 assert len(call.attrs.metadata["prim_shape_fn_states"]) == 2 assert call.attrs.metadata["prim_shape_fn_states"][0] == 2 assert call.attrs.metadata["prim_shape_fn_states"][1] == 2 assert call.attrs.metadata["prim_shape_fn_num_inputs"] == 2 assert len(call.attrs.metadata["all_prim_shape_fn_vars"]) == 1 - assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == "shape_func_add" + assert call.attrs.metadata["all_prim_shape_fn_vars"][0].name_hint == "test_shape_func_add" assert call.attrs.metadata["prim_shape_fn_num_outputs"] == 1 assert len(call.attrs.metadata["all_prim_fn_vars"]) == 0 @@ -233,7 +233,7 @@ def @my_dyn(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], Extern=1 assert isinstance(my_dyn, tvm.relay.Function) assert my_dyn.attrs["Extern"] == 1 - shape_func_add = actual_mod["shape_func_add"] + shape_func_add = actual_mod["test_shape_func_add"] assert isinstance(shape_func_add, tvm.tir.PrimFunc) diff --git a/tests/python/relay/test_name_supply.py b/tests/python/relay/test_name_supply.py new file mode 100644 index 000000000000..688be19c8171 --- /dev/null +++ b/tests/python/relay/test_name_supply.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing + +from tvm import relay +from tvm.ir import GlobalVar, structural_equal +from tvm.ir.supply import NameSupply +from tvm.ir.supply import GlobalVarSupply + + +def test_name_supply(): + name_supply = NameSupply("prefix") + name_supply.reserve_name("test") + + assert name_supply.contains_name("test") + assert name_supply.fresh_name("test") == "prefix_test_1" + assert name_supply.contains_name("test_1") + assert not name_supply.contains_name("test_1", False) + assert not name_supply.contains_name("test_2") + + +def test_global_var_supply_from_none(): + var_supply = GlobalVarSupply() + global_var = GlobalVar("test") + var_supply.reserve_global(global_var) + + assert structural_equal(var_supply.unique_global_for("test"), global_var) + assert not structural_equal(var_supply.fresh_global("test"), global_var) + + +def test_global_var_supply_from_name_supply(): + name_supply = NameSupply("prefix") + var_supply = GlobalVarSupply(name_supply) + global_var = GlobalVar("test") + var_supply.reserve_global(global_var) + + assert structural_equal(var_supply.unique_global_for("test", False), global_var) + assert not structural_equal(var_supply.unique_global_for("test"), global_var) + + +def test_global_var_supply_from_ir_mod(): + x = relay.var("x") + y = relay.var("y") + mod = tvm.IRModule() + global_var = GlobalVar("test") + mod[global_var] = relay.Function([x, y], relay.add(x, y)) + var_supply = GlobalVarSupply(mod) + + second_global_var = var_supply.fresh_global("test", False) + + assert structural_equal(var_supply.unique_global_for("test", False), global_var) + assert not structural_equal(var_supply.unique_global_for("test"), global_var) + assert not structural_equal(second_global_var, global_var) + + +if __name__ == "__main__": + tvm.testing.main()