Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Merge two consecutive reshape ops #6052

Merged
merged 13 commits into from
Jul 16, 2020

Conversation

icemelon
Copy link
Member

@icemelon icemelon commented Jul 14, 2020

Use pattern matching rewriter to merge two consecutive reshape ops.

@mbrookhart I added an InferType pass after rewriting each pattern. I think this change can make the pattern rewriter more useful, at least I need this feature in my pass. Does it make sense to you?

@icemelon
Copy link
Member Author

cc @zhiics @comaniac @jroesch

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. It seems fine to add InferType to pattern rewrite.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor changes, but overall it's great, thank you!

I've been thinking about adding an algebraic simplifier to Relay, and this is a perfect first step to that. One question though: I would think that keeping the pattern/callbacks in python would cause a performance hit as we grow the number of simplification cases. Would it be better to lower this into C++?

tests/python/relay/test_pass_simplify_expr.py Show resolved Hide resolved
@@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator {
groups_ = grouper.GroupMatches(callback_->pattern_, post);
gid_assignments_ = grouper.GetGIDAssignments();
memo_.clear();
post = this->VisitExpr(post);
post = InferType(this->VisitExpr(post));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is failing all of the pattern language unit tests because they don't assume you need a typed graph for pattern matching. Maybe we should make this behavior optional? Or do we change the API to assert that Expressions have to be well typed to run the pattern rewriter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to make the InferType optional, but assertion may not work, as one pattern may rewrite a graph multiple times, so the rewritten nodes are still not typed even the original nodes are well typed before running rewriter. One solution is requiring users to manually type new nodes in the rewrite callback, but it seems not trivial.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, in the multi-stage rewrite scenario, it makes sense to have the InferType here.

Comment on lines 38 to 45
@transform.function_pass(opt_level=0, required=["InferType"])
class SimplifyExpr:
""" A pass to simplify the Relay expression."""
def __init__(self):
self.callbacks = [SimplifyReshapeCallback()]

def transform_function(self, func, mod, _):
return rewrite(self.callbacks, func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) I've been thinking about putting together an algebraic simplifier for a while, this seems like a great first step.

@icemelon
Copy link
Member Author

@mbrookhart I'll try to move the pass to C++ and make the infer type to optional.

@mbrookhart
Copy link
Contributor

Thanks!

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Two minor nitpicks, up to you if you want to address them.

Thanks!

Comment on lines +45 to +49
DFPattern pattern;
/*! \brief Function to call when finding a matched expression */
PackedFunc function_;
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://tvm.apache.org/docs/contribute/code_guide.html
https://google.github.io/styleguide/cppguide.html#Variable_Names

Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.

Comment on lines 83 to 89
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = simplify_reshape_.callback(pre, post, node_map);
};
callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have SimplifyReshape directly inherit DFPatternCallback? You could fold this directly into that and keep it out of the main Simplifier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that I didn't inherit directly from DFPatternCallback is because you need to create the pattern somewhere else as it's required in the DFPatternCallback constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:/ I think I focused too much on the Python API and left an Ugly C++ API. I'll see if I can clean that up in a follow up PR. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. :)

@@ -599,6 +599,7 @@ def test_rewrite():

class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating these!

Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tqchen tqchen merged commit 6360ad1 into apache:master Jul 16, 2020
@tqchen
Copy link
Member

tqchen commented Jul 16, 2020

Thanks @zhiics @icemelon9 @mbrookhart

@icemelon icemelon deleted the simpl-expr branch July 16, 2020 04:34
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Sep 2, 2020
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants