diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index fefb08f878ef..852c7d0d8a98 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -122,6 +122,7 @@ class IRModuleNode : public Object { v->Visit("global_var_map_", &global_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("source_map", &source_map); + v->Visit("attrs", &attrs); } TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; @@ -277,6 +278,12 @@ class IRModuleNode : public Object { */ TVM_DLL void Update(const IRModule& other); + /*! + * \brief Create a shallow copy of this IRModule. + * \returns The shallow copy of the IRModule. + */ + TVM_DLL IRModule ShallowCopy(); + /*! * \brief Import Relay code from the file at path. * \param path The path of the Relay code to import. @@ -348,12 +355,14 @@ class IRModule : public ObjectRef { * \brief constructor * \param functions Functions in the module. * \param type_definitions Type definitions in the module. - * \param import_set Set of imported files in the module + * \param import_set Set of imported files in the module. * \param map The module source map. + * \param attrs The module attributes. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}, parser::SourceMap map = {}); + std::unordered_set import_set = {}, parser::SourceMap map = {}, + DictAttrs attrs = {}); /*! \brief default constructor */ IRModule() : IRModule(Map({})) {} @@ -415,6 +424,13 @@ class IRModule : public ObjectRef { */ TVM_DLL static IRModule FromText(const String& text, const String& source_path); + /*! + * \brief Create a shallow copy of an IRModule. + * \param mod The module to copy. + * \return The copied module. + */ + IRModule ShallowCopyIRModule(IRModule mod); + /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; diff --git a/src/ir/module.cc b/src/ir/module.cc index d4129c84ccf5..15c441d61a23 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -43,7 +43,8 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set, parser::SourceMap source_map) { + std::unordered_set import_set, parser::SourceMap source_map, + DictAttrs attrs) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -52,6 +53,7 @@ IRModule::IRModule(tvm::Map functions, n->constructor_tag_map_ = {}; n->import_set_ = std::move(import_set); n->source_map = source_map; + n->attrs = std::move(attrs); for (const auto& kv : n->functions) { // set global var map @@ -72,6 +74,7 @@ IRModule::IRModule(tvm::Map functions, bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { if (functions.size() != other->functions.size()) return false; + if (!equal(this->attrs, other->attrs)) return false; for (const auto& kv : this->functions) { if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; @@ -112,6 +115,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { temp.emplace_back(kv.first->name_hint, kv.second); } reduce_temp(); + hash_reduce(this->attrs); } bool IRModuleNode::ContainGlobalVar(const String& name) const { @@ -361,6 +365,11 @@ void IRModuleNode::Update(const IRModule& mod) { } } +IRModule IRModuleNode::ShallowCopy() { + return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map, + this->attrs); +} + std::pair IRModule::FromExprInContext( const RelayExpr& expr, const tvm::Map& global_funcs, const tvm::Map& type_definitions, diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 0c094cb1fa2c..70779ac58abf 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -676,7 +676,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { Optional> external_modules = lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point."; + ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index b7b388431ca1..aca95db34c4e 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -241,26 +241,23 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(main_module->Lookup("main")); + Function lowered_main_func = Downcast(lowered_mod->Lookup("main")); // Now that we have lowered all operators to TIR code, we can proceed with compilation. // // We need to unfortunately re-plan as the previous results have been invalidated by lowering // we will fix this in future refactors. - memory_plan_ = GraphPlanMemory(main_func); + memory_plan_ = GraphPlanMemory(lowered_main_func); // The graph planner also can not handle planning calls to global variables to we must remap // First we convert all the parameters into input nodes. - for (auto param : main_func->params) { + for (auto param : lowered_main_func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } - heads_ = VisitExpr(main_func->body); + heads_ = VisitExpr(lowered_main_func->body); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -277,7 +274,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> external_modules = lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point."; + ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 82455bdf925c..df14b9e078b6 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -292,14 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - // TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes - // LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) - : mod_(mod), - per_target_module_(per_target_module), - device_(device), - target_(target), - debug_op_(Op::Get("debug")) {} + Interpreter(IRModule unified_mod, Device device, Target target) + : unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -316,7 +310,7 @@ class Interpreter : public ExprFunctor, ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } ObjectRef VisitExpr_(const GlobalVarNode* op) final { - return Eval(mod_->Lookup(GetRef(op))); + return Eval(unified_mod_->Lookup(GetRef(op))); } ObjectRef VisitExpr_(const OpNode* id) override { @@ -387,9 +381,9 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; + Map per_target_module = tec::GetPerTargetModules(unified_mod_); std::unordered_map - per_target_module_std_map = - backend::TargetModuleMapToTargetStrModuleMap(per_target_module_); + per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module); auto mod_itr = per_target_module_std_map.find(target); ICHECK(mod_itr != per_target_module_std_map.end()) << "No target module for target '" << target->str() << "'"; @@ -876,13 +870,11 @@ class Interpreter : public ExprFunctor, } private: - // Main module. All expressions are eval'ed w.r.t. the definitions in this module. This module - // may contain calls to TIR functions bound in a per_target_module_ below. - IRModule mod_; - // Map from target key to lowered TIR functions derived from mod_. - // Note that primitives are implicitly executed on target_, while shape functions are implicitly - // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + // Unified module. Functions are annotated with their target. + // All expressions are eval'ed w.r.t. the definitions in this module. + // This module contains functions that used to be in main_module and the per_target_module (TIR + // functions) in one module. + IRModule unified_mod_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; @@ -902,7 +894,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +IRModule Prepare(IRModule mod, Device device, Target target) { // Things to initialize to pass into tec::LowerTEPass // We only have one device-specific target. tec::TargetMap targets = {{device.device_type, target}}; @@ -930,8 +922,7 @@ std::pair> Prepare(IRModule mod, Device device, With ctx(pass_ctx); mod = seq(mod); - // Lower all primitive functions reachable from expr. - return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)}; + return mod; } /*! \brief Check if an expression could be changed by \p Prepare. @@ -1020,11 +1011,9 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = - Prepare(mod_with_expr, device, target); - std::shared_ptr intrp = std::make_shared( - /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, - target); + IRModule lowered_mod = Prepare(mod_with_expr, device, target); + + std::shared_ptr intrp = std::make_shared(lowered_mod, device, target); // // Step 2: Evaluate target function to a closure. @@ -1063,12 +1052,11 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = - Prepare(mod_and_global.first, device, target); - Interpreter intrp( - /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, - target); - Expr expr_to_eval = main_and_lowered.first->GetGlobalVar(mod_and_global.second->name_hint); + + IRModule mod = Prepare(mod_and_global.first, device, target); + + Interpreter intrp(mod, device, target); + Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); if (expr.as() == nullptr) { // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr // unless it is a function, so we must reverse that in the expression to eval. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 4d7f50b4f3a0..0393fdfec70d 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -900,8 +900,9 @@ Map GetPerTargetModules(IRModule mod) { // Put the function in per_target_modules if (!per_target_modules.count(target.value())) { - // Initialize the IRModule for this target and add the function - IRModule target_module; + // Initialize the IRModule for this target with the attributes from the input IRModule + IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs); + // Add the function to the IRModule target_module->Add(var, func); per_target_modules[target.value()] = target_module; } else { @@ -918,23 +919,6 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -IRModule GetMainModule(IRModule mod) { - IRModule main_module; - // Copy the type defs - for (const auto& kv : mod->type_definitions) { - main_module->AddTypeDef(kv.first, kv.second); - } - // Copy all Relay functions (we don't include PrimFuncs) - for (auto kv : mod->functions) { - const GlobalVar& var = kv.first; - const BaseFunc& func = kv.second; - if (func->IsInstance()) { - main_module->Add(var, func); - } - } - return main_module; -} - Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn) { @@ -942,9 +926,8 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, PassContext ctx) { return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); }; - // TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to - // be called afterwards - return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); + return tvm::transform::Sequential( + {tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()}); } } // namespace tec } // namespace relay diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 082cd8c4491a..9d0eb1078ee0 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -165,13 +165,6 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); */ Map GetPerTargetModules(IRModule mod); -/*! - * \brief Utility to extract all the Relay functions from an IRModule, with no PrimFuncs. - * \param mod The IRModule to extract the Relay functions from - * \return An IRModule containing only the Relay functions that are in the input mod (no PrimFuncs) - */ -IRModule GetMainModule(IRModule mod); - /*! \brief Lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 4a7974cae5ae..344d1cae7823 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -133,9 +133,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - // Execute the pass function and return a new module. - IRModule updated_mod = - IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + IRModule updated_mod = mod->ShallowCopy(); std::vector > updates; for (const auto& it : updated_mod->functions) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index b48fbe44bd11..f74cf983ccae 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -30,6 +30,7 @@ */ #include +#include #include #include #include @@ -509,7 +510,9 @@ class NameMangleExtFuncs : public MixedModeMutator { // Walk the tree and mangle the functions. Then replace compiler functions // with mangled functions in the module - IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports()); + IRModule new_module = module_->ShallowCopy(); + new_module->functions = {}; + for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index d03fc1488aea..8e952d60b8b7 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -52,7 +52,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { DLOG(INFO) << "ToBBlock:" << std::endl << mod; // Create a new module by shallow copy. - auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + IRModule mod_ = mod->ShallowCopy(); tvm::Map updates; auto funcs = mod_->functions; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index f29087dcc049..6c2371716b16 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -205,13 +205,17 @@ class TypeInferencer : private ExprFunctor, this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " << "without a module"); } - if (mod_->ContainGlobalVar(var->name_hint)) { - relay::Function e = Downcast(mod_->Lookup(var)); - return e->checked_type(); - } else { - return op->checked_type_; + BaseFunc func = mod_->Lookup(var->name_hint); + + if (func->IsInstance()) { + relay::Function relay_func = Downcast(func); + return relay_func->checked_type(); + } } + // Return op->checked_type if the module doesn't contain the GlobalVar or the function is a + // PrimFunc (we don't typecheck PrimFuncs) + return op->checked_type_; } Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } @@ -822,8 +826,7 @@ Pass InferType() { [=](IRModule mod, const PassContext& pass_ctx) { DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. - IRModule updated_mod = - IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + IRModule updated_mod = mod->ShallowCopy(); pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod); diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index b90bce548a5e..092cae01f568 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -194,6 +194,8 @@ def get_func(shape): engine.dump() +# Note: Once compile engine is removed, we should keep this test so that +# we make sure that opt_level=0 passes are being called correctly. def test_compile_placeholder_bypass(): engine = relay.backend.compile_engine.get() x = relay.var("x", shape=(2, 3))