Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr #7731

Merged
merged 11 commits into from
Mar 31, 2021

Conversation

altanh
Copy link
Contributor

@altanh altanh commented Mar 24, 2021

This PR introduces two new rewrites:

  • ConcretizeLike: replaces *_like operators with their concrete-shape equivalent when the result shape is concrete.
  • EliminateIdentity: eliminates identity expressions like x + 0, 1 * x, etc. Expressions that broadcast x to a new shape are not removed, although we could explicitly replace them with broadcasting ops (not sure of the performance difference for this). This pass supports eliminating scalar constant 0s and 1s, but will not eliminate constants with more than a single element.

I also refactored the existing DFPatternCallback-based passes slightly to lift out common machinery.

Together, these rewrites should help optimize the generated AD code (and credit to @t-vi for prototyping them in the blog post).

cc @tqchen @comaniac @MarisaKirisame @yzhliu

(I'll work on a DeadParameterElimination pass to complement ConcretizeLike as we discussed in the previous PR, but will send as a follow-up.)

@tqchen
Copy link
Member

tqchen commented Mar 24, 2021

cc @comaniac @yzhliu please help to review this PR

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Meanwhile I have some questions about the design:

  1. It seems to me that ConcretizeLike and EliminateIdentity can also be merged to SImplifyExpr in terms of the implementation and semantic. What's the concern of having 3 separate passes?

  2. You mentioned that for a certain reason, EliminateIdentity should be run before SimplifyExpr, but I didn't get the point about what would happen if we run them in the reverse order. Could you elaborate a bit further?

src/relay/transforms/concretize_like.cc Outdated Show resolved Hide resolved
src/relay/transforms/concretize_like.cc Outdated Show resolved Hide resolved
src/relay/transforms/concretize_like.cc Outdated Show resolved Hide resolved
src/relay/transforms/simplify_expr.cc Outdated Show resolved Hide resolved
@comaniac
Copy link
Contributor

Also cc @mbrookhart

@altanh
Copy link
Contributor Author

altanh commented Mar 24, 2021

Overall LGTM. Meanwhile I have some questions about the design:

1. It seems to me that ConcretizeLike and EliminateIdentity can also be merged to SImplifyExpr in terms of the implementation and semantic. What's the concern of having 3 separate passes?

2. You mentioned that for a certain reason, EliminateIdentity should be run before SimplifyExpr, but I didn't get the point about what would happen if we run them in the reverse order. Could you elaborate a bit further?
  1. It is definitely possible- I separated them mainly for ability to test them separately, as otherwise the overall semantics of the combined pass might be a bit tricky to write test cases for (e.g. will need to adjust the cases where we are adding 0 or multiplying by 1). I can definitely add additional test cases that run all of them in sequence (as if it was 1 single pass), or just try to merge them into SimplifyExpr and update the test cases. lmk
  2. Yeah, so SimplifyExpr has a rewrite called FullElementwise that takes (for example) x + zeros_like(x) and rewrites it to x + const(0). I couldn't think of a portable way to rewrite x + const(0) to x in EliminateIdentity, so it won't reduce this expression. For this reason you should run EliminateIdentity first- hope this makes sense. That being said, if there is a good way to examine constant values for any dtype (e.g. casting?) then we could also eliminate this.

@comaniac
Copy link
Contributor

If the purpose is just for testing, then I'll prefer to have them in a single pass. You can still test the pattern one-by-one as SimplifyExpr does now. Since the unrelated patterns won't be matched, I didn't see the problem of testing. In this case, you can also control the order of rewriting patterns. i.e., always run EliminateIdentity before FullElementWise in the SimplifyExpr pass. This can also reduce the possible confusion from users.

@mbrookhart
Copy link
Contributor

I think I agree that merging these with SimplifyExpr would be a win in terms of our ability to control the order of execution.

On the simplification of things like x * const(0), you can get the ndarray out of the const as ConstantNode->data, and then you can pass that to this utility, which will return a long double version of the value, which shouldn't loose precision for any of the 64 bit or smaller datatypes we use. You can then do your comparison in a single dtype.

I've been meaning to implement this for like 6 months, and I haven't had a strong enough forcing function to bubble it up to the top of my priority list.

@mbrookhart
Copy link
Contributor

It looks like you have a windows build issue?

@altanh
Copy link
Contributor Author

altanh commented Mar 24, 2021

I think I agree that merging these with SimplifyExpr would be a win in terms of our ability to control the order of execution.

On the simplification of things like x * const(0), you can get the ndarray out of the const as ConstantNode->data, and then you can pass that to this utility, which will return a long double version of the value, which shouldn't loose precision for any of the 64 bit or smaller datatypes we use. You can then do your comparison in a single dtype.

I've been meaning to implement this for like 6 months, and I haven't had a strong enough forcing function to bubble it up to the top of my priority list.

This is just what I needed, thanks! I think since I'm also just checking 0 or 1, there should be no problem casting since the floating point repr should all be the same.

I'll fix the Windows issue, it was because I overloaded the same name too much.

@altanh
Copy link
Contributor Author

altanh commented Mar 24, 2021

Ah, well looks like this might only make sense for constants that only have one element, unless we want to loop over every single element and check that it is equal to 0 or 1. But if I understand correctly, the FullElementwise pass only rewrites to scalar constants so those will be rewritten correctly; it's just that if the input IR has non-scalar constants that it won't be simplified.

Do you guys think this is a reasonable tradeoff?

@mbrookhart
Copy link
Contributor

Yeah, as long as we aren't commonly manifesting full sized arrays of zero or one, that should be fine. Given the full/zeros/ones ops and their like counterparts, plus auto-broadcasting, I think that's generally a reasonable assumption to make.

@altanh
Copy link
Contributor Author

altanh commented Mar 24, 2021

@comaniac @mbrookhart I've merged them and updated the unit tests

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just nits.

src/relay/transforms/simplify_expr.cc Show resolved Hide resolved
};

Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
return ExprSimplifier(mod).Simplify(expr);
static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need to comment here if the order is enforced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mbrookhart , I believe the ordering is respected by the rewriter but because the rewriter iterates until fixed point, I don't think it would be correct to say that globally the order is enforced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(in any case it shouldn't matter now that I've added support for eliminating constants)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, they will do as much as they can to the graph in this order, and then loop back and try again. That being said, that's the same thing you'd get if you ran them as separate passes, you'd rewrite everything you could before the next pass, but the next pass might open opportunities to do more with the current pass if you ran it again.

src/relay/transforms/simplify_expr.h Show resolved Hide resolved
@altanh altanh changed the title [Relay][Pass] ConcretizeLike and EliminateIdentity Passes [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr Mar 24, 2021
Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few design issues and nitpicks

};

Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
return ExprSimplifier(mod).Simplify(expr);
static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, they will do as much as they can to the graph in this order, and then loop back and try again. That being said, that's the same thing you'd get if you ran them as separate passes, you'd rewrite everything you could before the next pass, but the next pass might open opportunities to do more with the current pass if you ran it again.

Comment on lines 388 to 389
DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also simplify x * 0 = 0 and x + 0 = x?

The pattern matcher should be able to match these ops irrespective of order:

if (const OpNode* op_node = get_op_node(op)) {
if ((op_node->name == "add") || (op_node->name == "multiply")) {
if (match_args(reverse(op->args), call_node->args)) {
return true;
}
}
}

You can probably get away without the AltPattern here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add x * 0 = 0 although to me it semantically doesn't exactly fit and would need slightly different logic, perhaps a new rewrite ZeroMultiply or something?

And I wasn't aware of the commutative matching, that's helpful (although I wonder if it should be more visibly defined somewhere?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by x + 0 = x as I think I already have that covered, can you confirm?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you do, sorry!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this roughly implementend in the pattern matcher tests, but never really productized:

def algebraic_simplify(expr):
zero = is_expr(relay.const(0)) | is_expr(relay.const(0.0))
one = is_expr(relay.const(1)) | is_expr(relay.const(1.0))
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
return node_map[self.x][0] # pylint: disable=no-member
class AddCallback(ElwiseNullCallback):
def __init__(self):
super(AddCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x + zero
class SubCallback(ElwiseNullCallback):
def __init__(self):
super(SubCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x - zero
class MulCallback(ElwiseNullCallback):
def __init__(self):
super(MulCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x * one
class DivCallback(ElwiseNullCallback):
def __init__(self):
super(DivCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x / one
class MulZeroCallback(ElwiseNullCallback):
def __init__(self):
super(MulZeroCallback, self).__init__()
self.x = zero
self.pattern = self.x * wildcard()
class ZeroDivCallback(ElwiseNullCallback):
def __init__(self):
super(ZeroDivCallback, self).__init__()
self.x = zero
self.pattern = self.x / wildcard()
return rewrite(
[
AddCallback(),
SubCallback(),
MulCallback(),
DivCallback(),
MulZeroCallback(),
ZeroDivCallback(),
],
expr,
)
def test_algebraic_simplify():
x = relay.Var("x")
y = relay.Var("y")
one = relay.const(1)
zero = relay.const(0)
onef = relay.const(1.0)
zerof = relay.const(0.0)
assert algebraic_simplify(x + zero) == x
assert algebraic_simplify(x + zerof) == x
assert algebraic_simplify(zero + x) == x
assert algebraic_simplify(zerof + x) == x
assert algebraic_simplify(x - zero) == x
assert algebraic_simplify(x - zerof) == x
assert algebraic_simplify(x * one) == x
assert algebraic_simplify(x * onef) == x
assert algebraic_simplify(one * x) == x
assert algebraic_simplify(onef * x) == x
assert algebraic_simplify(x * zero) == zero
assert algebraic_simplify(x * zerof) == zerof
assert algebraic_simplify(x / one) == x
assert algebraic_simplify(x / onef) == x
assert algebraic_simplify(zero / x) == zero
assert algebraic_simplify(zerof / x) == zerof
assert tvm.ir.structural_equal(
algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to do the full thing for this PR, you can keep this as it is and we can extend later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep sounds good

Comment on lines 37 to 45
#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType) \
static DFPatternRewrite* Get() { \
static RewriteType rw; \
return &rw; \
} \
static DFPatternCallback GetCallback() { \
static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
return cb; \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I don't like Macros
  2. I don't like static initialization

Why not just initialize the object and call the method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the overhead of initializing and calling is probably negligible compared to running the pass itself, I did this following comments on my previous PR. @comaniac maybe you can comment? In the end I am fine with either way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previously PR has a different implemetation and my point was the pattern table itself should be static. Given the current implemntation is based on SimplifyExpr, I agree with @mbrookhart that we don't need to make those functions static in the pattern class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, I'll remove this, thanks for the clarification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, made a helper class for composing rewrites since we need to ensure the lifetimes of the DFPatternCallbacks do not exceed the Rewrite objects

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@altanh
Copy link
Contributor Author

altanh commented Mar 25, 2021

Ended up having to make ToScalar return an optional value due to a custom datatype test (which as far as I can tell, we don't have a good way of supporting conversion at compile time in C++ currently). Let me know if this is fine, as I don't see an alternative within the scope of this PR.

@mbrookhart
Copy link
Contributor

@altanh I'm confused why you need this change to ToScalar. What changed elsewhere in your PR that broke this unit test?

@altanh
Copy link
Contributor Author

altanh commented Mar 25, 2021

@altanh I'm confused why you need this change to ToScalar. What changed elsewhere in your PR that broke this unit test?

here's the offending test https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-7731/7/pipeline

the alternative I see to changing ToScalar is basically just mirroring the code exactly but instead returning a bool if there is a way to convert, which seems less sustainable but perhaps better in the name of API stability.

@altanh
Copy link
Contributor Author

altanh commented Mar 31, 2021

ping @mbrookhart @comaniac, I adjusted the change to ToScalar by making a new function TryToScalar so that the existing API does not need to change (although we should probably keep in mind where we use ToScalar for bring-your-own-datatype compatibility)

@mbrookhart
Copy link
Contributor

I'm happy with this. I'll merge this afternoon unless @comaniac objects.

@comaniac comaniac merged commit b3ab19e into apache:main Mar 31, 2021
@comaniac
Copy link
Contributor

Sorry I missed the previous message. Yeah I'm good with it so I've merged it. Thanks @altanh @mbrookhart

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…ifyExpr (apache#7731)

* factor out some common code for DF rewriting, add ConcretizeLike

* slight refactoring, add EliminateIdentity pass

* lint

* merge ConcretizeLike and EliminateIdentity into SimplifyExpr

* nits and lint

* remove static stuff

* document

* definitely ran clang-format but ok

* make ToScalar return optional, fix missing virtual destructor

* lint

* tweak scalar conversion API to maintain compatibility
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…ifyExpr (apache#7731)

* factor out some common code for DF rewriting, add ConcretizeLike

* slight refactoring, add EliminateIdentity pass

* lint

* merge ConcretizeLike and EliminateIdentity into SimplifyExpr

* nits and lint

* remove static stuff

* document

* definitely ran clang-format but ok

* make ToScalar return optional, fix missing virtual destructor

* lint

* tweak scalar conversion API to maintain compatibility
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…ifyExpr (apache#7731)

* factor out some common code for DF rewriting, add ConcretizeLike

* slight refactoring, add EliminateIdentity pass

* lint

* merge ConcretizeLike and EliminateIdentity into SimplifyExpr

* nits and lint

* remove static stuff

* document

* definitely ran clang-format but ok

* make ToScalar return optional, fix missing virtual destructor

* lint

* tweak scalar conversion API to maintain compatibility
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
…ifyExpr (apache#7731)

* factor out some common code for DF rewriting, add ConcretizeLike

* slight refactoring, add EliminateIdentity pass

* lint

* merge ConcretizeLike and EliminateIdentity into SimplifyExpr

* nits and lint

* remove static stuff

* document

* definitely ran clang-format but ok

* make ToScalar return optional, fix missing virtual destructor

* lint

* tweak scalar conversion API to maintain compatibility
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants