Skip to content

Commit

Permalink
Fix fusion bug when call symbol that is not an operator. (apache#2630)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and wweic committed Mar 12, 2019
1 parent 6ff88b4 commit 9e08e45
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,22 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
Node* node = graph_.node_map.at(call);
static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern");
// setup pattern.
// Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
// annotated pattern.
//
// If the pattern is not annotated we will default to opaque.
//
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
}

node->pattern = op_pattern;
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
Expand Down
37 changes: 37 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,47 @@ def expected(dshape):
assert relay.ir_pass.alpha_equal(z, after)


def test_fuse_myia_regression():
def before(dshape, dtype):
x = relay.var('x', shape=dshape, dtype=dtype)
y = relay.var('y', shape=dshape, dtype=dtype)
sb = relay.ScopeBuilder()
with sb.if_scope(relay.op.greater(x, y)):
sb.ret(relay.Function([], x))
with sb.else_scope():
sb.ret(relay.Function([], y))
return relay.Function([x, y],
relay.Call(sb.get(), []))

def expected(dshape, dtype):
x = relay.var('x', shape=dshape, dtype=dtype)
y = relay.var('y', shape=dshape, dtype=dtype)
sb = relay.ScopeBuilder()
p1 = relay.var('p1', shape=dshape, dtype=dtype)
p2 = relay.var('p2', shape=dshape, dtype=dtype)
fused_gt = relay.Function([p1, p2],
relay.op.greater(p1, p2))
with sb.if_scope(fused_gt(x, y)):
sb.ret(relay.Function([], x))
with sb.else_scope():
sb.ret(relay.Function([], y))
return relay.Function([x, y],
relay.Call(sb.get(), []))

dshape = ()
dtype = 'int64'
f = before(dshape, dtype)
f = relay.ir_pass.infer_type(f)
f = relay.ir_pass.fuse_ops(f)
after = relay.ir_pass.infer_type(expected(dshape, dtype))
assert relay.ir_pass.alpha_equal(f, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
test_tuple_strided_slice()
test_stop_fusion()
test_fuse_myia_regression()

0 comments on commit 9e08e45

Please sign in to comment.