Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Turn reshape into nop in graph executor backend. #7945

Merged
merged 3 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";

/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there other ops that could benefit? for example squeeze, expand_dims, etc.? I feel like this should more generally be like relay.alias_only (although alias might be too general since you could "alias" a subset of an input by slicing)

} // namespace attr

} // namespace relay
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ using TOpIsStateful = bool;
*/
using TNonComputational = bool;

/*!
* \brief Mark the operator as reshape op of its first input
* and can be turned into a nop when the first input and output
* shares the same piece of memory.
*/
using TReshapeOp = bool;

/*!
* \brief Mark the operator whether output shape is data dependent.
*/
Expand Down
23 changes: 23 additions & 0 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return AddNode(node, GetRef<Expr>(op));
}

bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
auto lit = storage_device_map_.find(lhs);
auto rit = storage_device_map_.find(rhs);
ICHECK(lit != storage_device_map_.end());
ICHECK(rit != storage_device_map_.end());
int64_t lhs_storage_id = ((*lit).second)[0][0]->value;
int64_t rhs_storage_id = ((*rit).second)[0][0]->value;
return lhs_storage_id == rhs_storage_id;
}

std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
Expr expr = GetRef<Expr>(op);
Function func;
Expand Down Expand Up @@ -380,6 +390,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
}

// In the current flat memory allocation scenario
// the flat memory allocator can always allocate input
// and output of the reshape to the same memory, we can turn reshape only
// function to a nop.
//
// NOTE that for non-flat memory this is not necessarily true.
//
// TODO(tvm-team) Update checks of flat memory enablement when we support
// opaque-nd memory planning to skip this path.
if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) {
return GraphAddCallNode(op, "reshape_nop", "__nop");
}

ICHECK_GE(storage_device_map_.count(expr), 0);
auto& device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
Expand Down
45 changes: 43 additions & 2 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
auto it = prototype_.find(op);
ICHECK(it != prototype_.end());
std::vector<StorageToken*> tokens;

for (StorageToken* tok : it->second) {
if (can_realloc) {
tokens.push_back(Request(tok));
Expand All @@ -250,6 +251,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}
token_map_[op] = tokens;
}
// Mark op to reuse the input_token
// tie the two memories together
void ReuseInputToken(const ExprNode* op, StorageToken* input_token) {
ICHECK(!token_map_.count(op));
auto it = prototype_.find(op);
ICHECK(it != prototype_.end());
ICHECK_EQ(it->second.size(), 1U);
StorageToken* prototype = it->second[0];
// add the reference counter of the output
// so the input token can only be deleted after references
// to both are expired
input_token->ref_counter += prototype->ref_counter;
// reuse the input token
token_map_[op] = {input_token};
}

// The call map
void VisitExpr_(const CallNode* op) final {
std::vector<StorageToken*> args;
Expand All @@ -259,8 +276,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
args.push_back(tok);
}
}
// create token for the call node.
CreateToken(op, true);
// Under the flat-memory setting.
// we can force aliasing the input and output of reshape
// to make it an nop. Note that this is not true
// for non-flat memory case. Given the current graph plan memory
// only works for flat memory case, we will go with this choice
//
// TODO(tvm-team) Update checks of flat memory enablement when we support
// opaque-nd memory planning to skip this path.
if (IsReshape(op)) {
ICHECK_EQ(args.size(), 1U);
ReuseInputToken(op, args[0]);
} else {
// create token for the call node.
CreateToken(op, true);
}
// check if there is orphaned output that can be released immediately.
for (StorageToken* tok : token_map_.at(op)) {
CheckForRelease(tok);
Expand All @@ -278,6 +308,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
static size_t DivRoundUp(size_t size, size_t word_size) {
return (size + word_size - 1) / word_size;
}
/*!
* \brief The call is an reshape only op
* \param call The call to be checked.
* \return the check result.
*/
static bool IsReshape(const CallNode* call) {
if (const auto* fn = call->op.as<FunctionNode>()) {
return fn->HasNonzeroAttr(attr::kReshapeOnly);
}
return false;
}
/*!
* \brief Get the memory requirement.
* \param prototype The prototype token.
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ RELAY_REGISTER_OP("dyn.reshape")
.set_support_level(3)
.add_type_rel("DynamicReshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

// tile operator
// TVM_REGISTER_NODE_TYPE(TileAttrs);
Expand Down
12 changes: 8 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ RELAY_REGISTER_OP("expand_dims")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel)
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
.set_attr<TReshapeOp>("TReshapeOp", true);

// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
Expand Down Expand Up @@ -887,7 +888,8 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

/*!
* \brief ReshapeLikeRel User defined type constraint function.
Expand Down Expand Up @@ -2243,7 +2245,8 @@ RELAY_REGISTER_OP("squeeze")
.add_type_rel("Squeeze", SqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout)
.set_attr<TReshapeOp>("TReshapeOp", true);

// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down Expand Up @@ -3221,7 +3224,8 @@ example below::
.set_support_level(10)
.add_type_rel("ReverseReshape", ReverseReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

// gather operator
TVM_REGISTER_NODE_TYPE(GatherAttrs);
Expand Down
30 changes: 27 additions & 3 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -948,15 +948,39 @@ class FuseMutator : private MixedModeMutator {
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
// Quickly check special properties of the fused function.
// A pass to check if the fused op contains only reshape ops.
class CheckReshapeOnly : public ExprVisitor {
public:
void VisitExpr_(const CallNode* cn) final {
this->has_call = true;
static auto freshape_op = Op::GetAttrMap<TReshapeOp>("TReshapeOp");

if (!freshape_op.get(cn->op, false)) {
this->reshape_only = false;
}

if (!this->reshape_only) return;
ExprVisitor::VisitExpr_(cn);
}

void VisitExpr_(const VarNode* vn) final {
if (!vn->type_annotation.defined() || !vn->type_annotation->IsInstance<TensorTypeNode>()) {
this->reshape_only = false;
}
}

bool reshape_only = true;
bool has_call = false;
void VisitExpr_(const CallNode* op) final { has_call = true; }
} visitor;

visitor(body);
const GroupInfo& ginfo = ginfo_[group];
auto func = Function(ginfo.params, body, ret_type, {});
func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
if (visitor.has_call && visitor.reshape_only) {
func = WithAttr(std::move(func), attr::kReshapeOnly, tvm::Integer(visitor.reshape_only));
}
return Call(func, ginfo.arguments, Attrs());
}

Expand Down
34 changes: 4 additions & 30 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,13 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt
return AllocTensor(storage, offset, shape, dtype, assert_shape);
}

// A pass to check if the fused op contains only reshape ops.
class CheckReshapeOnly : public ExprVisitor {
public:
CheckReshapeOnly()
: reshape_(Op::Get("reshape")),
contr_reshape_(Op::Get("contrib_reverse_reshape")),
dyn_reshape_(Op::Get("dyn.reshape")) {}

void VisitExpr_(const CallNode* cn) final {
if (!reshape_only) return;
if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) {
reshape_only = false;
}
for (auto arg : cn->args) ExprVisitor::VisitExpr(arg);
}

void VisitExpr_(const VarNode* vn) final {
if (!vn->checked_type_->IsInstance<TensorTypeNode>()) {
reshape_only = false;
}
}

const Op& reshape_;
const Op& contr_reshape_;
const Op& dyn_reshape_;
bool reshape_only{true};
};

// Check if the primitive function contains only reshape ops.
bool IsReshapeOnly(const Expr& expr) {
auto check = CheckReshapeOnly();
check.VisitExpr(expr);
return check.reshape_only;
if (auto* func = expr.as<FunctionNode>()) {
return func->HasNonzeroAttr(attr::kReshapeOnly);
}
return false;
}

class DialectRewriter : public ExprMutator {
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

import tvm
import json
from tvm import relay
from tvm.contrib import graph_executor
from tvm.relay.op import add
Expand Down Expand Up @@ -146,6 +147,51 @@ def test_plan_memory():
assert len(device_types) == 1


def test_reshape_nop():
# test that reshape can be turned into nop
x = relay.var("x", shape=(10, 4))
xx = relay.abs(x)
y = relay.expand_dims(xx, axis=1)
t0 = relay.reshape(y, (1, 40))
t1 = relay.abs(y)

z0 = relay.reshape(t0, (2, 20))
z1 = relay.sqrt(t1)
z2 = relay.reshape(t1, (1, 40))

func = relay.Function([x], relay.Tuple([z0, z1, z2]))
x_data = np.random.rand(10, 4).astype("float32")
graph = relay.build(tvm.IRModule.from_expr(func), "llvm")
graph_json_str = graph.get_json()

graph_json = json.loads(graph_json_str)

# reshape must force sharing memory
storage_ids = graph_json["attrs"]["storage_id"][1]
assert tuple(storage_ids) == (0, 1, 1, 2, 3, 2)
assert graph_json["nodes"][2]["attrs"]["func_name"] == "__nop"
assert graph_json["nodes"][5]["attrs"]["func_name"] == "__nop"

gmod = graph_executor.GraphModule(graph["default"](tvm.cpu(0)))

gmod.set_input(x=x_data)
gmod.run()
z0_np = x_data.reshape(2, 20)
z1_np = np.sqrt(
np.abs(
x_data.reshape(
10,
1,
4,
)
)
)
z2_np = np.abs(x_data).reshape(1, 40)
tvm.testing.assert_allclose(gmod.get_output(0).asnumpy(), z0_np)
tvm.testing.assert_allclose(gmod.get_output(1).asnumpy(), z1_np)
tvm.testing.assert_allclose(gmod.get_output(2).asnumpy(), z2_np)


@tvm.testing.uses_gpu
def test_gru_like():
def unit(rnn_dim):
Expand Down Expand Up @@ -231,6 +277,7 @@ def test_graph_executor_nested_tuples():


if __name__ == "__main__":
test_reshape_nop()
test_plan_memory()
test_with_params()
test_add_op_scalar()
Expand Down