Skip to content

Commit

Permalink
clean ops
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 21, 2018
1 parent 023f12c commit 66717fc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ class Conv2DAttrs(Attrs):

@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
"""Attribute of a Convolution Operator"""
"""Attribute of a Global 2D Pooling Operator"""
pass
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def slice_like(data, shape_like, axes=None):


def layout_transform(data, src_layout, dst_layout):
"""Strided slice of an array..
"""Transform the layout of an tensor
Parameters
----------
Expand Down
20 changes: 11 additions & 9 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr);
// Return inferred_input_layout, inferred_output_layout, success
std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
const Call& call,
const Array<Layout>& inputs,
const Array<Layout>& last_inputs) {
const Array<Layout>& inputs) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");

Op op = Downcast<Op>(call->op);
Expand Down Expand Up @@ -193,12 +192,12 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
new_in.push_back(inp->new_layout);
}

// old_in, old_out = op.infer(old_in, old_out)
// old_in, old_out = op.infer(old_in)
bool success = false;
std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in, Array<Layout>(nullptr));
std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in);
if (!success) { return Expr(nullptr); }

CHECK_EQ(old_in.size(), new_in.size());

// if new_in == 'undef': new_in = old_in
for (size_t i = 0; i < new_in.size(); ++i) {
if (!new_in[i].defined()) {
Expand All @@ -209,11 +208,14 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// new_op = alter(op)
Call new_call = CallAlter(ref_call, normal_new_args);

success = false;
if (new_call->op->is_type<OpNode>()) { // try infer after alternating op
std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in);
// new_in2, new_out = op.infer(new_in)
if (new_call->op->is_type<OpNode>()) {
success = false;
std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in);
if (!success) { return Expr(nullptr); }
} else {
return Expr(nullptr);
}
if (!success) { return Expr(nullptr); }

CHECK_EQ(new_out.size(), old_out.size())
<< "The number of output nodes should keep the same during alter_op_layout";
Expand Down

0 comments on commit 66717fc

Please sign in to comment.