Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass #8802

Merged
merged 12 commits into from
Aug 24, 2021
1 change: 0 additions & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";

/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr
Expand Down
8 changes: 5 additions & 3 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {

IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -234,8 +234,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
main_module = relay::transform::InferType()(main_module);
Expand Down
117 changes: 112 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
Expand Down Expand Up @@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
relay_primfuncs);
}

// TODO(@electriclilies): Is the function passed in here relay_func??
// Also should this be inlined?
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
Expand Down Expand Up @@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule.
// Currently we rely on accumulating bindings inside the local TECompiler which we then
// host into the LoweredModule result.
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
Expand Down Expand Up @@ -875,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
return lowered_module;
}

IRModule LoweredModuleToIRModule(LoweredModule mod) {
IRModule unified_module;

// Copy the main module and its typedefs
for (const auto& kv : mod.main_module->functions) {
unified_module->Add(kv.first, kv.second);
}
for (const auto& kv : mod.main_module->type_definitions) {
unified_module->AddTypeDef(kv.first, kv.second);
}

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const String target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
// there should be no type defs in the per_target_modules
size_t ty_def_size = target_module->type_definitions.size();
ICHECK(ty_def_size == 0)
<< "Expected there to be no type definitions in the per_target_modules, but found "
<< ty_def_size;

for (const auto& kv : target_module->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc primFunc =
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, primFunc);
} else if (func->IsInstance<relay::FunctionNode>()) {
relay::Function relayFunc =
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, relayFunc);
} else {
LOG(FATAL)
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
<< func->GetTypeKey();
}
}
}

IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
return ret_mod;
}

LoweredModule IRModuleToLoweredModule(IRModule mod) {
IRModule main_mod;
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
// This is the only time we need to do this since there are no TypeDefs in TIR
for (const auto& kv : mod->type_definitions) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<String, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules.Set(target.value(), target_module);
} else {
// The IRModule for this target is initialized, so just add the function.
IRModule target_module = per_target_modules.at(target.value());
target_module->Add(var, func);
}
} else {
LOG(FATAL)
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
<< func->GetTypeKey();
}
}

// Put the LoweredModule together
LoweredModule lowered_module;
lowered_module.main_module = main_mod;
lowered_module.per_target_module = per_target_modules;

// Extract external modules and main func info, add to lowered module if they exist
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
if (external_mods) {
lowered_module.external_mods = external_mods.value();
}
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
if (main_func_info) {
lowered_module.main_func_info = main_func_info.value();
}
return lowered_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LoweredModuleToIRModule(
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
};
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[A question to the community and not specific to your PR Lily!]

This is a good example of code which could be easily unit tested in C++ in the, er, 'conventional' sense. That is, as a reader I could expect to go to tests/cpp/relay/backend/te_compiler_test.cc and look for TEST(IRModuleToLoweredModule, ...). Currently this new code is tested indirectly via it's use by LowerTEPass and consumers of such, which in turn are tested indirectly by virtue of everything passing into TIR via this choke point. Just wanted to test the water on whether folks on this PR have opinions here so I don't go off tilting at windmills.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Since these two functions are supposed to be inverses of each other, it would be pretty easy to write a unit test for it in theory. When I was developing, I actually inserted the conversions in some places and ran existing unit tests to make sure that the functions worked, but it would be great to have a way to directly write unit tests in C++. That way I wouldn't have to remove my tests before merging!

Copy link
Member

@tqchen tqchen Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently when we unit-test passes by exposing them via a python API, construct the expected input and output and run the tests in python:

https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tir_transform_loop_partition.py#L30

There are certainly pros and cons of doing so. The original rationale is that we require most of the compiler passed to be accessible from python and it is relatively easier to construct and expand test cases.

We could revisit this pt on the need of the related testcases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@electriclilies is there any reason you can't create a file similar to https://github.com/apache/tvm/blob/main/tests/cpp/build_module_test.cc and test the functions there?

Ideally I'd definitely like to see a C++ test setup as @mbs-octoml describes rather than the single folder but this would work here? It's not an absolute rule that we must expose via Python for testing is it @tqchen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for most of the passes that can be modularized, we encourage the python first principle and expose via python. This one is a intermediate state so it is not an absolute rule to do so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. Perhaps a rule of thumb here is if it's part of the public api it should be tested on the py side, but otherwise should stay on the c++ side. I'm struggling to see how to write targeted unit tests on the py side without both risking making something internal part of the defacto api and without paying for all the unit test boundaries be ffi-able.

} // namespace tec
} // namespace relay
} // namespace tvm
49 changes: 46 additions & 3 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,73 @@ void UpdateFunctionMetadata(Function relay_func,
/*!
* \brief Obtain the Target from the device type.
* If homogenous compilation, this will return the only target.
* If heteregenous compilation, this will select associated using the targets_ Map.
* If heterogeneous compilation, this will select the associated target using the
* targets_ Map.
*
* \param dev_type
* \return Target
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);

/*! \brief Utility to convert a LoweredModule to an IRModule.
*
* This function takes all the target specific modules in LoweredModule and
* annotates their functions with the correct target, and puts all those functions
* in one IRModule.
* The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase.
*
* \param mod The LoweredModule to convert.
* \return The IRModule form of the input LoweredModule.
*/
IRModule LoweredModuleToIRModule(LoweredModule mod);

/*! \brief Utility to convert an IRModule to a LoweredModule.
*
* This function takes all the functions in the IRModule and moves them into target-specific
* IRModules stored inside a LoweredModule.
* The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase.
* \param mod The IRModule to convert.
* \return The LoweredModule form of the input IRModule.
*/
LoweredModule IRModuleToLoweredModule(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* \param compiler The TE-to-TIR compliler (which caches lowered functions)
* \param module The IRModule.
* \param targets The mapping for devices to targets.
* \param device_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \return The lowered module, see above.
*/
// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(
const IRModule& module, TargetMap targets, DeviceMap device_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](Function f) {});

/*! \brief Pass to lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and
* uses LoweredModuleToIRModule utility to convert the output LowerTE's output
* LoweredModule into an IRModule before returning it.
*
* \param targets The mapping for devices to targets.
* \param device_context_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \returns The pass which lowers primative functions to TIR
*/
transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm
Expand Down