From d0ec768f82c9a369f75a490781d1bd44b3e65e64 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 26 May 2020 13:30:04 -0700 Subject: [PATCH] add a testcase for #5674 --- tests/python/relay/test_dataflow_pattern.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index ed90873421f3..6a66f60aa68b 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -464,6 +464,23 @@ def callback(self, pre, post, node_map): out = rewrite(TestRewrite(), x + y) assert sub_pattern.match(out) +def test_rewrite_func(): + x = relay.var('x') + w = relay.var('w') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + sub_pattern = is_op('subtract')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + inpf = relay.var("input") + weightf = relay.var("weight") + func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None) + out = rewrite(TestRewrite(), func(x,w) + y) + assert sub_pattern.match(out) + def test_nested_rewrite(): class PatternCallback(DFPatternCallback): def __init__(self, pattern):