Skip to content

Commit

Permalink
Include primfuncs in update metadata
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Sidebottom <chris.sidebottom@arm.com>

Change-Id: I15e7e7e4fa864ddc469a4e430c4f472057290e8a
  • Loading branch information
ashutosh-arm committed Nov 16, 2021
1 parent 8de73b7 commit cdebe57
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 57 deletions.
11 changes: 8 additions & 3 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .._ffi import get_global_func
from ..contrib import utils
from ..driver import build_module
from ..driver.tvmc.composite_target import get_codegen_names
from ..runtime import ndarray as _nd
from ..relay.backend import executor_factory
from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
Expand Down Expand Up @@ -174,11 +175,12 @@ def _build_function_memory_map(function_metadata):
device_max_workspace = dict()
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
external_codegens = get_codegen_names()
func_entries = []
target_local_entries = dict()
for i in range(num_targets):
target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[target] = 0
main_target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[main_target] = 0
for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
Expand All @@ -201,8 +203,11 @@ def _build_function_memory_map(function_metadata):
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size > device_max_workspace[target]:
if workspace_size > device_max_workspace.get(target, 0):
device_max_workspace[target] = workspace_size
# TODO(Mousius) - Remove this massive hack when Targets are unified
if target.kind.name in external_codegens:
device_max_workspace[main_target] += int(workspace_size)

for func_name, target_entries_ in target_local_entries.items():
func_entry = {
Expand Down
28 changes: 8 additions & 20 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,25 +648,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}

LoweredOutput Codegen(relay::Function func, String mod_name) {
AOTOnDemandAllocator initial_aot_allocator;
initial_aot_allocator.Run(func);

// Pre-lowering storage map and memory plan
// TODO(mbs): Why plan memory and update workspace sizes before lowering?
StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap();
StaticMemoryPlan memory_plan(initial_storage_map);

IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) {
IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -683,12 +666,17 @@ class AOTExecutorCodegen : public MixedModeVisitor {
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
// created, just referencing the new expressions created from lowering
// Post-lowering storage map for writing main func
AOTOnDemandAllocator final_aot_allocator;
final_aot_allocator.Run(lowered_main_func);
storage_device_map_ = final_aot_allocator.GetStorageMap();

// TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize
StaticMemoryPlan memory_plan(storage_device_map_);
backend::FunctionInfo func_info =
tec::UpdateMainWorkspaceSize(lowered_mod, targets_, memory_plan->expr_to_storage_info);
lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var("input", DataType::Handle()));
Expand Down
20 changes: 10 additions & 10 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ namespace cmsisnn {
class RelayToTIRVisitor : public MixedModeMutator {
public:
explicit RelayToTIRVisitor(IRModule ir_module, Target target)
: ir_module_(ir_module), target_(target) { context_buffer_id_ = 0; }
: ir_module_(ir_module), target_(target) {
context_buffer_id_ = 0;
}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Expand Down Expand Up @@ -71,8 +73,8 @@ class RelayToTIRVisitor : public MixedModeMutator {
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

if (context_buffer_size) {
tir::Var buffer_var(
context_buffer_name, PointerType(PrimType(DataType::Int(8)), "global.workspace"));
tir::Var buffer_var(context_buffer_name,
PointerType(PrimType(DataType::Int(8)), "global.workspace"));
body = tir::Allocate(buffer_var, DataType::Int(8), {context_buffer_size}, tir::const_true(),
body);
body =
Expand All @@ -83,8 +85,6 @@ class RelayToTIRVisitor : public MixedModeMutator {
tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

LOG(INFO) << PrettyPrint(replacement_func);

ir_module_->Add(global_var, replacement_func);
}

Expand All @@ -95,8 +95,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(qnn::get_const_int(shape[3]))};
}

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
const CallNode* bias_add_call = nullptr;
Expand Down Expand Up @@ -194,7 +193,8 @@ void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
if (context_buffer_size) {
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
}
tvm::Array<PrimExpr> context_buffer_args = { tir::StringImm(context_buffer_name), ToArg(context_buffer_size) };
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, input_dims);
Expand All @@ -211,8 +211,8 @@ void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
func_signature.push_back(shift);
func_signature.push_back(output);

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args,
context_buffer_name, context_buffer_size);
CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
Expand Down
2 changes: 0 additions & 2 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
LOG(INFO) << "------------------------";
LOG(INFO) << PrettyPrint(prim_func);
codegen.AddFunction(prim_func);
}
std::string code = codegen.Finish();
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(mod_name_, [this](Function func) {
IRModule lowered_mod = tec::LowerTEPass(mod_name_, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) {
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(/*module_name=*/"intrp", [](Function func) { /* no-op */ })});
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
44 changes: 32 additions & 12 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,20 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
// Already lowered by other means so we don't need to mutate
// the call but we do need to mutate the arguments
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
// Function should already be Target annotated by this point
// but the TE Compiler metadata is still needed for the callback
// TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks
GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
tir::PrimFunc downcast_prim_func = Downcast<tir::PrimFunc>(prim_func);

Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, downcast_prim_func}};
tir::PrimFunc func_with_metadata =
WithAttrs(downcast_prim_func, {
{"prim_fn_var", prim_func_var},
{"prim_funcs", prim_fns},
});

this->process_fn_(func_with_metadata);
return Call(call_node->op, visited_args, call_node->attrs);
}

Expand Down Expand Up @@ -682,8 +696,7 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
}
}

Pass LowerTensorExpr(const String& module_name, TECompiler compiler,
std::function<void(Function)> process_fn) {
Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler);
Expand Down Expand Up @@ -831,13 +844,13 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
* \param relay_func The function to calculate function metadata for
* \param func The function to calculate function metadata for
* \param function_metadata The map that stores all the function metadatas
*/
void UpdateFunctionMetadata(Function relay_func,
void UpdateFunctionMetadata(BaseFunc func,
Map<String, backend::FunctionInfo>& function_metadata) { // NOLINT(*)
VLOG_CONTEXT << "UpdateFunctionMetadata";
VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(relay_func);
VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func);
// Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored
// there Now the goal is to take only one func because process_fn should be controlling the
// iteration However, to do the workspace calculations we need the primfuncs. So process_fn
Expand All @@ -852,13 +865,13 @@ void UpdateFunctionMetadata(Function relay_func,
Map<Target, Function> relay_primfuncs;

Optional<Map<GlobalVar, tir::PrimFunc>> prim_fns =
relay_func->GetAttr<Map<GlobalVar, tir::PrimFunc>>("prim_funcs");
func->GetAttr<Map<GlobalVar, tir::PrimFunc>>("prim_funcs");
CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler.";

Optional<GlobalVar> prim_fn_var = relay_func->GetAttr<GlobalVar>("prim_fn_var");
Optional<GlobalVar> prim_fn_var = func->GetAttr<GlobalVar>("prim_fn_var");
CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler.";

Optional<Target> relay_target = relay_func->GetAttr<Target>(tvm::attr::kTarget);
Optional<Target> relay_target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(relay_target) << "target must be set on Relay functions by the TECompiler.";

for (const auto& kv : prim_fns.value()) {
Expand All @@ -883,6 +896,12 @@ void UpdateFunctionMetadata(Function relay_func,
// Calculating size for I/O
// TODO(mbs): See also the other three utils for calculating tensor bytesize.
for (auto const& param : prim_fn->params) {
bool not_a_buffer = prim_fn->buffer_map.count(param) == 0;
if (not_a_buffer) {
io_sizes.Set(prim_fn_target, 0);
continue;
}

auto p_shape = prim_fn->buffer_map[param]->shape;
int num_of_elements = 1;
for (const auto& dim_index_expr : p_shape) {
Expand All @@ -899,7 +918,9 @@ void UpdateFunctionMetadata(Function relay_func,

constant_sizes.Set(prim_fn_target, 0);
tir_primfuncs.Set(prim_fn_target, prim_fn);
relay_primfuncs.Set(prim_fn_target, relay_func);
if (func->IsInstance<FunctionNode>()) {
relay_primfuncs.Set(prim_fn_target, Downcast<Function>(func));
}
}

backend::FunctionInfo fi = backend::FunctionInfo(
Expand All @@ -913,8 +934,7 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

IRModule LowerTE(const IRModule& module, const String& module_name,
std::function<void(Function)> process_fn) {
IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn) {
TECompiler compiler;

auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module);
Expand Down Expand Up @@ -966,7 +986,7 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

Pass LowerTEPass(const String& module_name, std::function<void(Function)> process_fn) {
Pass LowerTEPass(const String& module_name, ProcessFn process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn); };

Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace tec {
using TargetMap = std::unordered_map<DLDeviceType, Target, backend::EnumClassHash>;
using DeviceMap =
std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
using ProcessFn = std::function<void(Function)>;
using ProcessFn = std::function<void(BaseFunc)>;

/*!
* \brief A compiler which lowers primitive Relay functions to tensor expressions
Expand Down Expand Up @@ -140,10 +140,10 @@ class TECompiler : public ObjectRef {
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
* \param relay_func The function to calculate function metadata for
* \param func The function to calculate function metadata for
* \param function_metadata The map that stores all the function metadatas
*/
void UpdateFunctionMetadata(Function relay_func,
void UpdateFunctionMetadata(BaseFunc func,
Map<String, backend::FunctionInfo>& function_metadata); // NOLINT(*)

/*!
Expand Down Expand Up @@ -188,7 +188,7 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod);
*/
IRModule LowerTE(
const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](Function f) {});
ProcessFn process_fn = [](BaseFunc f) {});

/*! \brief Pass to lower an IRModule's primitive functions to TIR.
*
Expand All @@ -201,7 +201,7 @@ IRModule LowerTE(
* each function that we lower
* \returns The pass which lowers primitive functions to TIR
*/
transform::Pass LowerTEPass(const String& module_name, std::function<void(Function)> process_fn);
transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct ConstantUpdater : public ExprVisitor {
* \param func The function from which to get the constant params.
* \param params The params to update with the constants.
*/
inline void UpdateConstants(Function func,
inline void UpdateConstants(BaseFunc func,
std::unordered_map<std::string, runtime::NDArray>* params) {
VLOG_CONTEXT << "UpdateConstants";
VLOG(1) << "updating constants for:" << std::endl << PrettyPrint(func);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,7 @@ void VMCompiler::Codegen() {
// Collect metadata in functions that are handled by external codegen.
auto name = cfunc->prim_fn_var->name_hint;
ICHECK(mod->ContainGlobalVar(name));
Function func = Downcast<Function>(mod->Lookup(name));
backend::UpdateConstants(func, &params_);
backend::UpdateConstants(mod->Lookup(name), &params_);
} else if (funcs.count(target) == 0) {
funcs.Set(target, mod);
} else {
Expand Down

0 comments on commit cdebe57

Please sign in to comment.