From 51a20ba9a88f8779fba5bfc27536a443210c9fe7 Mon Sep 17 00:00:00 2001 From: Mercy Date: Tue, 27 Nov 2018 22:24:05 -0800 Subject: [PATCH] Support concatenate --- 3rdparty/HalideIR | 2 +- src/relay/op/nn/convolution.cc | 10 +++-- src/relay/op/nn/pooling.cc | 28 +++++++------- src/relay/op/tensor/transform.cc | 40 ++++++++++++++++++- src/relay/pass/alter_op_layout.cc | 61 +++++++++++++++++++---------- src/relay/pass/alter_op_layout.h | 64 ++++++++++++++++++------------- src/relay/pass/forward_rewrite.cc | 16 ++++++++ 7 files changed, 153 insertions(+), 68 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index e4a4c02764d37..a08e26e5a97f4 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit e4a4c02764d37c9c3db0d64c4996651a3ef9513c +Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 65fb09a3e8421..170b6b6d13c5c 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -104,12 +104,16 @@ bool Conv2DRel(const Array& types, } template -Array > Conv2DInferCorrectLayout(const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes) { +Array > Conv2DInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { const T* params = attrs.as(); Layout out_layout(params->out_layout); + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs return Array >{{params->data_layout, params->weight_layout}, {out_layout.defined() ? out_layout : params->data_layout}}; } diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index bcc329bdbc11e..6233e6d51776b 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,24 +21,22 @@ TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template Array > Pool2DInferCorrectLayout( const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes) { - CHECK_EQ(in_layouts.size(), 1); - + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { // NOTE: Discard "const" qualifier here. T *params = const_cast(attrs.as()); - Layout input = in_layouts[0]; - const Layout raw_layout(params->layout); - if (input.defined()) { - CHECK(input.Convertible(raw_layout)); - if (input.Indexof('W') != raw_layout.Indexof('W') || - input.Indexof('H') != raw_layout.Indexof('H') || - input.Contains('w') || input.Contains('h')) { - // if the new layout changes width or height dimension, - // fallback to old layout; - input = raw_layout; + + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.Indexof('W') == raw_layout.Indexof('W') && + input.Indexof('H') == raw_layout.Indexof('H') && + !input.Contains('w') && !input.Contains('h')) { + params->layout = input.name(); // modify self to follow the input layout } - params->layout = input.name(); // modify self to follow the input layout } return Array >{{params->layout}, {params->layout}}; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 455dffa568695..74a182e2e984b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -13,6 +13,7 @@ #include #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" +#include "../../pass/alter_op_layout.h" #include "../layout.h" namespace tvm { @@ -202,6 +203,42 @@ bool ConcatenateRel(const Array& types, return true; } +Array> ConcatenateLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + const ConcatenateAttrs* param = attrs.as(); + + size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : + static_cast(param->axis); + + Layout ret; + if (new_in_layouts.defined()) { // this function is called after some operators are alternated. + Layout::LayoutDim concate_dim = old_in_layouts[0][axis]; + for (size_t i = 0; i < new_in_layouts.size(); ++i) { + if (new_in_layouts[i].ndim() > axis && + new_in_layouts[i][axis] == concate_dim) { + ret = new_in_layouts[i]; + break; + } + } + } else { // this function is called on the original correct relay ir + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } + + if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) { + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } + + return Array > {Array(old_in_layouts.size(), ret), {ret}}; +} + Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_node(); @@ -227,7 +264,8 @@ RELAY_REGISTER_OP("concatenate") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel); +.add_type_rel("Concatenate", ConcatenateRel) +.set_attr("FInferCorrectLayout", ConcatenateLayout); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 78632942877cc..b3f7a478dcc39 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -66,6 +66,9 @@ class TransformMemorizer : public NodeRef { // Transform layout with memorizer Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { + if (src_layout.Equals(dst_layout)) + return raw; + std::tuple key = std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); auto& memo = operator->()->memo; @@ -116,14 +119,16 @@ RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); // Return inferred_input_layout, inferred_output_layout, success std::tuple, Array, bool> CallInfer( const Call& call, - const Array& in_layouts, - const Array>& in_shapes) { + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array > inferred_layouts; - inferred_layouts = finfer_layout[op](call->attrs, in_layouts, in_shapes); + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, + old_in_layouts, old_in_shapes); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { @@ -180,17 +185,27 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // NOTE: discard the "const" qualifier TransformMemorizer memorizer = Downcast(ctx); - // fill incomplete state - for (auto arg : new_args) { - if (const LayoutAlternatedExprNode *inp = arg.as()) { - inputs.push_back(GetRef(inp)); - normal_new_args.push_back(inp->value); + // fill incomplete state and expand tuple + for (auto new_arg : new_args) { + auto push_back_one_arg = [&](Expr arg) { + if (const LayoutAlternatedExprNode *inp = arg.as()) { + inputs.push_back(GetRef(inp)); + normal_new_args.push_back(inp->value); + } else { + auto inode = make_node(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr(inode)); + normal_new_args.push_back(arg); + } + }; + if (new_arg->is_type()) { + Tuple tuple_new_arg = Downcast(new_arg); + for (auto x : tuple_new_arg->fields) { + push_back_one_arg(x); + } } else { - auto inode = make_node(); - inode->value = arg; - inode->memorizer = memorizer; - inputs.push_back(LayoutAlternatedExpr(inode)); - normal_new_args.push_back(arg); + push_back_one_arg(new_arg); } } @@ -202,12 +217,21 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } for (auto arg : ref_call->args) { - input_shapes.push_back(arg->type_as()->shape); + if (arg->is_type()) { // expand tuple + Tuple tuple_arg = Downcast(arg); + for (auto x : tuple_arg->fields) { + input_shapes.push_back(x->type_as()->shape); + } + } else { + input_shapes.push_back(arg->type_as()->shape); + } } // old_in, old_out = op.infer(old_in) bool success = false; - std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in, input_shapes); + std::tie(old_in, old_out, success) = CallInfer(ref_call, + Array(nullptr), + old_in, input_shapes); if (!success) { return Expr(nullptr); } CHECK_EQ(old_in.size(), new_in.size()); @@ -224,12 +248,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // new_in2, new_out = op.infer(new_in) if (new_call->op->is_type()) { success = false; - for (size_t i = 0; i < input_shapes.size(); ++i) { - if (old_in.defined()) { - input_shapes.Set(i, ConvertLayout(input_shapes[i], old_in[i], new_in[i])); - } - } - std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, input_shapes); + std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes); if (!success) { return Expr(nullptr); } } else { return Expr(nullptr); diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/alter_op_layout.h index b5de670bdb577..fcb7b379a0ec1 100644 --- a/src/relay/pass/alter_op_layout.h +++ b/src/relay/pass/alter_op_layout.h @@ -19,44 +19,54 @@ namespace relay { /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention * \param attrs The attribute of the node. - * \param in_layouts The layouts of input arguments. - * \param in_shapes The shapes of input arguments. + * \param new_in_layouts The layouts of input arguments after alter_op_layout. + * This can be undefined, which means we call this function before alternating + * any operators. + * \param old_in_layouts The layouts of input arguments before alter_op_layout. + * \param old_in_shapes The shapes of old input arguments. * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ using FInferCorrectLayout = runtime::TypedPackedFunc< Array>(const Attrs& attrs, - const Array& in_layouts, - const Array> &in_shapes)>; + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes)>; /*! \brief take arbitrary input layout and copy to output */ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& in_layouts, - const Array > &in_shapes) { - Array inferred_ins; - - Layout in; - for (size_t i = 0; i < in_layouts.size(); ++i) { - if (!in.defined()) in = in_layouts[i]; - CHECK(in.Equals(in_layouts[i])) - << "Incompatible layout at " << i << "-th input: expected " << in - << ", got " << in_layouts[i]; - } - for (size_t i = 0; i < in_layouts.size(); ++i) { - inferred_ins.push_back(in); + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Layout ret; + + if (new_in_layouts.defined()) { + CHECK_GE(new_in_layouts.size(), 1); + ret = new_in_layouts[0]; + } else { + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } } - return Array >{inferred_ins, {in}}; + return Array >{Array(old_in_layouts.size(), ret), {ret}}; } /*! \brief Infer layout for binary broadcast operators */ inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& in_layouts, - const Array > &in_shapes) { - CHECK_EQ(in_layouts.size(), 2); - CHECK_EQ(in_shapes.size(), 2); + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + Array layouts; - Array layouts = in_layouts; + if (new_in_layouts.defined()) { + layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); + } else { + layouts.assign(old_in_layouts.begin(), old_in_layouts.end()); + } if (!layouts[0].defined() && !layouts[1].defined()) { // both undefined, infer fails @@ -66,11 +76,11 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, int defined_idx = layouts[0].defined() ? 0 : 1; int undef_idx = 1 - defined_idx; - if (in_shapes[defined_idx].size() >= in_shapes[undef_idx].size()) { + if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { layouts.Set(undef_idx, layouts[defined_idx].Sublayout( - in_shapes[defined_idx].size() - in_shapes[undef_idx].size(), - in_shapes[undef_idx].size())); + old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); return Array > {layouts, {layouts[defined_idx]}}; } else { // only know the tensor with smaller dimensions, @@ -79,7 +89,7 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, return Array > {{Layout::Undef()}, {Layout::Undef()}}; } } else { - // try to broadcast to the tensors to the larger dimension + // try to broadcast the tensors to the larger dimension int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1; int small_idx = 1 - large_idx; Layout ret = layouts[large_idx]; diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index a0cbc4a502c58..4f33d4a053b75 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -112,6 +112,22 @@ class ForwardRewriter : private ExprMutator { } } + Expr VisitExpr_(const TupleNode* op) final { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->GetTempExpr(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return GetRef(op); + } else { + return TupleNode::make(fields); + } + } + Expr VisitExpr_(const CallNode* call_node) final { const Call& ref_call = GetRef(call_node); PackedFunc frewrite;