Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LoweredModule #8886

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ class Map : public ObjectRef {
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
* \return Handle to the internal node container(which guarantees to be unique)
*/
MapNode* CopyOnWrite() {
if (data_.get() == nullptr) {
Expand Down
1 change: 0 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/target/target_kind.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

Expand Down
20 changes: 14 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);

IRModule new_mod =
IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
Expand All @@ -598,9 +598,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
auto lowered_main = lowered_mod->Lookup("main");

auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
Expand Down Expand Up @@ -662,8 +665,13 @@ class AOTExecutorCodegen : public MixedModeVisitor {

ret.function_metadata = std::move(function_metadata_);

ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;
Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
ret.external_mods = external_modules.value();

if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
Expand Down
22 changes: 16 additions & 6 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

IRModule new_mod =
IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
Expand All @@ -236,9 +236,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

// Get only the Relay functions out of the lowered module so we can run type inference on them
IRModule main_module = tec::GetMainModule(lowered_mod);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok in a follow up if you like, but I think folding type inf into LowerTEPass makes sense to account for the rewritten calls. Someday we'll figure out how to gracefully do that incrementally since it's really overkill for just xfering known types from old to new CallNodes!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add it as a TODO

Copy link
Contributor Author

@electriclilies electriclilies Sep 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there's a slight problem with this -- you can't run InferType on the whole lowered_mod becauselowered_mod has functions with GlobalVars that are in the IRModule's global_var_map, but you can't find them using Lookup because they are not the same object (it fails on line 210)..
https://github.com/apache/tvm/blob/main/src/relay/transforms/type_infer.cc#L209:L215
My workaround was to only apply type inferencing to the main_module (since that was what was done before this PR).

I'm not sure if this is a bug in how the type inferencer is dealing with GlobalVars (maybe it should be doing lookup by name hint, not pointer equality) or if it's a bug in how those GlobalVars are being made / propagated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the other problem is even if you do successfully manage to lookup the function, there are PrimFuncs in the module which the type inferencer doesn't know how to deal with. We could just skip PrimFuncs we find during type inference..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK actually I think I fixed this in #8399

main_module = relay::transform::InferType()(main_module);
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));

Expand Down Expand Up @@ -270,8 +274,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}
ret.function_metadata = std::move(function_metadata_);
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
ret.external_mods = external_modules.value();
return ret;
}

Expand Down
42 changes: 21 additions & 21 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
// LoweredModule.
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
Expand Down Expand Up @@ -902,20 +903,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType()});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Things to initialize to pass into tec::LowerTEPass
// We only have one device-specific target.
tec::TargetMap targets = {{device.device_type, target}};

Expand All @@ -925,13 +913,25 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
// No need for a memory plan.
backend::StaticMemoryPlan memory_plan; /*=nullptr*/

// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Lower all primitive functions reachable from expr.
// TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to
// be merged into IRModule.
LoweredModule lowered_module =
tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ });
return {lowered_module.main_module, lowered_module.per_target_module};
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
}

/*! \brief Check if an expression could be changed by \p Prepare.
Expand Down
159 changes: 61 additions & 98 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,46 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<Target, IRModule> GetLoweredFunctions() {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
lowered_functions;
IRModule GetLoweredFunctions() {
IRModule mod;
// Extract lowered functions from the cache
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
IRModule lowered_mod = lowered_func->cached_func->funcs;

lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;

// Only add functions that are not external functions
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
ICHECK(func->IsInstance<tir::PrimFuncNode>())
<< "Expected all functions that are not external to be PrimFuncs, but found "
<< func->GetTypeKey();
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}
}
}
// Extract lowered dynamic shape functions from the shape cache
for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule lowered_mod = lowered_func->cached_func->funcs;

// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}

lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
return mod;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
Expand Down Expand Up @@ -830,9 +843,9 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);

TECompiler compiler;
Expand Down Expand Up @@ -864,76 +877,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
(*te_compiler_update_weights)(weight_map);
}

LoweredModule lowered_module;
lowered_module.main_module = updated_module;
lowered_module.per_target_module = compiler->GetLoweredFunctions();
lowered_module.external_mods = compiler->LowerExternalFunctions();
lowered_module.main_func_info = func_info;
return lowered_module;
}
// Copy the lowered functions into the return module
updated_module->Update(compiler->GetLoweredFunctions());

IRModule LoweredModuleToIRModule(LoweredModule mod) {
IRModule unified_module;
// Annotate the module with the external modules and function info
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
updated_module = WithAttr(updated_module, "main_func_info", func_info);

// Copy the main module and its typedefs
for (const auto& kv : mod.main_module->functions) {
unified_module->Add(kv.first, kv.second);
}
for (const auto& kv : mod.main_module->type_definitions) {
unified_module->AddTypeDef(kv.first, kv.second);
}

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
// there should be no type defs in the per_target_modules
size_t ty_def_size = target_module->type_definitions.size();
ICHECK(ty_def_size == 0)
<< "Expected there to be no type definitions in the per_target_modules, but found "
<< ty_def_size;

for (const auto& kv : target_module->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc primFunc =
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, primFunc);
} else if (func->IsInstance<relay::FunctionNode>()) {
relay::Function relayFunc =
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, relayFunc);
} else {
LOG(FATAL)
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
<< func->GetTypeKey();
}
}
}

IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
return ret_mod;
return updated_module;
}

LoweredModule IRModuleToLoweredModule(IRModule mod) {
IRModule main_mod;
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
// This is the only time we need to do this since there are no TypeDefs in TIR
for (const auto& kv : mod->type_definitions) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<Target, IRModule> per_target_modules;
Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";
Expand All @@ -943,44 +903,47 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules.Set(target.value(), target_module);
per_target_modules[target.value()] = target_module;
} else {
// The IRModule for this target is initialized, so just add the function.
IRModule target_module = per_target_modules.at(target.value());
target_module->Add(var, func);
}
} else {
} else if (!func->IsInstance<relay::FunctionNode>()) {
LOG(FATAL)
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
<< func->GetTypeKey();
}
}
return per_target_modules;
}

// Put the LoweredModule together
LoweredModule lowered_module;
lowered_module.main_module = main_mod;
lowered_module.per_target_module = per_target_modules;

// Extract external modules and main func info, add to lowered module if they exist
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
if (external_mods) {
lowered_module.external_mods = external_mods.value();
IRModule GetMainModule(IRModule mod) {
IRModule main_module;
// Copy the type defs
for (const auto& kv : mod->type_definitions) {
main_module->AddTypeDef(kv.first, kv.second);
}
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
if (main_func_info) {
lowered_module.main_func_info = main_func_info.value();
// Copy all Relay functions (we don't include PrimFuncs)
for (auto kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tvm::relay::FunctionNode>()) {
main_module->Add(var, func);
}
}
return lowered_module;
return main_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LoweredModuleToIRModule(
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
};
// TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
// be called afterwards
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
}
} // namespace tec
Expand Down
Loading