Skip to content

Commit

Permalink
Build GlobalVarSupply from IRModules instead of having it attached to…
Browse files Browse the repository at this point in the history
… an IRModule.
  • Loading branch information
gigiblender committed Jul 15, 2022
1 parent a7a9278 commit 3c327fe
Show file tree
Hide file tree
Showing 23 changed files with 161 additions and 76 deletions.
2 changes: 2 additions & 0 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class GlobalVarSupplyNode : public Object {

GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);

void VisitAttrs(AttrVisitor* v) { v->Visit("name_supply", &name_supply_); }

NameSupply name_supply_;
Expand Down
23 changes: 12 additions & 11 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class IRModuleNode : public Object {
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;

GlobalVarSupply global_var_supply;

/*!
* \brief Get a module attribute.
*
Expand Down Expand Up @@ -128,7 +126,6 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
v->Visit("global_var_supply", &global_var_supply);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -358,14 +355,12 @@ class IRModule : public ObjectRef {
/*!
* \brief constructor
* \param functions Functions in the module.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
DictAttrs attrs = {});
Expand Down Expand Up @@ -401,7 +396,6 @@ class IRModule : public ObjectRef {
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map. Default empty.
* \param global_var_supply The GlobalVarSupply to be used in the module.
* \param type_definitions The global type definition map. Default empty.
* \param import_set Set of external modules already imported. Default empty.
*
Expand All @@ -412,18 +406,16 @@ class IRModule : public ObjectRef {
*/
static std::pair<IRModule, GlobalVar> FromExprInContext(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {});

/*!
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
* imports.
*/
TVM_DLL static IRModule FromExpr(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});

/*!
* \brief Parse text format source file into an IRModule.
Expand Down Expand Up @@ -482,6 +474,15 @@ namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Name of the module
*
* Type: String
*
* \sa tvm::runtime::String
*/
constexpr const char* kModuleName = "name";

/*!
* \brief Executor targeted by the module
*
Expand Down
5 changes: 2 additions & 3 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,15 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
*
* \param expr An expression to evaluate.
* \param type_definitions Global type definitions which \p expr may references.
* \param global_var_supply The GlobalVarSupply to be used during evaluation.
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \param attrs Attributes for the expression to be evaluated with
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
GlobalVarSupply global_var_supply, std::unordered_set<String> import_set,
Device device, Target target, Map<String, ObjectRef> attrs = {});
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs = {});

} // namespace relay
} // namespace tvm
Expand Down
16 changes: 4 additions & 12 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi
from tvm.ir.supply import GlobalVarSupply

from .base import Node
from . import expr as _expr
Expand All @@ -37,7 +36,7 @@ class IRModule(Node):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None, globar_var_supply=None):
def __init__(self, functions=None, type_definitions=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -60,11 +59,7 @@ def __init__(self, functions=None, type_definitions=None, globar_var_supply=None
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
if globar_var_supply is None:
globar_var_supply = GlobalVarSupply()
self.__init_handle_by_constructor__(
_ffi_api.IRModule, functions, type_definitions, globar_var_supply
)
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down Expand Up @@ -222,7 +217,7 @@ def get_type(self, name):
return tuple([ty_var] + list(ty_data.constructors))

@staticmethod
def from_expr(expr, functions=None, type_defs=None, global_var_supply=None):
def from_expr(expr, functions=None, type_defs=None):
"""Construct a module from a standalone expression.
Parameters
Expand All @@ -243,12 +238,9 @@ def from_expr(expr, functions=None, type_defs=None, global_var_supply=None):
where expr is set as the entry point
(wrapped in a function if necessary)
"""
global_var_supply = (
global_var_supply if global_var_supply is not None else GlobalVarSupply()
)
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _ffi_api.Module_FromExpr(expr, funcs, global_var_supply, defs)
return _ffi_api.Module_FromExpr(expr, funcs, defs)

def _import(self, file_to_import):
return _ffi_api.Module_Import(self, file_to_import)
Expand Down
6 changes: 2 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
GlobalVar global_var = global_var_supply->UniqueGlobalFor(name, false);
return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}), global_var_supply);
return IRModule(Map<GlobalVar, BaseFunc>({{global_var, f}}));
}

TVM_REGISTER_GLOBAL("driver.schedule_to_module")
Expand Down Expand Up @@ -432,9 +432,7 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply();
IRModule mhost_all =
IRModule(Map<GlobalVar, BaseFunc>(), global_var_supply, {}, {}, {}, first_module->attrs);
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

ICHECK(mhost_all.defined()) << "The host module must be defined";

Expand Down
9 changes: 9 additions & 0 deletions src/ir/global_var_supply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ GlobalVarSupply GlobalVarSupply::EmptySupply() {
return GlobalVarSupplyFromNameSupply(NameSupply::NameSupplyWithPrefix(""));
}

void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) {
name_supply_->ReserveName(var->name_hint, false);
if (!allow_conflict) {
ICHECK(name_to_var_map_.count(var->name_hint) == 0)
<< "GlobalVar " << var << " conflicts by name in this supply.";
}
name_to_var_map_[var->name_hint] = var;
}

GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply)
: name_supply_(std::move(name_supply)) {}

Expand Down
29 changes: 14 additions & 15 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@
#include <sstream>
#include <unordered_set>

#include "../relay/backend/supply_provider.h"

namespace tvm {

IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions, GlobalVarSupply global_var_supply,
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, parser::SourceMap source_map,
DictAttrs attrs) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->global_var_supply = global_var_supply;
n->type_definitions = std::move(type_definitions);
n->global_type_var_map_ = {};
n->global_var_map_ = {};
Expand Down Expand Up @@ -361,16 +362,15 @@ void IRModuleNode::Update(const IRModule& mod) {
}

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->global_var_supply, this->type_definitions, this->Imports(),
this->source_map, this->attrs);
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
GlobalVarSupply global_var_supply, const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
std::unordered_set<String> import_set) {
auto mod =
IRModule(global_funcs, std::move(global_var_supply), type_definitions, std::move(import_set));
auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
String gv_name;

// All global definitions must be functions.
Expand All @@ -386,21 +386,20 @@ std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
}

GlobalVar main_gv;
auto global_var_supply = tvm::BuildGlobalVarSupply(mod);
if (gv_name.empty()) {
// Bind function to 'main' (though rename if would clash with existing 'main').
main_gv = mod->global_var_supply->FreshGlobal("main");
main_gv = global_var_supply->FreshGlobal("main", false);
} else {
main_gv = mod->global_var_supply->UniqueGlobalFor(gv_name);
main_gv = global_var_supply->UniqueGlobalFor(gv_name, false);
}
mod->Add(main_gv, func);
return {mod, main_gv};
}

IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs,
GlobalVarSupply global_var_supply,
const Map<GlobalTypeVar, TypeData>& type_definitions) {
return FromExprInContext(expr, global_funcs, std::move(global_var_supply), type_definitions)
.first;
return FromExprInContext(expr, global_funcs, type_definitions).first;
}

void IRModuleNode::Import(const String& path) {
Expand Down Expand Up @@ -438,9 +437,9 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
TVM_REGISTER_NODE_TYPE(IRModuleNode);

TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
GlobalVarSupply global_var_supply) {
return IRModule(funcs, global_var_supply, types, {});
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return IRModule(funcs, types, {});
});

TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1889,7 +1889,7 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,
IRModule module;
if (!init_module) {
SourceMap source_map;
module = IRModule({}, GlobalVarSupply::EmptySupply(), {}, {}, source_map);
module = IRModule({}, {}, {}, source_map);
} else {
module = init_module.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Kind KindCheck(const Type& t, const IRModule& mod, Optional<DiagnosticContext> d

TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], IRModule({}, GlobalVarSupply::EmptySupply(), {}));
*ret = KindCheck(args[0], IRModule({}, {}));
} else if (args.size() == 2) {
*ret = KindCheck(args[0], args[1], Optional<DiagnosticContext>());
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/relay/analysis/match_exhaustion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
// expose for testing only
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
.set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) {
IRModule call_mod =
mod_ref.defined() ? mod_ref.value() : IRModule({}, GlobalVarSupply::EmptySupply(), {});
IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {});
return UnmatchedCases(match, call_mod);
});

Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto module = IRModule({}, GlobalVarSupply::EmptySupply(), {});
auto module = IRModule({}, {});
DiagnosticContext diag_ctx = DiagnosticContext::Default(module);
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
Expand Down
12 changes: 6 additions & 6 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,8 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
mod = transform::InferType()(mod);
mod_and_global = IRModule::FromExprInContext(expr, mod->functions, mod->global_var_supply,
mod->type_definitions, mod->Imports());
mod_and_global =
IRModule::FromExprInContext(expr, mod->functions, mod->type_definitions, mod->Imports());
} else {
mod_and_global = IRModule::FromExprInContext(expr);
}
Expand Down Expand Up @@ -1104,14 +1104,14 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
}

ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
GlobalVarSupply global_var_supply, std::unordered_set<String> import_set,
Device device, Target target, Map<String, ObjectRef> attrs) {
std::unordered_set<String> import_set, Device device, Target target,
Map<String, ObjectRef> attrs) {
ICHECK_EQ(device.device_type, target->kind->device_type);
Array<Target> raw_targets = {target};
CompilationConfig config(transform::PassContext::Current(), raw_targets);

std::pair<IRModule, GlobalVar> mod_and_global = IRModule::FromExprInContext(
expr, /*global_funcs=*/{}, global_var_supply, type_definitions, import_set);
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);

IRModule mod = Prepare(WithAttrs(mod_and_global.first, {attrs}), config);

Expand Down
52 changes: 52 additions & 0 deletions src/relay/backend/supply_provider.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "supply_provider.h"

#include <string>

namespace tvm {

// TODO(gigiblender): move this method
std::string GetModuleName(const IRModule& module) {
return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
}

GlobalVarSupply BuildGlobalVarSupply(const IRModule module) {
return BuildGlobalVarSupply(Array<IRModule>({module}));
}

GlobalVarSupply BuildGlobalVarSupply(const Array<IRModule>& modules) {
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply();
// TODO(gigiblender): For now use as prefix the name of the first module.
if (!modules.empty()) {
IRModule first_mod = modules.front();
global_var_supply->name_supply_->prefix_ = GetModuleName(first_mod);
}
for (auto& mod : modules) {
for (auto kv : mod->functions) {
global_var_supply->ReserveGlobalVar(kv.first);
}
}

return global_var_supply;
}

} // namespace tvm
Loading

0 comments on commit 3c327fe

Please sign in to comment.