From b2c4f1c2cbd4c4b7ac27ac853dbad7c489784444 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 2 May 2021 00:00:10 -0400 Subject: [PATCH] [RELAY] Turn reshape into nop in graph executor backend. (#7945) * [RELAY] Turn reshape into nop in graph executor backend. Previously we are generating the function calls for reshape. This PR updates the optimization to turn reshape into nop: - Tag a fused function as reshape only if it only contains reshape. - Update memory planner to force input output to share the same piece of memory - Update the graph runtime codegen to emit nop when reshape only function is encountered. * Address review comments. * Additional comment and TODOs on the rationale --- include/tvm/relay/function.h | 3 ++ include/tvm/relay/op_attr_types.h | 7 +++ src/relay/backend/graph_executor_codegen.cc | 23 +++++++++ src/relay/backend/graph_plan_memory.cc | 45 +++++++++++++++++- src/relay/op/dyn/tensor/transform.cc | 3 +- src/relay/op/tensor/transform.cc | 12 +++-- src/relay/transforms/fuse_ops.cc | 30 ++++++++++-- src/relay/transforms/memory_alloc.cc | 34 ++------------ .../relay/test_backend_graph_executor.py | 47 +++++++++++++++++++ 9 files changed, 164 insertions(+), 40 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index db973b91f92a..95eaad0b2797 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -144,6 +144,9 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kInline = "Inline"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; + +/*! \brief Mark the function as only composed of reshape operations. */ +constexpr const char* kReshapeOnly = "relay.reshape_only"; } // namespace attr } // namespace relay diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index f916dbeb713f..97a3d5e2a01f 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -82,6 +82,13 @@ using TOpIsStateful = bool; */ using TNonComputational = bool; +/*! + * \brief Mark the operator as reshape op of its first input + * and can be turned into a nop when the first input and output + * shares the same piece of memory. + */ +using TReshapeOp = bool; + /*! * \brief Mark the operator whether output shape is data dependent. */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 72989b5ba46a..3ea8a2bed91b 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -349,6 +349,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(op)); } + bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { + auto lit = storage_device_map_.find(lhs); + auto rit = storage_device_map_.find(rhs); + ICHECK(lit != storage_device_map_.end()); + ICHECK(rit != storage_device_map_.end()); + int64_t lhs_storage_id = ((*lit).second)[0][0]->value; + int64_t rhs_storage_id = ((*rit).second)[0][0]->value; + return lhs_storage_id == rhs_storage_id; + } + std::vector VisitExpr_(const CallNode* op) override { Expr expr = GetRef(op); Function func; @@ -380,6 +390,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunc_name, ext_func->func_name); } + // In the current flat memory allocation scenario + // the flat memory allocator can always allocate input + // and output of the reshape to the same memory, we can turn reshape only + // function to a nop. + // + // NOTE that for non-flat memory this is not necessarily true. + // + // TODO(tvm-team) Update checks of flat memory enablement when we support + // opaque-nd memory planning to skip this path. + if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) { + return GraphAddCallNode(op, "reshape_nop", "__nop"); + } + ICHECK_GE(storage_device_map_.count(expr), 0); auto& device_type = storage_device_map_[expr][1]; auto call_dev_type = device_type[0]->value; diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 4260f052d2c0..cf843236da61 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -236,6 +236,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; + for (StorageToken* tok : it->second) { if (can_realloc) { tokens.push_back(Request(tok)); @@ -250,6 +251,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } token_map_[op] = tokens; } + // Mark op to reuse the input_token + // tie the two memories together + void ReuseInputToken(const ExprNode* op, StorageToken* input_token) { + ICHECK(!token_map_.count(op)); + auto it = prototype_.find(op); + ICHECK(it != prototype_.end()); + ICHECK_EQ(it->second.size(), 1U); + StorageToken* prototype = it->second[0]; + // add the reference counter of the output + // so the input token can only be deleted after references + // to both are expired + input_token->ref_counter += prototype->ref_counter; + // reuse the input token + token_map_[op] = {input_token}; + } + // The call map void VisitExpr_(const CallNode* op) final { std::vector args; @@ -259,8 +276,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor { args.push_back(tok); } } - // create token for the call node. - CreateToken(op, true); + // Under the flat-memory setting. + // we can force aliasing the input and output of reshape + // to make it an nop. Note that this is not true + // for non-flat memory case. Given the current graph plan memory + // only works for flat memory case, we will go with this choice + // + // TODO(tvm-team) Update checks of flat memory enablement when we support + // opaque-nd memory planning to skip this path. + if (IsReshape(op)) { + ICHECK_EQ(args.size(), 1U); + ReuseInputToken(op, args[0]); + } else { + // create token for the call node. + CreateToken(op, true); + } // check if there is orphaned output that can be released immediately. for (StorageToken* tok : token_map_.at(op)) { CheckForRelease(tok); @@ -278,6 +308,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { static size_t DivRoundUp(size_t size, size_t word_size) { return (size + word_size - 1) / word_size; } + /*! + * \brief The call is an reshape only op + * \param call The call to be checked. + * \return the check result. + */ + static bool IsReshape(const CallNode* call) { + if (const auto* fn = call->op.as()) { + return fn->HasNonzeroAttr(attr::kReshapeOnly); + } + return false; + } /*! * \brief Get the memory requirement. * \param prototype The prototype token. diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 9724a92e8776..cf8f3689b045 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -141,7 +141,8 @@ RELAY_REGISTER_OP("dyn.reshape") .set_support_level(3) .add_type_rel("DynamicReshape", ReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) - .set_attr("TOpPattern", kInjective); + .set_attr("TOpPattern", kInjective) + .set_attr("TReshapeOp", true); // tile operator // TVM_REGISTER_NODE_TYPE(TileAttrs); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e937cb0c7b1f..02cdb211463e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -238,7 +238,8 @@ RELAY_REGISTER_OP("expand_dims") .set_support_level(1) .add_type_rel("ExpandDims", ExpandDimsRel) .set_attr("FTVMCompute", ExpandDimsCompute) - .set_attr("TOpPattern", kBroadcast); + .set_attr("TOpPattern", kBroadcast) + .set_attr("TReshapeOp", true); // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); @@ -887,7 +888,8 @@ Example:: .set_support_level(3) .add_type_rel("Reshape", ReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) - .set_attr("TOpPattern", kInjective); + .set_attr("TOpPattern", kInjective) + .set_attr("TReshapeOp", true); /*! * \brief ReshapeLikeRel User defined type constraint function. @@ -2243,7 +2245,8 @@ RELAY_REGISTER_OP("squeeze") .add_type_rel("Squeeze", SqueezeRel) .set_attr("FTVMCompute", SqueezeCompute) .set_attr("TOpPattern", kInjective) - .set_attr("FInferCorrectLayout", SqueezeInferCorrectLayout); + .set_attr("FInferCorrectLayout", SqueezeInferCorrectLayout) + .set_attr("TReshapeOp", true); // CollapseSumLike: -> B where BroadCast(A, B) = A bool CollapseSumLikeRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -3221,7 +3224,8 @@ example below:: .set_support_level(10) .add_type_rel("ReverseReshape", ReverseReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) - .set_attr("TOpPattern", kInjective); + .set_attr("TOpPattern", kInjective) + .set_attr("TReshapeOp", true); // gather operator TVM_REGISTER_NODE_TYPE(GatherAttrs); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index eaef0b905079..f1f7a95e33e8 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -948,15 +948,39 @@ class FuseMutator : private MixedModeMutator { } Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { - // If the function has no call, it is not a primitive function. - struct HasCallVisitor : ExprVisitor { + // Quickly check special properties of the fused function. + // A pass to check if the fused op contains only reshape ops. + class CheckReshapeOnly : public ExprVisitor { + public: + void VisitExpr_(const CallNode* cn) final { + this->has_call = true; + static auto freshape_op = Op::GetAttrMap("TReshapeOp"); + + if (!freshape_op.get(cn->op, false)) { + this->reshape_only = false; + } + + if (!this->reshape_only) return; + ExprVisitor::VisitExpr_(cn); + } + + void VisitExpr_(const VarNode* vn) final { + if (!vn->type_annotation.defined() || !vn->type_annotation->IsInstance()) { + this->reshape_only = false; + } + } + + bool reshape_only = true; bool has_call = false; - void VisitExpr_(const CallNode* op) final { has_call = true; } } visitor; + visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = Function(ginfo.params, body, ret_type, {}); func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call)); + if (visitor.has_call && visitor.reshape_only) { + func = WithAttr(std::move(func), attr::kReshapeOnly, tvm::Integer(visitor.reshape_only)); + } return Call(func, ginfo.arguments, Attrs()); } diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 1dc204d43ba1..2b69b02ab999 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -64,39 +64,13 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt return AllocTensor(storage, offset, shape, dtype, assert_shape); } -// A pass to check if the fused op contains only reshape ops. -class CheckReshapeOnly : public ExprVisitor { - public: - CheckReshapeOnly() - : reshape_(Op::Get("reshape")), - contr_reshape_(Op::Get("contrib_reverse_reshape")), - dyn_reshape_(Op::Get("dyn.reshape")) {} - - void VisitExpr_(const CallNode* cn) final { - if (!reshape_only) return; - if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) { - reshape_only = false; - } - for (auto arg : cn->args) ExprVisitor::VisitExpr(arg); - } - - void VisitExpr_(const VarNode* vn) final { - if (!vn->checked_type_->IsInstance()) { - reshape_only = false; - } - } - - const Op& reshape_; - const Op& contr_reshape_; - const Op& dyn_reshape_; - bool reshape_only{true}; -}; // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { - auto check = CheckReshapeOnly(); - check.VisitExpr(expr); - return check.reshape_only; + if (auto* func = expr.as()) { + return func->HasNonzeroAttr(attr::kReshapeOnly); + } + return false; } class DialectRewriter : public ExprMutator { diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index b9553d79c3b6..8e6fe298351e 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -17,6 +17,7 @@ import numpy as np import tvm +import json from tvm import relay from tvm.contrib import graph_executor from tvm.relay.op import add @@ -146,6 +147,51 @@ def test_plan_memory(): assert len(device_types) == 1 +def test_reshape_nop(): + # test that reshape can be turned into nop + x = relay.var("x", shape=(10, 4)) + xx = relay.abs(x) + y = relay.expand_dims(xx, axis=1) + t0 = relay.reshape(y, (1, 40)) + t1 = relay.abs(y) + + z0 = relay.reshape(t0, (2, 20)) + z1 = relay.sqrt(t1) + z2 = relay.reshape(t1, (1, 40)) + + func = relay.Function([x], relay.Tuple([z0, z1, z2])) + x_data = np.random.rand(10, 4).astype("float32") + graph = relay.build(tvm.IRModule.from_expr(func), "llvm") + graph_json_str = graph.get_json() + + graph_json = json.loads(graph_json_str) + + # reshape must force sharing memory + storage_ids = graph_json["attrs"]["storage_id"][1] + assert tuple(storage_ids) == (0, 1, 1, 2, 3, 2) + assert graph_json["nodes"][2]["attrs"]["func_name"] == "__nop" + assert graph_json["nodes"][5]["attrs"]["func_name"] == "__nop" + + gmod = graph_executor.GraphModule(graph["default"](tvm.cpu(0))) + + gmod.set_input(x=x_data) + gmod.run() + z0_np = x_data.reshape(2, 20) + z1_np = np.sqrt( + np.abs( + x_data.reshape( + 10, + 1, + 4, + ) + ) + ) + z2_np = np.abs(x_data).reshape(1, 40) + tvm.testing.assert_allclose(gmod.get_output(0).asnumpy(), z0_np) + tvm.testing.assert_allclose(gmod.get_output(1).asnumpy(), z1_np) + tvm.testing.assert_allclose(gmod.get_output(2).asnumpy(), z2_np) + + @tvm.testing.uses_gpu def test_gru_like(): def unit(rnn_dim): @@ -231,6 +277,7 @@ def test_graph_executor_nested_tuples(): if __name__ == "__main__": + test_reshape_nop() test_plan_memory() test_with_params() test_add_op_scalar()