Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Remove memory planing from LowerTEPass #8974

Merged
merged 18 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -583,8 +583,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just put the func_info on the mod here before passing the module into LowerTE? Then you don't need to re-extract it later, and also the logic surrounding func_info is all in one place. (LowerTEPass should preserve all attributes on modules passed into it)

Comment on lines +586 to +592
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we condense this to:

Suggested change
backend::FunctionInfo func_info;
if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}
if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
backend::FunctionInfo func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}


IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand Down Expand Up @@ -661,7 +669,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I said above, can you attach the func_info right before LowerTEPass is called? And we can then remove the check about whether the attribute is on the module.

main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size);
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

Expand Down
17 changes: 13 additions & 4 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -221,8 +221,17 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

backend::FunctionInfo func_info;

if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info =
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in aot_executor_codegen -- can we put the func_info on the module here instead of after LowerTEPass is called and delete the check for main_func_info being set?

Comment on lines +224 to +232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above:

Suggested change
backend::FunctionInfo func_info;
if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info =
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}
if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
backend::FunctionInfo func_info =
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -238,7 +247,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";

function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));
Expand Down
28 changes: 12 additions & 16 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*
* @param prim_fn_var Global bound to lowered primitive.
* @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed.
* @param all_prim_shape_fn_vars All globals references by lowered shape function, plus
* prim_shape_fn_var itself.
* @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic
Expand Down Expand Up @@ -763,7 +763,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
const auto* adt_obj = val.as<ADTObj>();
ICHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value";
ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value";
auto adt = GetRef<ADT>(adt_obj);
ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds";
return adt[op->index];
Expand Down Expand Up @@ -902,21 +902,17 @@ IRModule Prepare(IRModule mod, Device device, Target target) {
// All calls to primitives will use the unique target.
tec::DeviceMap device_map;

// No need for a memory plan.
backend::StaticMemoryPlan memory_plan; /*=nullptr*/

// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});
transform::Sequential seq({transform::SimplifyInference(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we have inconsistent formatters? In any case I'd revert this whitespace change.

// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
73 changes: 28 additions & 45 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

#include "te_compiler.h"
#include "./te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
Expand All @@ -42,8 +42,8 @@
#include <utility>
#include <vector>

#include "te_compiler_cache.h"
#include "utils.h"
#include "./te_compiler_cache.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -596,19 +596,7 @@ class LowerTensorExprMutator : public ExprMutator {
const Op& debug_op_;
};

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we don't move these to keep the diff down.

backend::StaticMemoryPlan memory_plan, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
if (targets.size() == 1) {
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
Expand Down Expand Up @@ -638,26 +626,30 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
}
}

/*!
* \brief Update the "main" control function's metadata
*
* \param mod The module
* \param targets Map of targets
* \return function_infos Function info for each function in the module
*/
Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets,
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
CHECK_EQ(mod->functions.size(), 1)
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
Function func = Downcast<Function>(mod->Lookup("main"));

// This is a Map<device,Map<storage_id, size>>
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, EnumClassHash> sid_workspace;
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
sid_workspace;
// This is a Map<device, size_of_inputs_and_outputs>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_io;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
// This is a Map<device, size_of_constants>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_consts;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;

// Initialize the mapping from all storage identifiers to workspace sizes,
// the amount of device io, and the device constants.
Expand Down Expand Up @@ -723,7 +715,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

// This is a Map<device, workspace_size>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_workspace;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
// Once we know the sizes of sids, we need to accumulate per device
for (const auto& dev_sid_size : sid_workspace) {
auto dev = dev_sid_size.first;
Expand All @@ -746,17 +738,17 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

for (const auto& dev_and_size : device_workspace) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
workspace_sizes.Set(tgt, dev_and_size.second);
relay_primfuncs.Set(tgt, func);
}
for (const auto& dev_and_size : device_io) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
io_sizes.Set(tgt, dev_and_size.second);
}

for (const auto& dev_and_size : device_consts) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
constant_sizes.Set(tgt, dev_and_size.second);
}

Expand Down Expand Up @@ -844,20 +836,13 @@ void UpdateFunctionMetadata(Function relay_func,
}

IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
const String& module_name, std::function<void(Function)> process_fn) {
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);

TECompiler compiler;

backend::FunctionInfo func_info;
if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info);
}

auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name,
compiler, process_fn)(module);
auto updated_module =
LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module);

// A temporary solution until we can rewrite the auto-scheduler task extraction code to work
// in a more reasonable way.
Expand All @@ -882,7 +867,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con

// Annotate the module with the external modules and function info
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
updated_module = WithAttr(updated_module, "main_func_info", func_info);

return updated_module;
}
Expand Down Expand Up @@ -919,12 +903,11 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
return LowerTE(module, targets, device_context_map, module_name, process_fn);
};
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
Expand Down
27 changes: 13 additions & 14 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,15 @@
#include "../transforms/infer_layout_utils.h"
#include "../transforms/pass_utils.h"
#include "./te_compiler_cache.h"
#include "utils.h"
#include "./utils.h"

namespace tvm {
namespace relay {
namespace tec {

// This class is needed to avoid a GCC 5 bug that prevents maps containing enums
// from being compiled. If i386 GCC version is increased, we can remove it.
struct EnumClassHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};

// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
// we should a version of context which works in Map
using TargetMap = std::unordered_map<DLDeviceType, Target, EnumClassHash>;
using TargetMap = std::unordered_map<DLDeviceType, Target, backend::EnumClassHash>;
using DeviceMap =
std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
using ProcessFn = std::function<void(Function)>;
Expand Down Expand Up @@ -158,6 +149,16 @@ void UpdateFunctionMetadata(Function relay_func,
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);

/*!
* \brief Update the "main" control function's metadata
*
* \param mod The module
* \param targets Map of targets
* \return function_infos Function info for each function in the module
*/
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map);

/*! \brief Utility to separate the functions in an IRModule by Target.
*
* \param mod The IRModule to extract the per target module from
Expand Down Expand Up @@ -192,15 +193,13 @@ IRModule LowerTE(
*
* \param targets The mapping for devices to targets.
* \param device_context_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \returns The pass which lowers primative functions to TIR
*/
transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn);
const String& module_name, std::function<void(Function)> process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ struct LoweredOutput {
runtime::Metadata metadata;
};

/*!
* \brief This class is needed to avoid a GCC 5 bug that prevents maps containing enums from being
compiled. If i386 GCC version is increased, we can remove it.
*/
struct EnumClassHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};

/*!
* \brief A helper to expand the params by adding the ones used in a given expression.
*/
Expand Down