Skip to content

Commit

Permalink
[BYOC] Prevent duplicate outputs in subgraph Tuple (apache#5320)
Browse files Browse the repository at this point in the history
* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential
  • Loading branch information
Trevor Morris authored and zhiics committed Apr 17, 2020
1 parent cc43a1c commit e839876
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 101 deletions.
226 changes: 125 additions & 101 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
// region_function_calls is map that maintains
// (each annotated regions) --> created function

if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the
// region Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.

// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);

auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function

Array<Expr> fields;

for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);

Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;

for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}

Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}

std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);

// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;

if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
if (region_function_calls.find(region) == region_function_calls.end()) {
// First time this region is encountered in the traversal.
// Creating the function.
CreateFunction(region, call);
}
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
}
}

Expand Down Expand Up @@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
}

/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
* \brief This function is called first time that we encounter a compiler_end
* node to create the function for the subgraph.
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
return idx;
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
// Create fields which is a unique list of outputs. Also populate
// region_return_indices_ map which maps parent of compiler_end node to
// corresponding index in fields.
Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
// Don't duplicate outputs.
if (!region_return_indices_.count(region) ||
!region_return_indices_[region].count(ret_node)) {
auto ret_expr = VisitExpr(ret_node);
fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i;
i++;
}
idx++;
}
return -1;

Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;

for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}

Function global_region_func;
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}

std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);

// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
}

/*!
* \brief Get the return(output) of the function for compiler end node "end_arg".
* This will return either a Call (for a function with a single output) or a
* TupleGetItem (for a function with multiple outputs).
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_return_indices_[region].size() == 1) {
return region_function_calls[region];
}
// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_return_tuplegetitem_.count(region) &&
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];

auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
}

/*!
Expand All @@ -485,6 +492,23 @@ class Partitioner : public ExprMutator {
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;

/*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;

/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;

/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
Expand Down
Loading

0 comments on commit e839876

Please sign in to comment.