-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr #7731
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Meanwhile I have some questions about the design:
-
It seems to me that ConcretizeLike and EliminateIdentity can also be merged to SImplifyExpr in terms of the implementation and semantic. What's the concern of having 3 separate passes?
-
You mentioned that for a certain reason, EliminateIdentity should be run before SimplifyExpr, but I didn't get the point about what would happen if we run them in the reverse order. Could you elaborate a bit further?
Also cc @mbrookhart |
|
If the purpose is just for testing, then I'll prefer to have them in a single pass. You can still test the pattern one-by-one as SimplifyExpr does now. Since the unrelated patterns won't be matched, I didn't see the problem of testing. In this case, you can also control the order of rewriting patterns. i.e., always run |
I think I agree that merging these with SimplifyExpr would be a win in terms of our ability to control the order of execution. On the simplification of things like I've been meaning to implement this for like 6 months, and I haven't had a strong enough forcing function to bubble it up to the top of my priority list. |
It looks like you have a windows build issue? |
This is just what I needed, thanks! I think since I'm also just checking 0 or 1, there should be no problem casting since the floating point repr should all be the same. I'll fix the Windows issue, it was because I overloaded the same name too much. |
Ah, well looks like this might only make sense for constants that only have one element, unless we want to loop over every single element and check that it is equal to 0 or 1. But if I understand correctly, the Do you guys think this is a reasonable tradeoff? |
Yeah, as long as we aren't commonly manifesting full sized arrays of zero or one, that should be fine. Given the full/zeros/ones ops and their like counterparts, plus auto-broadcasting, I think that's generally a reasonable assumption to make. |
@comaniac @mbrookhart I've merged them and updated the unit tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just nits.
}; | ||
|
||
Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { | ||
return ExprSimplifier(mod).Simplify(expr); | ||
static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may need to comment here if the order is enforced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @mbrookhart , I believe the ordering is respected by the rewriter but because the rewriter iterates until fixed point, I don't think it would be correct to say that globally the order is enforced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(in any case it shouldn't matter now that I've added support for eliminating constants)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, they will do as much as they can to the graph in this order, and then loop back and try again. That being said, that's the same thing you'd get if you ran them as separate passes, you'd rewrite everything you could before the next pass, but the next pass might open opportunities to do more with the current pass if you ran it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few design issues and nitpicks
}; | ||
|
||
Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { | ||
return ExprSimplifier(mod).Simplify(expr); | ||
static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, they will do as much as they can to the graph in this order, and then loop back and try again. That being said, that's the same thing you'd get if you ran them as separate passes, you'd rewrite everything you could before the next pass, but the next pass might open opportunities to do more with the current pass if you ran it again.
DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_}); | ||
DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also simplify x * 0 = 0
and x + 0 = x
?
The pattern matcher should be able to match these ops irrespective of order:
tvm/src/relay/ir/dataflow_matcher.cc
Lines 275 to 281 in cfe2e28
if (const OpNode* op_node = get_op_node(op)) { | |
if ((op_node->name == "add") || (op_node->name == "multiply")) { | |
if (match_args(reverse(op->args), call_node->args)) { | |
return true; | |
} | |
} | |
} |
You can probably get away without the AltPattern here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add x * 0 = 0 although to me it semantically doesn't exactly fit and would need slightly different logic, perhaps a new rewrite ZeroMultiply
or something?
And I wasn't aware of the commutative matching, that's helpful (although I wonder if it should be more visibly defined somewhere?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean by x + 0 = x
as I think I already have that covered, can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you do, sorry!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this roughly implementend in the pattern matcher tests, but never really productized:
tvm/tests/python/relay/test_dataflow_pattern.py
Lines 1039 to 1127 in 3ba5868
def algebraic_simplify(expr): | |
zero = is_expr(relay.const(0)) | is_expr(relay.const(0.0)) | |
one = is_expr(relay.const(1)) | is_expr(relay.const(1.0)) | |
class ElwiseNullCallback(DFPatternCallback): | |
def callback(self, pre, post, node_map): | |
return node_map[self.x][0] # pylint: disable=no-member | |
class AddCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(AddCallback, self).__init__() | |
self.x = wildcard() | |
self.pattern = self.x + zero | |
class SubCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(SubCallback, self).__init__() | |
self.x = wildcard() | |
self.pattern = self.x - zero | |
class MulCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(MulCallback, self).__init__() | |
self.x = wildcard() | |
self.pattern = self.x * one | |
class DivCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(DivCallback, self).__init__() | |
self.x = wildcard() | |
self.pattern = self.x / one | |
class MulZeroCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(MulZeroCallback, self).__init__() | |
self.x = zero | |
self.pattern = self.x * wildcard() | |
class ZeroDivCallback(ElwiseNullCallback): | |
def __init__(self): | |
super(ZeroDivCallback, self).__init__() | |
self.x = zero | |
self.pattern = self.x / wildcard() | |
return rewrite( | |
[ | |
AddCallback(), | |
SubCallback(), | |
MulCallback(), | |
DivCallback(), | |
MulZeroCallback(), | |
ZeroDivCallback(), | |
], | |
expr, | |
) | |
def test_algebraic_simplify(): | |
x = relay.Var("x") | |
y = relay.Var("y") | |
one = relay.const(1) | |
zero = relay.const(0) | |
onef = relay.const(1.0) | |
zerof = relay.const(0.0) | |
assert algebraic_simplify(x + zero) == x | |
assert algebraic_simplify(x + zerof) == x | |
assert algebraic_simplify(zero + x) == x | |
assert algebraic_simplify(zerof + x) == x | |
assert algebraic_simplify(x - zero) == x | |
assert algebraic_simplify(x - zerof) == x | |
assert algebraic_simplify(x * one) == x | |
assert algebraic_simplify(x * onef) == x | |
assert algebraic_simplify(one * x) == x | |
assert algebraic_simplify(onef * x) == x | |
assert algebraic_simplify(x * zero) == zero | |
assert algebraic_simplify(x * zerof) == zerof | |
assert algebraic_simplify(x / one) == x | |
assert algebraic_simplify(x / onef) == x | |
assert algebraic_simplify(zero / x) == zero | |
assert algebraic_simplify(zerof / x) == zerof | |
assert tvm.ir.structural_equal( | |
algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't have to do the full thing for this PR, you can keep this as it is and we can extend later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep sounds good
src/relay/transforms/simplify_expr.h
Outdated
#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \ | ||
static DFPatternRewrite* Get() { \ | ||
static RewriteType rw; \ | ||
return &rw; \ | ||
} \ | ||
static DFPatternCallback GetCallback() { \ | ||
static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \ | ||
return cb; \ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I don't like Macros
- I don't like static initialization
Why not just initialize the object and call the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed the overhead of initializing and calling is probably negligible compared to running the pass itself, I did this following comments on my previous PR. @comaniac maybe you can comment? In the end I am fine with either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previously PR has a different implemetation and my point was the pattern table itself should be static. Given the current implemntation is based on SimplifyExpr, I agree with @mbrookhart that we don't need to make those functions static in the pattern class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, I'll remove this, thanks for the clarification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, made a helper class for composing rewrites since we need to ensure the lifetimes of the DFPatternCallbacks do not exceed the Rewrite objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Ended up having to make ToScalar return an optional value due to a custom datatype test (which as far as I can tell, we don't have a good way of supporting conversion at compile time in C++ currently). Let me know if this is fine, as I don't see an alternative within the scope of this PR. |
@altanh I'm confused why you need this change to ToScalar. What changed elsewhere in your PR that broke this unit test? |
here's the offending test https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-7731/7/pipeline the alternative I see to changing |
ping @mbrookhart @comaniac, I adjusted the change to |
I'm happy with this. I'll merge this afternoon unless @comaniac objects. |
Sorry I missed the previous message. Yeah I'm good with it so I've merged it. Thanks @altanh @mbrookhart |
…ifyExpr (apache#7731) * factor out some common code for DF rewriting, add ConcretizeLike * slight refactoring, add EliminateIdentity pass * lint * merge ConcretizeLike and EliminateIdentity into SimplifyExpr * nits and lint * remove static stuff * document * definitely ran clang-format but ok * make ToScalar return optional, fix missing virtual destructor * lint * tweak scalar conversion API to maintain compatibility
…ifyExpr (apache#7731) * factor out some common code for DF rewriting, add ConcretizeLike * slight refactoring, add EliminateIdentity pass * lint * merge ConcretizeLike and EliminateIdentity into SimplifyExpr * nits and lint * remove static stuff * document * definitely ran clang-format but ok * make ToScalar return optional, fix missing virtual destructor * lint * tweak scalar conversion API to maintain compatibility
…ifyExpr (apache#7731) * factor out some common code for DF rewriting, add ConcretizeLike * slight refactoring, add EliminateIdentity pass * lint * merge ConcretizeLike and EliminateIdentity into SimplifyExpr * nits and lint * remove static stuff * document * definitely ran clang-format but ok * make ToScalar return optional, fix missing virtual destructor * lint * tweak scalar conversion API to maintain compatibility
…ifyExpr (apache#7731) * factor out some common code for DF rewriting, add ConcretizeLike * slight refactoring, add EliminateIdentity pass * lint * merge ConcretizeLike and EliminateIdentity into SimplifyExpr * nits and lint * remove static stuff * document * definitely ran clang-format but ok * make ToScalar return optional, fix missing virtual destructor * lint * tweak scalar conversion API to maintain compatibility
This PR introduces two new rewrites:
ConcretizeLike
: replaces*_like
operators with their concrete-shape equivalent when the result shape is concrete.EliminateIdentity
: eliminates identity expressions likex + 0
,1 * x
, etc. Expressions that broadcastx
to a new shape are not removed, although we could explicitly replace them with broadcasting ops (not sure of the performance difference for this). This pass supports eliminating scalar constant 0s and 1s, but will not eliminate constants with more than a single element.I also refactored the existing DFPatternCallback-based passes slightly to lift out common machinery.
Together, these rewrites should help optimize the generated AD code (and credit to @t-vi for prototyping them in the blog post).
cc @tqchen @comaniac @MarisaKirisame @yzhliu
(I'll work on a
DeadParameterElimination
pass to complementConcretizeLike
as we discussed in the previous PR, but will send as a follow-up.)