Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify name mangling in TVM #12066

Merged
merged 7 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#ifndef TVM_DRIVER_DRIVER_API_H_
#define TVM_DRIVER_DRIVER_API_H_

#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
Expand Down Expand Up @@ -99,14 +100,15 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/

TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Build an IRModule given a TE schedule, args and binds. This function also applies
Expand All @@ -115,13 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param simple_mode Disables the loop partition pass. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
bool simple_mode = false);
GlobalVarSupply global_var_supply, bool simple_mode = false);

/*!
* \brief Create an IRModule out of a TE Schedule. It does not apply lowering passes. If you want
Expand All @@ -130,10 +133,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param global_var_supply The GlobalVarSupply to be used in the module and when creating
* GlobalVars.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);
/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
125 changes: 125 additions & 0 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/global_var_supply.h
* \brief GlobalVarSupply that can be used to generate unique \class GlobalVar.
*/
#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
#define TVM_IR_GLOBAL_VAR_SUPPLY_H_

#include <string>
#include <unordered_map>

#include "tvm/ir/expr.h"
#include "tvm/ir/module.h"
#include "tvm/ir/name_supply.h"

namespace tvm {

/*!
* \brief GlobalVarSupply can be used to generate unique GlobalVars.
*/
class GlobalVarSupplyNode : public Object {
public:
/*!
* \brief Empty constructor. Will use an empty NameSupply.
*/
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}

/*!
* \brief Constructor.
* \param name_supply The NameSupply to use for generating the names of fresh GlobalVars.
* \param name_to_var_map An optional map.
*/
explicit GlobalVarSupplyNode(NameSupply name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
* \brief Generates a unique GlobalVar from this supply.
* \param name The name from which the name of the GlobalVar is derived.
* \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended
* to the name. \return A unique GlobalVar.
*/
GlobalVar FreshGlobal(String name, bool add_prefix = true);
gigiblender marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Looks up for a GlobalVar with the given name in this supply.
* If no entry is found, creates one, places it in the cache and returns it.
* \param name The name of the GlobalVar to search for.
* \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to
* the name before performing the search. \return A cached GlobalVar.
*/
GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

/*!
* \brief Reserves an existing GlobalVar with this supply.
* \param var The GlobalVar to be registered.
* \param allow_conflict Allow conflict with other GlobalVars that have the same name.
*/
void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);

void VisitAttrs(AttrVisitor* v) {}

/*! \brief The NameSupply used to generate unique name hints to GlobalVars. */
NameSupply name_supply_;

static constexpr const char* _type_key = "GlobalVarSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);

private:
std::unordered_map<std::string, GlobalVar> name_to_var_map_;
};

/*!
* \brief Managed reference class to GlobalVarSupplyNode.
* \sa GlobalVarSupplyNode
*/
class GlobalVarSupply : public ObjectRef {
public:
/*!
* \brief Constructor.
* \param name_supply The NameSupply to be used when generating new GlobalVars.
* \param name_to_var_map An optional map.
*/
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
* \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are
* guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array
* of IRModules.
*/
TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);

/*!
* \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are
* guaranteed not to conflict with GlobalVars that belong to the modules. \param module The
* IRModule.
*/
TVM_DLL explicit GlobalVarSupply(const IRModule module);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_
17 changes: 9 additions & 8 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,6 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -481,6 +473,15 @@ namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Name of the module
*
* Type: String
*
* \sa tvm::runtime::String
*/
constexpr const char* kModuleName = "mod_name";

/*!
* \brief Executor targeted by the module
*
Expand Down
123 changes: 123 additions & 0 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/name_supply.h
* \brief NameSupply that can be used to generate unique variable names.
*/
#ifndef TVM_IR_NAME_SUPPLY_H_
#define TVM_IR_NAME_SUPPLY_H_

#include <string>
#include <unordered_map>
#include <utility>

#include "tvm/ir/expr.h"

namespace tvm {

/*!
* \brief NameSupply can be used to generate unique names.
*/
class NameSupplyNode : public Object {
public:
/*!
* \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro.
*/
NameSupplyNode() = default;

/*!
* \brief Constructor.
* \param prefix The prefix to be used with this NameSupply.
* \param name_map The map used to guarantee uniqueness.
*/
NameSupplyNode(const String& prefix, std::unordered_map<std::string, int> name_map)
: prefix_(prefix), name_map(std::move(name_map)) {}

/*!
* \brief Generates a unique name from this NameSupply.
* \param name The name from which the generated name is derived.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name. \return A unique name.
*/
String FreshName(const String& name, bool add_prefix = true);
gigiblender marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Reserves an existing name with this NameSupply.
* \param name The name to be reserved.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before reserving it. \return The name that was reserved with the NameSupply. It can be
* different if a prefix is added.
*/
String ReserveName(const String& name, bool add_prefix = true);

/*!
* \brief Checks if this NameSupply already generated a name.
* \param name The name to check.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before checking for it. \return True if the name has already been generated. False
* otherwise.
*/
bool ContainsName(const String& name, bool add_prefix = true);

void VisitAttrs(AttrVisitor* v) {}

// Prefix for all GlobalVar names. It can be empty.
std::string prefix_;

static constexpr const char* _type_key = "NameSupply";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);

private:
/*! \brief Helper function to add the NameSupply prefix to the name. */
String add_prefix_to_name(const String& name);

/*!
* \brief Function that will generate a unique name.
* \param name The name to be used as a base.
* \return A unique name.
*/
std::string GetUniqueName(std::string name);

/*! \brief A map that is used to generate unique names. */
std::unordered_map<std::string, int> name_map;
};

/*!
* \brief Managed reference class to NameSupplyNode.
* \sa NameSupplyNode
*/
class NameSupply : public ObjectRef {
public:
/*!
* \brief Constructor.
* \param prefix The prefix to be used with this NameSupply.
* \param name_map An optional map.
*/
TVM_DLL explicit NameSupply(const String& prefix,
std::unordered_map<std::string, int> name_map = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
};

} // namespace tvm

#endif // TVM_IR_NAME_SUPPLY_H_
Loading