Skip to content

Commit

Permalink
Support concatenate
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 28, 2018
1 parent 13338b9 commit 51a20ba
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 68 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
10 changes: 7 additions & 3 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ bool Conv2DRel(const Array<Type>& types,
}

template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr>> &in_shapes) {
Array<Array<Layout> > Conv2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();
Layout out_layout(params->out_layout);

// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->weight_layout},
{out_layout.defined() ? out_layout : params->data_layout}};
}
Expand Down
28 changes: 13 additions & 15 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,22 @@ TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
template <typename T>
Array<Array<Layout> > Pool2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr>> &in_shapes) {
CHECK_EQ(in_layouts.size(), 1);

const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());
Layout input = in_layouts[0];
const Layout raw_layout(params->layout);
if (input.defined()) {
CHECK(input.Convertible(raw_layout));
if (input.Indexof('W') != raw_layout.Indexof('W') ||
input.Indexof('H') != raw_layout.Indexof('H') ||
input.Contains('w') || input.Contains('h')) {
// if the new layout changes width or height dimension,
// fallback to old layout;
input = raw_layout;

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.Indexof('W') == raw_layout.Indexof('W') &&
input.Indexof('H') == raw_layout.Indexof('H') &&
!input.Contains('w') && !input.Contains('h')) {
params->layout = input.name(); // modify self to follow the input layout
}
params->layout = input.name(); // modify self to follow the input layout
}

return Array<Array<Layout> >{{params->layout}, {params->layout}};
Expand Down
40 changes: 39 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <vector>
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../layout.h"

namespace tvm {
Expand Down Expand Up @@ -202,6 +203,42 @@ bool ConcatenateRel(const Array<Type>& types,
return true;
}

Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();

size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);

Layout ret;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) {
ret = new_in_layouts[i];
break;
}
}
} else { // this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}

if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}

return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

Expr MakeConcatenate(Expr data,
int axis) {
auto attrs = make_node<ConcatenateAttrs>();
Expand All @@ -227,7 +264,8 @@ RELAY_REGISTER_OP("concatenate")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel);
.add_type_rel("Concatenate", ConcatenateRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout);

/* relay.transpose */
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
Expand Down
61 changes: 40 additions & 21 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class TransformMemorizer : public NodeRef {

// Transform layout with memorizer
Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
if (src_layout.Equals(dst_layout))
return raw;

std::tuple<const Node*, std::string, std::string> key =
std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
auto& memo = operator->()->memo;
Expand Down Expand Up @@ -116,14 +119,16 @@ RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr);
// Return inferred_input_layout, inferred_output_layout, success
std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
const Call& call,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr>>& in_shapes) {
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");

Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
Array<Array<Layout> > inferred_layouts;
inferred_layouts = finfer_layout[op](call->attrs, in_layouts, in_shapes);
inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts,
old_in_layouts, old_in_shapes);
CHECK_EQ(inferred_layouts.size(), 2)
<< "FInferCorrectLayout should return an array with size of 2";
for (auto x : inferred_layouts) {
Expand Down Expand Up @@ -180,17 +185,27 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// NOTE: discard the "const" qualifier
TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);

// fill incomplete state
for (auto arg : new_args) {
if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
normal_new_args.push_back(inp->value);
// fill incomplete state and expand tuple
for (auto new_arg : new_args) {
auto push_back_one_arg = [&](Expr arg) {
if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
normal_new_args.push_back(inp->value);
} else {
auto inode = make_node<LayoutAlternatedExprNode>();
inode->value = arg;
inode->memorizer = memorizer;
inputs.push_back(LayoutAlternatedExpr(inode));
normal_new_args.push_back(arg);
}
};
if (new_arg->is_type<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
for (auto x : tuple_new_arg->fields) {
push_back_one_arg(x);
}
} else {
auto inode = make_node<LayoutAlternatedExprNode>();
inode->value = arg;
inode->memorizer = memorizer;
inputs.push_back(LayoutAlternatedExpr(inode));
normal_new_args.push_back(arg);
push_back_one_arg(new_arg);
}
}

Expand All @@ -202,12 +217,21 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
}

for (auto arg : ref_call->args) {
input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape);
if (arg->is_type<TupleNode>()) { // expand tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
for (auto x : tuple_arg->fields) {
input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
}
} else {
input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape);
}
}

// old_in, old_out = op.infer(old_in)
bool success = false;
std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in, input_shapes);
std::tie(old_in, old_out, success) = CallInfer(ref_call,
Array<Layout>(nullptr),
old_in, input_shapes);
if (!success) { return Expr(nullptr); }
CHECK_EQ(old_in.size(), new_in.size());

Expand All @@ -224,12 +248,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// new_in2, new_out = op.infer(new_in)
if (new_call->op->is_type<OpNode>()) {
success = false;
for (size_t i = 0; i < input_shapes.size(); ++i) {
if (old_in.defined()) {
input_shapes.Set(i, ConvertLayout(input_shapes[i], old_in[i], new_in[i]));
}
}
std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, input_shapes);
std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes);
if (!success) { return Expr(nullptr); }
} else {
return Expr(nullptr);
Expand Down
64 changes: 37 additions & 27 deletions src/relay/pass/alter_op_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,54 @@ namespace relay {
/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param in_layouts The layouts of input arguments.
* \param in_shapes The shapes of input arguments.
* \param new_in_layouts The layouts of input arguments after alter_op_layout.
* This can be undefined, which means we call this function before alternating
* any operators.
* \param old_in_layouts The layouts of input arguments before alter_op_layout.
* \param old_in_shapes The shapes of old input arguments.
* \return infered_layout An array of two elements that are inferred input layouts and
* inferred output layouts.
*/
using FInferCorrectLayout = runtime::TypedPackedFunc<
Array<Array<Layout>>(const Attrs& attrs,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr>> &in_shapes)>;
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes)>;

/*! \brief take arbitrary input layout and copy to output */
inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr> > &in_shapes) {
Array<Layout> inferred_ins;

Layout in;
for (size_t i = 0; i < in_layouts.size(); ++i) {
if (!in.defined()) in = in_layouts[i];
CHECK(in.Equals(in_layouts[i]))
<< "Incompatible layout at " << i << "-th input: expected " << in
<< ", got " << in_layouts[i];
}
for (size_t i = 0; i < in_layouts.size(); ++i) {
inferred_ins.push_back(in);
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
Layout ret;

if (new_in_layouts.defined()) {
CHECK_GE(new_in_layouts.size(), 1);
ret = new_in_layouts[0];
} else {
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
}

return Array<Array<Layout> >{inferred_ins, {in}};
return Array<Array<Layout> >{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

/*! \brief Infer layout for binary broadcast operators */
inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& in_layouts,
const Array<Array<IndexExpr> > &in_shapes) {
CHECK_EQ(in_layouts.size(), 2);
CHECK_EQ(in_shapes.size(), 2);
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
Array<Layout> layouts;

Array<Layout> layouts = in_layouts;
if (new_in_layouts.defined()) {
layouts.assign(new_in_layouts.begin(), new_in_layouts.end());
} else {
layouts.assign(old_in_layouts.begin(), old_in_layouts.end());
}

if (!layouts[0].defined() && !layouts[1].defined()) {
// both undefined, infer fails
Expand All @@ -66,11 +76,11 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
int defined_idx = layouts[0].defined() ? 0 : 1;
int undef_idx = 1 - defined_idx;

if (in_shapes[defined_idx].size() >= in_shapes[undef_idx].size()) {
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx,
layouts[defined_idx].Sublayout(
in_shapes[defined_idx].size() - in_shapes[undef_idx].size(),
in_shapes[undef_idx].size()));
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
Expand All @@ -79,7 +89,7 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
} else {
// try to broadcast to the tensors to the larger dimension
// try to broadcast the tensors to the larger dimension
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];
Expand Down
16 changes: 16 additions & 0 deletions src/relay/pass/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ class ForwardRewriter : private ExprMutator {
}
}

Expr VisitExpr_(const TupleNode* op) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
auto new_field = this->GetTempExpr(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}

if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
}
}

Expr VisitExpr_(const CallNode* call_node) final {
const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite;
Expand Down

0 comments on commit 51a20ba

Please sign in to comment.