-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Remove memory planing from LowerTEPass #8974
Changes from all commits
f50495b
6d00a9c
3d2ba1b
ee52039
c68f446
48cb35d
8a0a3c9
074ed40
9733283
0a1beb6
03a208c
5b3fb11
b965443
66168f7
8dfec81
c4cc22d
ed2c1cc
ce9fbdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -38,8 +38,8 @@ | |||||||||||||||||||||||||
#include <string> | ||||||||||||||||||||||||||
#include <vector> | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#include "te_compiler.h" | ||||||||||||||||||||||||||
#include "utils.h" | ||||||||||||||||||||||||||
#include "./te_compiler.h" | ||||||||||||||||||||||||||
#include "./utils.h" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
namespace tvm { | ||||||||||||||||||||||||||
namespace relay { | ||||||||||||||||||||||||||
|
@@ -583,8 +583,16 @@ class AOTExecutorCodegen : public MixedModeVisitor { | |||||||||||||||||||||||||
// performing the preexisting AOT executor code generation phase. | ||||||||||||||||||||||||||
IRModule mod = IRModule::FromExpr(func); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
backend::FunctionInfo func_info; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if (memory_plan.defined()) { | ||||||||||||||||||||||||||
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize | ||||||||||||||||||||||||||
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info); | ||||||||||||||||||||||||||
mod = WithAttr(mod, "main_func_info", func_info); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
Comment on lines
+586
to
+592
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we condense this to:
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
IRModule lowered_mod = | ||||||||||||||||||||||||||
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { | ||||||||||||||||||||||||||
tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) { | ||||||||||||||||||||||||||
// We need to maintain the constant map for external | ||||||||||||||||||||||||||
// functions so we pass this processing function which | ||||||||||||||||||||||||||
// allows us to process each function as we lower it. | ||||||||||||||||||||||||||
|
@@ -661,7 +669,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Optional<backend::FunctionInfo> main_func_info = | ||||||||||||||||||||||||||
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info"); | ||||||||||||||||||||||||||
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like I said above, can you attach the func_info right before LowerTEPass is called? And we can then remove the check about whether the attribute is on the module. |
||||||||||||||||||||||||||
main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size); | ||||||||||||||||||||||||||
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -36,8 +36,8 @@ | |||||||||||||||||||||||||||||||
#include <string> | ||||||||||||||||||||||||||||||||
#include <vector> | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#include "te_compiler.h" | ||||||||||||||||||||||||||||||||
#include "utils.h" | ||||||||||||||||||||||||||||||||
#include "./te_compiler.h" | ||||||||||||||||||||||||||||||||
#include "./utils.h" | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
namespace tvm { | ||||||||||||||||||||||||||||||||
namespace relay { | ||||||||||||||||||||||||||||||||
|
@@ -221,8 +221,17 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector< | |||||||||||||||||||||||||||||||
device_context_map.insert({expr, dev}); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
backend::FunctionInfo func_info; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (memory_plan_.defined()) { | ||||||||||||||||||||||||||||||||
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize | ||||||||||||||||||||||||||||||||
func_info = | ||||||||||||||||||||||||||||||||
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info); | ||||||||||||||||||||||||||||||||
mod = WithAttr(mod, "main_func_info", func_info); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as in aot_executor_codegen -- can we put the func_info on the module here instead of after LowerTEPass is called and delete the check for main_func_info being set?
Comment on lines
+224
to
+232
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above:
Suggested change
|
||||||||||||||||||||||||||||||||
IRModule lowered_mod = | ||||||||||||||||||||||||||||||||
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) { | ||||||||||||||||||||||||||||||||
tec::LowerTEPass(targets_, device_context_map, mod_name_, [this](Function func) { | ||||||||||||||||||||||||||||||||
// We need to maintain the constant map for external | ||||||||||||||||||||||||||||||||
// functions so we pass this processing function which | ||||||||||||||||||||||||||||||||
// allows us to process each function as we lower it. | ||||||||||||||||||||||||||||||||
|
@@ -238,7 +247,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector< | |||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Optional<backend::FunctionInfo> main_func_info = | ||||||||||||||||||||||||||||||||
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info"); | ||||||||||||||||||||||||||||||||
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main")); | ||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -542,7 +542,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
* | ||
* @param prim_fn_var Global bound to lowered primitive. | ||
* @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself. | ||
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded. | ||
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed. | ||
* @param all_prim_shape_fn_vars All globals references by lowered shape function, plus | ||
* prim_shape_fn_var itself. | ||
* @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic | ||
|
@@ -763,7 +763,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
ObjectRef VisitExpr_(const TupleGetItemNode* op) final { | ||
ObjectRef val = Eval(op->tuple); | ||
const auto* adt_obj = val.as<ADTObj>(); | ||
ICHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value"; | ||
ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value"; | ||
auto adt = GetRef<ADT>(adt_obj); | ||
ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds"; | ||
return adt[op->index]; | ||
|
@@ -902,21 +902,17 @@ IRModule Prepare(IRModule mod, Device device, Target target) { | |
// All calls to primitives will use the unique target. | ||
tec::DeviceMap device_map; | ||
|
||
// No need for a memory plan. | ||
backend::StaticMemoryPlan memory_plan; /*=nullptr*/ | ||
|
||
// Run minimal transforms on module to establish invariants needed by interpreter. | ||
transform::Sequential seq( | ||
{transform::SimplifyInference(), | ||
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive' | ||
// attribute. | ||
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), | ||
// eta expand to support constructors in argument position | ||
transform::EtaExpand( | ||
/*expand_constructor=*/true, /*expand_global_var=*/false), | ||
transform::InferType(), | ||
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp", | ||
[](Function func) { /* no-op */ })}); | ||
transform::Sequential seq({transform::SimplifyInference(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems we have inconsistent formatters? In any case I'd revert this whitespace change. |
||
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive' | ||
// attribute. | ||
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), | ||
// eta expand to support constructors in argument position | ||
transform::EtaExpand( | ||
/*expand_constructor=*/true, /*expand_global_var=*/false), | ||
transform::InferType(), | ||
tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp", | ||
[](Function func) { /* no-op */ })}); | ||
|
||
transform::PassContext pass_ctx = transform::PassContext::Current(); | ||
With<transform::PassContext> ctx(pass_ctx); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
* under the License. | ||
*/ | ||
|
||
#include "te_compiler.h" | ||
#include "./te_compiler.h" | ||
|
||
#include <tvm/driver/driver_api.h> | ||
#include <tvm/ir/attrs.h> | ||
|
@@ -42,8 +42,8 @@ | |
#include <utility> | ||
#include <vector> | ||
|
||
#include "te_compiler_cache.h" | ||
#include "utils.h" | ||
#include "./te_compiler_cache.h" | ||
#include "./utils.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
@@ -596,19 +596,7 @@ class LowerTensorExprMutator : public ExprMutator { | |
const Op& debug_op_; | ||
}; | ||
|
||
Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we don't move these to keep the diff down. |
||
backend::StaticMemoryPlan memory_plan, const String& module_name, | ||
TECompiler compiler, std::function<void(Function)> process_fn) { | ||
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = | ||
[=](Function func, IRModule module, PassContext ctx) { | ||
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, | ||
module_name, compiler); | ||
return Downcast<Function>(lower_te.Mutate(func)); | ||
}; | ||
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); | ||
} | ||
|
||
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { | ||
Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { | ||
if (targets.size() == 1) { | ||
// The homogeneous execution case, return the only target. | ||
const auto& it = targets.begin(); | ||
|
@@ -638,26 +626,30 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { | |
} | ||
} | ||
|
||
/*! | ||
* \brief Update the "main" control function's metadata | ||
* | ||
* \param mod The module | ||
* \param targets Map of targets | ||
* \return function_infos Function info for each function in the module | ||
*/ | ||
Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name, | ||
TECompiler compiler, std::function<void(Function)> process_fn) { | ||
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = | ||
[=](Function func, IRModule module, PassContext ctx) { | ||
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, | ||
module_name, compiler); | ||
return Downcast<Function>(lower_te.Mutate(func)); | ||
}; | ||
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); | ||
} | ||
|
||
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets, | ||
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets, | ||
Map<Expr, backend::StorageInfo> storage_info_map) { | ||
CHECK_EQ(mod->functions.size(), 1) | ||
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize"; | ||
Function func = Downcast<Function>(mod->Lookup("main")); | ||
|
||
// This is a Map<device,Map<storage_id, size>> | ||
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, EnumClassHash> sid_workspace; | ||
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash> | ||
sid_workspace; | ||
// This is a Map<device, size_of_inputs_and_outputs> | ||
std::unordered_map<DLDeviceType, int, EnumClassHash> device_io; | ||
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io; | ||
// This is a Map<device, size_of_constants> | ||
std::unordered_map<DLDeviceType, int, EnumClassHash> device_consts; | ||
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts; | ||
|
||
// Initialize the mapping from all storage identifiers to workspace sizes, | ||
// the amount of device io, and the device constants. | ||
|
@@ -723,7 +715,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar | |
} | ||
|
||
// This is a Map<device, workspace_size> | ||
std::unordered_map<DLDeviceType, int, EnumClassHash> device_workspace; | ||
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace; | ||
// Once we know the sizes of sids, we need to accumulate per device | ||
for (const auto& dev_sid_size : sid_workspace) { | ||
auto dev = dev_sid_size.first; | ||
|
@@ -746,17 +738,17 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar | |
} | ||
|
||
for (const auto& dev_and_size : device_workspace) { | ||
auto tgt = GetTargetFromInteger(dev_and_size.first, targets); | ||
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); | ||
workspace_sizes.Set(tgt, dev_and_size.second); | ||
relay_primfuncs.Set(tgt, func); | ||
} | ||
for (const auto& dev_and_size : device_io) { | ||
auto tgt = GetTargetFromInteger(dev_and_size.first, targets); | ||
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); | ||
io_sizes.Set(tgt, dev_and_size.second); | ||
} | ||
|
||
for (const auto& dev_and_size : device_consts) { | ||
auto tgt = GetTargetFromInteger(dev_and_size.first, targets); | ||
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets); | ||
constant_sizes.Set(tgt, dev_and_size.second); | ||
} | ||
|
||
|
@@ -844,20 +836,13 @@ void UpdateFunctionMetadata(Function relay_func, | |
} | ||
|
||
IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, | ||
backend::StaticMemoryPlan memory_plan, const String& module_name, | ||
std::function<void(Function)> process_fn) { | ||
const String& module_name, std::function<void(Function)> process_fn) { | ||
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); | ||
|
||
TECompiler compiler; | ||
|
||
backend::FunctionInfo func_info; | ||
if (memory_plan.defined()) { | ||
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize | ||
func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); | ||
} | ||
|
||
auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name, | ||
compiler, process_fn)(module); | ||
auto updated_module = | ||
LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module); | ||
|
||
// A temporary solution until we can rewrite the auto-scheduler task extraction code to work | ||
// in a more reasonable way. | ||
|
@@ -882,7 +867,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con | |
|
||
// Annotate the module with the external modules and function info | ||
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions()); | ||
updated_module = WithAttr(updated_module, "main_func_info", func_info); | ||
|
||
return updated_module; | ||
} | ||
|
@@ -919,12 +903,11 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) { | |
return per_target_modules; | ||
} | ||
|
||
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, | ||
backend::StaticMemoryPlan memory_plan, const String& module_name, | ||
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name, | ||
std::function<void(Function)> process_fn) { | ||
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module, | ||
PassContext ctx) { | ||
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn); | ||
return LowerTE(module, targets, device_context_map, module_name, process_fn); | ||
}; | ||
return tvm::transform::Sequential( | ||
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just put the
func_info
on the mod here before passing the module into LowerTE? Then you don't need to re-extract it later, and also the logic surrounding func_info is all in one place. (LowerTEPass should preserve all attributes on modules passed into it)