diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index bb53ad32d9f4..12e4e3f45fef 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -42,11 +42,16 @@ class DFPatternCallback; class DFPatternCallbackNode : public Object { public: /*! \brief Pattern this callback matches */ - DFPattern pattern_; + DFPattern pattern; /*! \brief Function to call when finding a matched expression */ - PackedFunc function_; + PackedFunc function; + /*! \brief Require InferType to be run before the callback */ + bool require_type; - void VisitAttrs(tvm::AttrVisitor* v) {} + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("require_type", &require_type); + } static constexpr const char* _type_key = "DFPatternCallbackNode"; TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); @@ -58,7 +63,7 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: - TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback); + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; @@ -77,11 +82,12 @@ bool MatchPattern(DFPattern pattern, Expr expr); * * \param callbacks An array of DFPatternCallback Nodes * \param expr The expression to rewrite + * \param mod The module that associates with the expr * * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the * functions inside the callbacks */ -Expr RewritePatterns(Array callbacks, Expr expr); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod = IRModule()); /*! * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1b8b31aee5d1..d995301c1688 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -360,6 +360,13 @@ TVM_DLL Pass Inline(); */ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); +/*! + * \brief Simplify the Relay expression. + * + * \return The pass. + */ +TVM_DLL Pass SimplifyExpr(); + } // namespace transform /*! diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 317d28e1dbea..03bdd1952fa1 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -22,6 +22,7 @@ from tvm.relay.expr import RelayExpr as Expr from ... import _ffi as tvm_ffi +from ... import ir as _ir from ...ir import make_node from ...ir.base import Node from ...runtime import Object @@ -687,7 +688,15 @@ class DFPatternCallback: the callback returns. Users are expect to inherit from this class and provide a "self.pattern" to match + + Parameters + ---------- + require_type: bool + Whether InferType is required to be run before the callback. """ + def __init__(self, require_type=False): + self.pattern = None + self.require_type = require_type def rewrite(self, expr: Expr) -> Expr: """ @@ -727,11 +736,11 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp class _DFPatternCallback(Object): """C++ implemenation""" - def __init__(self, pattern, callback): - self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback) + def __init__(self, pattern, callback, require_type): + self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) -def rewrite(callbacks, expr: Expr) -> Expr: +def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: """ Rewrite expression with the given callbacks. @@ -741,20 +750,23 @@ def rewrite(callbacks, expr: Expr) -> Expr: The input callback or list of callbacks. expr : tvm.relay.Expr The expression to rewrite. + mod : Optional[tvm.ir.IRModule] + The module that associates with the expression. Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs rewritten by the callbacks. """ - if isinstance(callbacks, DFPatternCallback): - tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] - else: - tmp = [] - for callback in callbacks: - tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) + if mod is None: + mod = _ir.IRModule() + callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks + tmp = [] + for callback in callbacks: + assert callback.pattern is not None + tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) - return ffi.rewrite(tmp, expr) + return ffi.rewrite(tmp, expr, mod) def partition(pattern: "DFPattern", diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 878b82a19a36..dc1265870475 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -52,7 +52,7 @@ _reg.register_injective_schedule("take") _reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") -_reg.register_injective_schedule("_contrib_reverse_reshape") +_reg.register_injective_schedule("contrib_reverse_reshape") _reg.register_injective_schedule("gather") _reg.register_injective_schedule("gather_nd") _reg.register_injective_schedule("sequence_mask") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 83008a9c1cc5..ae10dd50a87f 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -912,7 +912,7 @@ def reverse_reshape(data, newshape): """ if isinstance(newshape, int): newshape = [newshape] - return _make._contrib_reverse_reshape(data, list(newshape)) + return _make.contrib_reverse_reshape(data, list(newshape)) def gather(data, axis, indices): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index ede63808d4fd..7db068785ba6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -909,6 +909,7 @@ def DenseToSparse(weight_name, weight_shape): """ return _ffi_api.DenseToSparse(weight_name, weight_shape) + def SimplifyFCTranspose(target_weight_name): """ Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` @@ -926,3 +927,15 @@ def SimplifyFCTranspose(target_weight_name): The registered SimplifyFCTranspose pass. """ return _ffi_api.SimplifyFCTranspose(target_weight_name) + + +def SimplifyExpr(): + """ + Simplify the Relay expression, including merging consecutive reshapes. + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyExpr pass. + """ + return _ffi_api.SimplifyExpr() diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b589bcce99fc..b57c0eb8cbdb 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -276,6 +276,7 @@ class RelayBuildModule : public runtime::ModuleNode { } }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d01dbda24a4c..585b8033be8d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -945,10 +945,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe *rv = false; }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); @@ -959,6 +961,8 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::AlterOpLayout()); } + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FuseOps()); diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 57b3013fd04b..50c05f2923bc 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -390,6 +390,34 @@ Expr InferType(const Expr& expr) { } } +Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { + IRModule mod(m->functions, m->type_definitions, m->Imports()); + int idx = 0; + std::string gv_name; + do { + std::ostringstream oss; + oss << "_tmp" << idx; + gv_name = oss.str(); + ++idx; + } while (mod->ContainGlobalVar(gv_name)); + GlobalVar gvar(gv_name); + BaseFunc func; + if (expr.as()) { + func = Downcast(expr); + } else { + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + } + mod->Add(gvar, func); + mod = transform::InferType()(mod); + Expr ret; + if (expr.as()) { + ret = mod->Lookup(gvar); + } else { + ret = mod->Lookup(gvar).as()->body; + } + return ret; +} + bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); @@ -436,7 +464,8 @@ bool MatchPattern(DFPattern pattern, Expr expr) { TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); -/* \brief PatternGrouper does pre-rewriting pattern matching and analysis +/*! + * \brief PatternGrouper does pre-rewriting pattern matching and analysis * * This class creates a number of groups of matched expressions, ensures they don't overlap, and * returns them to the caller for post-analysis rewriting. @@ -446,7 +475,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern) */ class PatternGrouper { public: - /* \brief Internal Group class for storing analysis */ + /*! \brief Internal Group class for storing analysis */ struct Group { Expr root_node; int gid; @@ -456,11 +485,11 @@ class PatternGrouper { Array args; }; - /* \brief Return the group assignments of expressions */ + /*! \brief Return the group assignments of expressions */ const std::unordered_map& GetGIDAssignments() { return gid_assignments_; } - /* \brief Group expressions that match the pattern */ + /*! \brief Group expressions that match the pattern */ const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre) { groups_.clear(); gid_assignments_.clear(); @@ -474,7 +503,7 @@ class PatternGrouper { } protected: - /* \brief Iteratively traverse the Expression in pre-order to find subgraphs + /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs * * If we traverse the graph in post-order, we can run into situtations where a small subgraph will * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in @@ -501,7 +530,7 @@ class PatternGrouper { } } } - /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + /*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform * group overlap analysis */ class MatchExtractor : public ExprMutator { public: @@ -563,7 +592,7 @@ class PatternGrouper { const std::unordered_map inputs_; }; - /* \brief Create a group based on a matched expression */ + /*! \brief Create a group based on a matched expression */ void CreateGroup(const Expr& expr) { int var_number = 0; @@ -661,7 +690,7 @@ class PatternGrouper { groups_[group.gid] = std::move(group); } - /* \brief EmbedConst implements rules for embedding constants into partitioned functions or + /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or * lifting them into the function arguments. * * The rules depend on what pattern the ConstantNode matched. @@ -703,28 +732,30 @@ class PatternGrouper { // Rewrite -DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) { ObjectPtr n = make_object(); - n->pattern_ = std::move(pattern); - n->function_ = std::move(function); + n->pattern = std::move(pattern); + n->function = std::move(function); + n->require_type = require_type; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed([](DFPattern pattern, PackedFunc function) { - return DFPatternCallback(pattern, function); + .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) { + return DFPatternCallback(pattern, function, require_type); }); -/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback +/*! + * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback * function to rewrite those matches * * The class uses PatternGrouper to support the dominator pattern. */ class PatternRewriter : protected MixedModeMutator { public: - PatternRewriter() {} + PatternRewriter(IRModule mod) : mod_(mod) {} /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the * callbacks until it stops changing */ Expr Rewrite(const Array& callbacks, const Expr& pre) { @@ -732,20 +763,27 @@ class PatternRewriter : protected MixedModeMutator { auto last = post; // rewrite the graph until it stops changing to make sure all rewrites are complete int count = 0; + bool equal = true; + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + CHECK(structural_equal) << "node.StructuralEqual is not registered."; do { last = post; for (auto callback : callbacks) { callback_ = callback; + if (callback_->require_type) { + post = InferTypeWithModule(post, mod_); + } auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern_, post); + groups_ = grouper.GroupMatches(callback_->pattern, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); post = this->VisitExpr(post); count++; } - } while (last != post || count >= 100); + equal = (*structural_equal)(last, post, false, true); + } while (!equal && count < 100); if (count >= 100) { - throw("Observed 100 rewrite passes, possible conflicting passes?"); + LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } return post; } @@ -765,23 +803,25 @@ class PatternRewriter : protected MixedModeMutator { node_map.insert({kv.first, tmp}); } // run the user callback function - return callback_->function_(pre, post, Map>(node_map)); + return callback_->function(pre, post, Map>(node_map)); } return post; } + IRModule mod_; DFPatternCallback callback_; std::unordered_map groups_; std::unordered_map gid_assignments_; }; -Expr RewritePatterns(Array callbacks, Expr expr) { - return PatternRewriter().Rewrite(callbacks, expr); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod) { + return PatternRewriter(mod).Rewrite(callbacks, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); -/* \brief PatternPartitioner replaces expressions that match a pattern with function call that +/*! + * \brief PatternPartitioner replaces expressions that match a pattern with function call that * perform the same computation but allow for further analysis and lowering. * * The class uses PatternGrouper to support the dominator pattern. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 85e8671cf8d5..cc1150cb9bae 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2554,13 +2554,13 @@ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; - static const Op& op = Op::Get("_contrib_reverse_reshape"); + static const Op& op = Op::Get("contrib_reverse_reshape"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape); +TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape").set_body_typed(MakeReverseReshape); -RELAY_REGISTER_OP("_contrib_reverse_reshape") +RELAY_REGISTER_OP("contrib_reverse_reshape") .describe(R"code(Reshapes the input array where the special values are inferred from right to left. diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc new file mode 100644 index 000000000000..079b86715a48 --- /dev/null +++ b/src/relay/transforms/simplify_expr.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/simplify_expr.cc + * \brief A pass for simplifying the Relay expression. + */ + +#include +#include +#include +#include +#include + +#include "../op/tensor/transform.h" + +namespace tvm { +namespace relay { + +static Op reshape_op = Op::Get("reshape"); +static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape"); + +/*! + * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, + * and merges into one reshape op. + */ +class SimplifyReshape { + public: + SimplifyReshape() { + x_ = WildcardPattern(make_object()); + auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); + auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); + pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {}); + } + + Expr callback(const Expr& pre, const Expr& post, const Map>& node_map) { + auto x = node_map[x_][0]; + bool const_shape = true; + Array newshape; + for (auto dim : Downcast(pre->checked_type())->shape) { + if (dim.as() == nullptr) { + const_shape = false; + break; + } + newshape.push_back(Downcast(dim)); + } + if (const_shape) { + return MakeReshape(x, newshape); + } + return post; + } + + DFPattern pattern() const { return pattern_; } + + private: + /*! \brief Pattern input */ + DFPattern x_; + /*! \brief Pattern for consecutive reshape or reverse_reshape ops */ + DFPattern pattern_; +}; + +/*! + * \brief ExprSimplifier simplifies the Relay expression. + */ +class ExprSimplifier { + public: + explicit ExprSimplifier(IRModule mod) : mod_(mod) { + auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = simplify_reshape_.callback(pre, post, node_map); + }; + callbacks_.push_back( + DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true)); + } + + Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } + + private: + IRModule mod_; + /*! \brief Simplify reshape pattern */ + SimplifyReshape simplify_reshape_; + /*! \brief Callbacks for expr simplification */ + Array callbacks_; +}; + +Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { + return ExprSimplifier(mod).Simplify(expr); +} + +namespace transform { + +Pass SimplifyExpr() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyExpr(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f390b720b80a..34a098731b86 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -599,6 +599,7 @@ def test_rewrite(): class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): @@ -617,6 +618,7 @@ def test_rewrite_func(): class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): @@ -634,6 +636,7 @@ def callback(self, pre, post, node_map): def test_nested_rewrite(): class PatternCallback(DFPatternCallback): def __init__(self, pattern): + super(PatternCallback, self).__init__() self.pattern = pattern def callback(self, pre, post, node_map): @@ -682,6 +685,7 @@ def test_not_fuse_multi_diamond(): class BatchnormCallback(DFPatternCallback): def __init__(self): + super(BatchnormCallback, self).__init__() self.x = wildcard() self.var = wildcard() self.mean = wildcard() @@ -798,6 +802,7 @@ def test_fuse_batchnorm_commutation(): def test_quadruple_rewrite_dominator(): class DominatorRemovalCallback(DFPatternCallback): def __init__(self): + super(DominatorRemovalCallback, self).__init__() self.inp = wildcard() self.weight = wildcard() is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) @@ -860,31 +865,37 @@ def callback(self, pre, post, node_map): class AddCallback(ElwiseNullCallback): def __init__(self): + super(AddCallback, self).__init__() self.x = wildcard() self.pattern = self.x + zero class SubCallback(ElwiseNullCallback): def __init__(self): + super(SubCallback, self).__init__() self.x = wildcard() self.pattern = self.x - zero class MulCallback(ElwiseNullCallback): def __init__(self): + super(MulCallback, self).__init__() self.x = wildcard() self.pattern = self.x * one class DivCallback(ElwiseNullCallback): def __init__(self): + super(DivCallback, self).__init__() self.x = wildcard() self.pattern = self.x / one class MulZeroCallback(ElwiseNullCallback): def __init__(self): + super(MulZeroCallback, self).__init__() self.x = zero self.pattern = self.x * wildcard() class ZeroDivCallback(ElwiseNullCallback): def __init__(self): + super(ZeroDivCallback, self).__init__() self.x = zero self.pattern = self.x / wildcard() @@ -1265,6 +1276,7 @@ def test_match_match(): add_pattern = is_op('add')(wildcard(), wildcard()) class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): return post.args[0] - post.args[1] diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py new file mode 100644 index 000000000000..e934c11a6370 --- /dev/null +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_opt_pass + +def test_simplify_reshape(): + def before(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(1, 16, -1)) + y = relay.reshape(y, newshape=(4, 8, -1, 16)) + y = relay.reverse_reshape(y, newshape=(32, 0, -1)) + return relay.Function([x, w], y) + + def expected(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(32, 16, 16)) + return relay.Function([x, w], y) + + def symbolic(): + b = tvm.te.size_var('b') + x = relay.var("x", shape=(b, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(1, 16, -1)) + y = relay.reshape(y, newshape=(4, 8, -1, 16)) + y = relay.reverse_reshape(y, newshape=(32, 0, -1)) + return relay.Function([x, w], y) + + z = before() + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + z = symbolic() + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(symbolic(), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + +if __name__ == "__main__": + test_simplify_reshape()