From 95b3ad971dfed6aa4e9f18d4f43c846806fd5292 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 28 May 2020 16:48:07 -0700 Subject: [PATCH] [PatternLang] Add ConstantPattern (#5689) * Add ConstantPattern * update doc --- docs/langref/relay_pattern.rst | 43 +++ include/tvm/relay/dataflow_pattern.h | 18 + include/tvm/relay/dataflow_pattern_functor.h | 3 + python/tvm/relay/dataflow_pattern/__init__.py | 8 + src/relay/ir/dataflow_matcher.cc | 5 + src/relay/ir/dataflow_pattern.cc | 12 + src/relay/ir/dataflow_pattern_functor.cc | 2 + src/relay/ir/indexed_graph.cc | 2 + tests/python/relay/test_dataflow_pattern.py | 359 ++++++++++++------ 9 files changed, 337 insertions(+), 115 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index c129544b4a79..7bb7bdfa6e82 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -37,6 +37,11 @@ for more use cases. .. _tests/python/relay/test_dataflow_pattern.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_dataflow_pattern.py +.. note:: + + If you cannot find the corresponding pattern node to match the Relay node you want, + you are welcome to raise an issue or submit a PR to add it. + Matching One of Two Ops *********************** @@ -131,6 +136,44 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu: out = relay.nn.relu(tuple_get_item_node) pat.match(out) +The next example is matching a constant node regarding its values. This is useful to check +if a specific parameter in a subgraph has been bind or not. + +.. code-block:: python + + def test_match_constant(): + conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern()) + pattern = is_op('nn.bias_add')(conv2d, wildcard()) + + x = relay.var('x', shape=(1, 3, 224, 224)) + w = relay.var('w', shape=(3, 3, 3, 3)) + b = relay.var('b', shape=(3, )) + conv2d = relay.op.nn.conv2d(x, w) + out = relay.op.nn.bias_add(conv2d, b) + func = relay.Function([x, w, b], out) + mod = tvm.IRModule.from_expr(func) + + # Two inputs of the conv2d in the graph are VarNode by default, so no match. + assert not pattern.match(mod['main'].body) + + # The second input (weight) has been bind with constant values so it is now a constant node. + mod["main"] = bind_params_by_name(mod["main"], + {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))}) + assert pattern.match(mod['main'].body) + +On the other hand, if you need to match the constant with a specific value, you can directly +use ``ExprPattern``. This could be useful for algebraic simplify. + +.. code-block:: python + + def test_match_plus_zero(): + zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) + pattern = wildcard() + zero + + x = relay.Var('x') + y = x + relay.const(0) + assert pattern.match(y) + The next example is matching function nodes with a specific attribute: .. code-block:: python diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index a8db51f74574..80a5d6f52617 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -111,6 +111,24 @@ class VarPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); }; +/*! + * \brief A Pattern to Match a Relay Constant + */ +class ConstantPattern; +/*! \brief Container for Constant */ +class ConstantPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); +}; + +class ConstantPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); +}; + /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index 05c2147c2c49..a1140ae4f54e 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -91,6 +91,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -111,6 +112,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); return vtable; } @@ -134,6 +136,7 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const TuplePatternNode* op) override; void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; protected: diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index ca83e06c6e95..e8f73ed08f4e 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -331,6 +331,14 @@ def __init__(self, name_hint: str, type_annotation=None): ffi.VarPattern, name_hint, type_annotation) +@register_df_node +class ConstantPattern(DFPattern): + """A pattern matching a Relay Constant. + """ + def __init__(self): + self.__init_handle_by_constructor__(ffi.ConstantPattern) + + @register_df_node class CallPattern(DFPattern): """A pattern matching a function call node in Relay. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 409dbc339c65..a7e4b3714fc1 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -56,6 +56,7 @@ class DFPatternMatcher : public DFPatternFunctor() != nullptr; +} + bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { return true; } diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 826a035ca6ba..280913164fd5 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -69,6 +69,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +TVM_REGISTER_NODE_TYPE(ConstantPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() { + auto c = ConstantPattern(make_object()); + return c; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ConstantPattern()"; + }); + CallPattern::CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); n->op = std::move(op); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index c7c34c804449..ee44bcb43c8b 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -71,6 +71,8 @@ void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPatte void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} + void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} } // namespace relay diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 7f7a5ff66853..0d4b90da0293 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -270,6 +270,8 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} }; return Annotator(Creator().CreateGraph(pattern)).Annotate(); diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 6a66f60aa68b..5d91dcb14056 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-wildcard-import +import numpy as np + import tvm from tvm import relay +from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import * from tvm.relay.testing import run_opt_pass -import numpy as np # NB: 1 corresponds to the C++ enum that specicfies this # we loose the type safety due to the Python/C++ calling @@ -26,21 +29,30 @@ K_ELEMWISE = 0 K_BROADCAST = 1 + ## NODE TESTS def test_expr_pattern(): ep = ExprPattern(relay.var('x', shape=(4, 1))) assert isinstance(ep, ExprPattern) assert isinstance(ep.expr, relay.Var) + def test_var_pattern(): v = is_input("x") assert isinstance(v, VarPattern) assert v.name == "x" + +def test_constant_pattern(): + c = ConstantPattern() + assert isinstance(c, ConstantPattern) + + def test_wildcard_pattern(): wc = wildcard() assert isinstance(wc, WildcardPattern) + def test_CallPattern(): wc1 = wildcard() wc2 = wildcard() @@ -49,6 +61,7 @@ def test_CallPattern(): assert isinstance(c.args[0], WildcardPattern) assert isinstance(c.args[1], WildcardPattern) + def test_TuplePattern(): wc1 = wildcard() wc2 = wildcard() @@ -57,6 +70,7 @@ def test_TuplePattern(): assert isinstance(t.fields[0], WildcardPattern) assert isinstance(t.fields[1], WildcardPattern) + def test_TupleGetItemPattern(): wc1 = wildcard() wc2 = wildcard() @@ -67,34 +81,42 @@ def test_TupleGetItemPattern(): assert isinstance(tgi.tuple.fields[0], WildcardPattern) assert isinstance(tgi.tuple.fields[1], WildcardPattern) + def test_AltPattern(): is_add_or_sub = is_op('add') | is_op('subtract') assert isinstance(is_add_or_sub, AltPattern) + def test_TypePattern(): ttype = relay.TensorType((10, 10), "float32") ty_pat = has_type(ttype) assert isinstance(ty_pat, TypePattern) assert ty_pat.type == ttype + def test_AttrPattern(): op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE}) assert isinstance(op, AttrPattern) assert op.attrs["TOpPattern"] == K_ELEMWISE + ## MATCHER TESTS + def test_match_op(): assert is_op('add').match(relay.op.op.get("add")) + def test_no_match_op(): assert not is_op('add').match(relay.op.op.get("subtract")) + def test_match_op_or(): is_add_or_sub = is_op('add') | is_op('subtract') assert is_add_or_sub.match(relay.op.op.get("add")) assert is_add_or_sub.match(relay.op.op.get("subtract")) + def test_match_call_commutive(): x = relay.var('x') y = relay.var('y') @@ -105,6 +127,7 @@ def test_match_call_commutive(): assert mul_pattern.match(x * y) assert mul_pattern.match(y * x) + def test_no_match_call_commutive(): x = relay.var('x') y = relay.var('y') @@ -115,26 +138,27 @@ def test_no_match_call_commutive(): assert add_pattern.match(x / y) assert not add_pattern.match(y / x) + def test_match_call(): x = relay.var('x') y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) assert add_pattern.match(x + y) + def test_no_match_call(): x = relay.var('x') y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) assert not add_pattern.match(x - y) + def test_match_option(): x = relay.var('x') w = relay.var('w') b = relay.var('b') - pattern = is_op("nn.relu")( - is_op("nn.conv2d")(wildcard(), wildcard() - ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) - ) + pattern = is_op("nn.relu")(is_op("nn.conv2d")( + wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))) conv2d = relay.op.nn.conv2d(x, w) relu = relay.op.nn.relu(conv2d) @@ -159,14 +183,13 @@ def test_match_option(): assert pattern.match(tanh2) assert not pattern.match(relu2) + def test_no_match_option(): x = relay.var('x') w = relay.var('w') b = relay.var('b') - pattern = is_op("nn.relu")( - is_op("nn.conv2d")(wildcard(), wildcard() - ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) - ) + pattern = is_op("nn.relu")(is_op("nn.conv2d")( + wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))) conv2d = relay.op.nn.conv2d(x, w) relu = relay.op.tanh(conv2d) @@ -186,46 +209,62 @@ def test_no_match_option(): relu = relay.op.nn.relu(bias_add) assert not pattern.match(relu) -def test_match_tuple(): - x = relay.var('x') - y = relay.var('y') - z = relay.op.op.get("add") - tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) - assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) -def test_no_match_tuple(): - x = relay.var('x') - y = relay.var('y') - z = relay.op.op.get("add") - tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard())) - assert not tuple_pattern.match(relay.expr.Tuple((x,y,z))) +def test_match_const(): + conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern()) + pattern = is_op('nn.bias_add')(conv2d, wildcard()) + + x = relay.var('x', shape=(1, 3, 224, 224)) + w = relay.var('w', shape=(3, 3, 3, 3)) + b = relay.var('b', shape=(3, )) + conv2d = relay.op.nn.conv2d(x, w) + out = relay.op.nn.bias_add(conv2d, b) + func = relay.Function([x, w, b], out) + mod = tvm.IRModule.from_expr(func) + + assert not pattern.match(mod['main'].body) + mod["main"] = bind_params_by_name(mod["main"], + {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))}) + assert pattern.match(mod['main'].body) + def test_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + assert tuple_pattern.match(relay.expr.Tuple((x, y, z))) + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) - assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1)) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) + def test_no_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard())) + assert not tuple_pattern.match(relay.expr.Tuple((x, y, z))) + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"))) tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) - assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2)) + assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple( + (x, y, z)), 2)) + def test_match_type(): x = relay.var('x', shape=(10, 10), dtype="float32") ty_pat = has_type(relay.TensorType((10, 10), "float32")) assert ty_pat.match(x) + def test_no_match_type(): x = relay.var('x', shape=(10, 10), dtype="int32") ty_pat = has_type(relay.TensorType((10, 10), "float32")) assert not ty_pat.match(x) + def test_match_op_attr(): op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) op_pat = op(wildcard(), wildcard()) @@ -233,6 +272,7 @@ def test_match_op_attr(): y = relay.var('y') assert op_pat.match(x + y) + def test_no_match_op_attr(): op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE}) op_pat = op(wildcard(), wildcard()) @@ -245,6 +285,7 @@ def test_no_match_op_attr(): y = relay.var('y') assert not op_pat.match(x - y) + def test_match_func_attr(): pattern = wildcard().has_attr({"Composite": "add"}) x = relay.var('x') @@ -252,6 +293,7 @@ def test_match_func_attr(): f = relay.Function([x, y], x + y).with_attr("Composite", "add") assert pattern.match(f) + def test_no_match_func_attr(): pattern = wildcard().has_attr({"Composite": "add"}) x = relay.var('x') @@ -262,12 +304,14 @@ def test_no_match_func_attr(): f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias") assert not pattern.match(f) + def test_match_call_attr(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"}) x = relay.var('x') y = relay.var('y') assert is_conv2d.match(relay.op.nn.conv2d(x, y)) + def test_no_match_call_attr(): x = relay.var('x') y = relay.var('y') @@ -278,6 +322,7 @@ def test_no_match_call_attr(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"}) assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + def test_match_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -296,6 +341,7 @@ def test_match_diamond(): # Check assert diamond.match(out) + def test_no_match_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -309,12 +355,12 @@ def test_no_match_diamond(): conv2d = relay.op.nn.conv2d(inp, weight) relu = relay.op.nn.relu(conv2d) leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) - out = relu + leaky_relu # Check assert not diamond.match(leaky_relu) assert not diamond.match(relu) + def test_match_fake_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -380,10 +426,11 @@ def test_match_dominator(): # Check assert diamond.match(out) - + # Fuzzy path/nested Diamond is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -398,6 +445,7 @@ def test_match_dominator(): assert diamond.match(out) + def test_not_match_dominator(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) @@ -451,36 +499,47 @@ def test_not_match_dominator(): # Check assert not diamond.match(out) + def test_rewrite(): x = relay.var('x') y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) sub_pattern = is_op('subtract')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): def __init__(self): self.pattern = add_pattern + def callback(self, pre, post, node_map): return post.args[0] - post.args[1] + out = rewrite(TestRewrite(), x + y) assert sub_pattern.match(out) + def test_rewrite_func(): x = relay.var('x') w = relay.var('w') y = relay.var('y') add_pattern = is_op('add')(wildcard(), wildcard()) sub_pattern = is_op('subtract')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): def __init__(self): self.pattern = add_pattern + def callback(self, pre, post, node_map): return post.args[0] - post.args[1] + inpf = relay.var("input") weightf = relay.var("weight") - func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None) - out = rewrite(TestRewrite(), func(x,w) + y) + func = relay.Function([inpf, weightf], + relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), + attrs=None) + out = rewrite(TestRewrite(), func(x, w) + y) assert sub_pattern.match(out) + def test_nested_rewrite(): class PatternCallback(DFPatternCallback): def __init__(self, pattern): @@ -510,6 +569,7 @@ def pattern(): assert tvm.ir.structural_equal(out, new_out) + def test_not_fuse_multi_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -528,6 +588,7 @@ def test_not_fuse_multi_diamond(): # Check assert not diamond.match(out) + class BatchnormCallback(DFPatternCallback): def __init__(self): self.x = wildcard() @@ -536,8 +597,9 @@ def __init__(self): self.beta = wildcard() self.gamma = wildcard() self.eps = wildcard() - - self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta + + self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \ + self.beta def callback(self, pre, post, node_map): x = node_map[self.x][0] @@ -546,7 +608,9 @@ def callback(self, pre, post, node_map): beta = node_map[self.beta][0] gamma = node_map[self.gamma][0] eps = node_map[self.eps][0] - return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0] + return relay.op.nn.batch_norm(x, gamma, beta, mean, var, + epsilon=eps.data.asnumpy().item())[0] + def test_fuse_batchnorm(): x = relay.var('x') @@ -554,11 +618,14 @@ def test_fuse_batchnorm(): mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + def test_no_fuse_batchnorm(): x = relay.var('x') @@ -566,75 +633,85 @@ def test_no_fuse_batchnorm(): mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - - fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta + + fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta out = rewrite(BatchnormCallback(), fake_BN) assert tvm.ir.structural_equal(out, fake_BN) + def test_fuse_double_batchnorm(): x = relay.var('x') var = relay.var('var') mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta - BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN2) - bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0] - bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0] + bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] + bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0] assert tvm.ir.structural_equal(out, bn2) + def test_partial_fuse_double_batchnorm(): x = relay.var('x') var = relay.var('var') mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - - BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta - BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN2) - bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0] + bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0] assert tvm.ir.structural_equal(out, bn2) + def test_fuse_batchnorm_commutation(): x = relay.var('x') var = relay.var('var') mean = relay.var('mean') beta = relay.var('beta') gamma = relay.var('gamma') - + #commute add - BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) # associate divide/multiply - BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5)) + beta + BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) # associate multiply/divide - BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta + BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + def test_quadruple_rewrite_dominator(): class DominatorRemovalCallback(DFPatternCallback): def __init__(self): self.inp = wildcard() self.weight = wildcard() - is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) - is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -674,7 +751,6 @@ def callback(self, pre, post, node_map): tanh = relay.op.tanh(relu) leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) out = tanh + leaky_relu - one = relay.op.nn.conv2d(inp, weight) two = relay.op.nn.conv2d(one, weight) three = relay.op.nn.conv2d(two, weight) @@ -682,18 +758,20 @@ def callback(self, pre, post, node_map): assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four) + def algebraic_simplify(expr): zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) + class ElwiseNullCallback(DFPatternCallback): def callback(self, pre, post, node_map): - return node_map[self.x][0] + return node_map[self.x][0] # pylint: disable=no-member class AddCallback(ElwiseNullCallback): def __init__(self): self.x = wildcard() self.pattern = self.x + zero - + class SubCallback(ElwiseNullCallback): def __init__(self): self.x = wildcard() @@ -719,17 +797,19 @@ def __init__(self): self.x = zero self.pattern = self.x / wildcard() - return rewrite([AddCallback(), - SubCallback(), - MulCallback(), - DivCallback(), - MulZeroCallback(), - ZeroDivCallback() - ], expr); + return rewrite([ + AddCallback(), + SubCallback(), + MulCallback(), + DivCallback(), + MulZeroCallback(), + ZeroDivCallback() + ], expr) + def test_algebraic_simplify(): x = relay.Var('x') - y = relay.Var('y') + y = relay.Var('y') one = relay.const(1) zero = relay.const(0) @@ -740,23 +820,25 @@ def test_algebraic_simplify(): assert algebraic_simplify(x + zerof) == x assert algebraic_simplify(zero + x) == x assert algebraic_simplify(zerof + x) == x - + assert algebraic_simplify(x - zero) == x assert algebraic_simplify(x - zerof) == x - + assert algebraic_simplify(x * one) == x assert algebraic_simplify(x * onef) == x assert algebraic_simplify(one * x) == x assert algebraic_simplify(onef * x) == x assert algebraic_simplify(x * zero) == zero assert algebraic_simplify(x * zerof) == zerof - + assert algebraic_simplify(x / one) == x assert algebraic_simplify(x / onef) == x assert algebraic_simplify(zero / x) == zero assert algebraic_simplify(zerof / x) == zerof - assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) + assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), + x + y) + def test_double_partition(): # Pattern 1 @@ -780,19 +862,30 @@ def test_double_partition(): for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]: partitioned = pat.partition(partitioned, {"Composite": label}) - inpf = relay.var("input") weightf = relay.var("weight") biasf = relay.var("bias") - func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_") + func0 = relay.Function( + [inpf, weightf, biasf], + relay.op.nn.relu(relay.op.nn.bias_add( + relay.op.nn.conv2d(inpf, weightf), + biasf))).with_attr("Composite", + "conv_bias_relu").with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") inpf = relay.var("input") weightf = relay.var("weight") biasf = relay.var("bias") - func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_") + func1 = relay.Function([inpf, weightf, biasf], + relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), + biasf)).with_attr("Composite", + "conv_bias").with_attr( + "PartitionedFromPattern", + "nn.conv2d_nn.bias_add_") expected = func1(func0(x, w, b), w2, b2) assert tvm.ir.structural_equal(partitioned, expected) + def test_partition_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -803,31 +896,36 @@ def test_partition_dominator(): # Classic Diamond inp = relay.var('input') weight = relay.var('weight') + def generate_diamond(inp, weight): conv2d = relay.op.nn.conv2d(inp, weight) relu = relay.op.nn.relu(conv2d) relu = relay.op.nn.relu(relu) leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) return relu + leaky_relu - out = generate_diamond(inp*inp, weight*weight) + + out = generate_diamond(inp * inp, weight * weight) # Check partitioned = diamond.partition(out) - + i = relay.Var("input") w = relay.Var("weight") - f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_") - assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight)) + f = relay.Function([i, w], generate_diamond(i, w)).with_attr( + "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_") + assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight)) + def test_quadruple_partition_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) - inp = relay.var('input') weight = relay.var('weight') + # Classic Diamond def classic_diamond(inp, weight): conv2d = relay.op.nn.conv2d(inp, weight) @@ -863,40 +961,31 @@ def nested_diamond(inp, weight): return tanh + leaky_relu partitioned = diamond.partition( - nested_diamond( - single_branch( - deeper_diamond( - classic_diamond(inp, weight), - weight), - weight), - weight - ) - ) + nested_diamond(single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight), + weight)) functions = [] partition_names = [ "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_", - "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", - "nn.conv2d_nn.relu_nn.relu_tanh_add_", + "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", "nn.conv2d_nn.relu_nn.relu_tanh_add_", "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_" ] for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]): inpf = relay.var("input") weightf = relay.var("weight") - functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i])) - - reference = functions[3]( - functions[2]( - functions[1]( - functions[0](inp, weight), - weight), - weight), - weight - ) + functions.append( + relay.Function([inpf, weightf], f(inpf, + weightf)).with_attr("PartitionedFromPattern", + partition_names[i])) + + reference = functions[3](functions[2](functions[1](functions[0](inp, weight), weight), weight), + weight) assert tvm.ir.structural_equal(partitioned, reference) + def get_BN(x, var, mean, beta, gamma, eps): - return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta + return gamma * (x - mean) / relay.op.sqrt(var + eps) + beta + def test_partition_batchnorm(): x = relay.var('x') @@ -907,7 +996,6 @@ def test_partition_batchnorm(): eps = relay.const(1e-5) BN = get_BN(x, var, mean, beta, gamma, eps) - xf = relay.var('xf') varf = relay.var('varf') meanf = relay.var('meanf') @@ -915,11 +1003,15 @@ def test_partition_batchnorm(): gammaf = relay.var('gammaf') epsf = relay.var('epsf') # Put the arguments in toplogological order for the reference - f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], + get_BN(xf, varf, meanf, betaf, gammaf, + epsf)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN) assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta)) + def test_partition_double_batchnorm(): x = relay.var('x') var = relay.var('var') @@ -927,9 +1019,9 @@ def test_partition_double_batchnorm(): beta = relay.var('beta') gamma = relay.var('gamma') eps = relay.const(1e-5) - - BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta - BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta + + BN = gamma * (x - mean) / relay.op.sqrt(var + eps) + beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + eps) + beta xf = relay.var('xf') varf = relay.var('varf') @@ -937,7 +1029,10 @@ def test_partition_double_batchnorm(): betaf = relay.var('betaf') gammaf = relay.var('gammaf') epsf = relay.var('epsf') - f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], + get_BN(xf, varf, meanf, betaf, gammaf, + epsf)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") # The partitioner doesn't replace duplicates, so we use two copies of the function xf2 = relay.var('xf2') varf2 = relay.var('varf2') @@ -945,14 +1040,19 @@ def test_partition_double_batchnorm(): betaf2 = relay.var('betaf2') gammaf2 = relay.var('gammaf2') epsf2 = relay.var('epsf2') - f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], + get_BN(xf2, varf2, meanf2, betaf2, gammaf2, + epsf2)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN2) reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta) assert tvm.ir.structural_equal(partitioned, reference) + def test_partition_check(): pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + def check(pre): return pre.args[0].attrs.data_layout == "NCHW" @@ -965,7 +1065,8 @@ def check(pre): wf = relay.var('weight') conv2df = relay.op.nn.conv2d(xf, wf) reluf = relay.op.nn.relu(conv2df) - func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_") + func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.relu_") reference = func(x, w) partitioned = pattern.partition(relu, check=check) @@ -975,8 +1076,10 @@ def check(pre): relu = relay.op.nn.relu(conv2d) assert relu == pattern.partition(relu, check=check) + def test_partition_check_types(): pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + def check(pre): conv = pre.args[0] return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1) @@ -1002,6 +1105,7 @@ def check(pre): relu = run_opt_pass(relu, relay.transform.InferType()) assert relu == pattern.partition(relu, check=check) + def test_partition_option(): x = relay.var('x') w = relay.var('w') @@ -1020,12 +1124,15 @@ def conv_bias_relu(x, w, b): bias_add = relay.op.nn.bias_add(conv2d, b) relu = relay.op.nn.relu(bias_add) return relu + relu = conv_bias_relu(x, w, b) xf = relay.var('x') wf = relay.var('w') bf = relay.var('b') - func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_") + func = relay.Function([xf, wf, bf], + conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") assert pattern1.match(relu) assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu)) @@ -1033,33 +1140,55 @@ def conv_bias_relu(x, w, b): assert pattern2.match(relu) assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) + if __name__ == "__main__": + test_expr_pattern() + test_var_pattern() + test_constant_pattern() + test_wildcard_pattern() + test_CallPattern() + test_TuplePattern() + test_TupleGetItemPattern() + test_AltPattern() + test_TypePattern() + test_AttrPattern() test_match_op() test_no_match_op() test_match_op_or() - test_match_call() - test_no_match_call() test_match_call_commutive() test_no_match_call_commutive() + test_match_call() + test_no_match_call() + test_match_option() + test_no_match_option() + test_match_const() test_match_tuple() test_no_match_tuple() test_match_type() test_no_match_type() - test_match_attr() - test_no_match_attr() + test_match_op_attr() + test_no_match_op_attr() + test_match_func_attr() + test_no_match_func_attr() + test_match_call_attr() + test_no_match_call_attr() test_match_diamond() test_no_match_diamond() test_match_fake_diamond() + test_match_dominator() + test_not_match_dominator() test_rewrite() + test_rewrite_func() test_nested_rewrite() + test_not_fuse_multi_diamond() test_fuse_batchnorm() test_no_fuse_batchnorm() test_fuse_double_batchnorm() test_partial_fuse_double_batchnorm() test_fuse_batchnorm_commutation() - test_match_dominator() - test_not_match_dominator() + test_quadruple_rewrite_dominator() test_algebraic_simplify() + test_double_partition() test_partition_dominator() test_quadruple_partition_dominator() test_partition_batchnorm()