Skip to content

Commit

Permalink
add a testcase for apache#5674
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed May 26, 2020
1 parent 6100112 commit 3e61832
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,36 @@ def callback(self, pre, post, node_map):
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)

def test_rewrite():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
sub_pattern = is_op('subtract')(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)

def test_rewrite_func():
x = relay.var('x')
w = relay.var('w')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
sub_pattern = is_op('subtract')(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None)
out = rewrite(TestRewrite(), func(x,w) + y)
assert sub_pattern.match(out)

def test_nested_rewrite():
class PatternCallback(DFPatternCallback):
def __init__(self, pattern):
Expand Down

0 comments on commit 3e61832

Please sign in to comment.