-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] Merge two consecutive reshape ops #6052
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. It seems fine to add InferType to pattern rewrite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor changes, but overall it's great, thank you!
I've been thinking about adding an algebraic simplifier to Relay, and this is a perfect first step to that. One question though: I would think that keeping the pattern/callbacks in python would cause a performance hit as we grow the number of simplification cases. Would it be better to lower this into C++?
src/relay/ir/dataflow_matcher.cc
Outdated
@@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator { | |||
groups_ = grouper.GroupMatches(callback_->pattern_, post); | |||
gid_assignments_ = grouper.GetGIDAssignments(); | |||
memo_.clear(); | |||
post = this->VisitExpr(post); | |||
post = InferType(this->VisitExpr(post)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is failing all of the pattern language unit tests because they don't assume you need a typed graph for pattern matching. Maybe we should make this behavior optional? Or do we change the API to assert that Expressions have to be well typed to run the pattern rewriter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree to make the InferType optional, but assertion may not work, as one pattern may rewrite a graph multiple times, so the rewritten nodes are still not typed even the original nodes are well typed before running rewriter. One solution is requiring users to manually type new nodes in the rewrite callback, but it seems not trivial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, in the multi-stage rewrite scenario, it makes sense to have the InferType here.
@transform.function_pass(opt_level=0, required=["InferType"]) | ||
class SimplifyExpr: | ||
""" A pass to simplify the Relay expression.""" | ||
def __init__(self): | ||
self.callbacks = [SimplifyReshapeCallback()] | ||
|
||
def transform_function(self, func, mod, _): | ||
return rewrite(self.callbacks, func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) I've been thinking about putting together an algebraic simplifier for a while, this seems like a great first step.
@mbrookhart I'll try to move the pass to C++ and make the infer type to optional. |
Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Two minor nitpicks, up to you if you want to address them.
Thanks!
DFPattern pattern; | ||
/*! \brief Function to call when finding a matched expression */ | ||
PackedFunc function_; | ||
PackedFunc function; | ||
/*! \brief Require InferType to be run before the callback */ | ||
bool require_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://tvm.apache.org/docs/contribute/code_guide.html
https://google.github.io/styleguide/cppguide.html#Variable_Names
Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { | ||
Expr pre = args[0]; | ||
Expr post = args[1]; | ||
Map<DFPattern, Array<Expr>> node_map = args[2]; | ||
*rv = simplify_reshape_.callback(pre, post, node_map); | ||
}; | ||
callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe have SimplifyReshape directly inherit DFPatternCallback? You could fold this directly into that and keep it out of the main Simplifier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason that I didn't inherit directly from DFPatternCallback
is because you need to create the pattern somewhere else as it's required in the DFPatternCallback
constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:/ I think I focused too much on the Python API and left an Ugly C++ API. I'll see if I can clean that up in a follow up PR. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. :)
@@ -599,6 +599,7 @@ def test_rewrite(): | |||
|
|||
class TestRewrite(DFPatternCallback): | |||
def __init__(self): | |||
super(TestRewrite, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating these!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @zhiics @icemelon9 @mbrookhart |
Use pattern matching rewriter to merge two consecutive reshape ops.
@mbrookhart I added an InferType pass after rewriting each pattern. I think this change can make the pattern rewriter more useful, at least I need this feature in my pass. Does it make sense to you?