-
Notifications
You must be signed in to change notification settings - Fork 7
Add optional bool preserve_error for expr simplifier
#2534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| // simplifications involving div and mod requires the divisor to be | ||
| // non-zero, and I don't want this edge case to block the simplification of | ||
| // normal cases. | ||
| if (ns->getParallelDim().has_value() || ns->isTensorSize()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't make sense to assume tensor size to be positive globally. All we want is to unblock the simplification of i / size * size + i % size -> i, but if we do assume size to be positive, we would also simplify predicates like 0 < size as true, which could potentially break zero tensor support. Instead, we should have a preserve_error flag for expr simplified to unblock that simplification even if we can not prove size to be non-zero.
|
Marking this as ready, but I would like to wait for #2500 because I don't want this to conflict with the new |
| std::unordered_map<Val*, VarInfo> var_info_map_; | ||
| std::vector<Val*> var_order_; | ||
| std::unordered_set<Val*> set_; | ||
| bool preserve_error_ = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a comment what this flag means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I found a comment at simplifyExpr().
|
|
||
| namespace { | ||
|
|
||
| // If we can not prove `value` to be zero, then treat it as non-zero, unless we |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this comment is a little confusing to me.
If we can not prove
valueto be zero, then treat it as non-zero
Are you saying "treat it as potentially non-zero" rather than "definitely non-zero", right?
"If we can not prove value to be zero, then treat it as potentially non-zero" <- This seems generally true, so why "unless we want to preserve error"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very confusing, I renamed it and rewrite its comments:
// If we want to do simplifications like (a * b) / b -> a, depending on whether
// we want to preserve error, the behavior could be different. If we don't care
// about preserving error, we can just go ahead and do the simplification.
// However, if we do want the division-by-zero error to be preserved, then we
// can only do the simplification if we can prove b != 0. This function tells us
// if the value of b is safe to do such optimization. Instead of completely
// ignoring error case, we do a bit extra: if b is proved to be zero, then we
// are sure that there will be an error, then we don't remove the error. That is,
// if we don't know if there will be an error, we procceed assuming no error. If
// we are sure there will be an error, then don't procceed.
bool isValidDenominator(Val* value, const Context& context) {
if (context.preserveError()) {
return prove::isNonZero(value, context);
}
return !foldConstants(value)->isZero();
}| if (context.preserveError()) { | ||
| return prove::isNonZero(value, context); | ||
| } | ||
| return !foldConstants(value)->isZero(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does folding of (a * b) / b to b happen in foldConstants(value)? Does foldConstants fold the expression no matter if b is zero or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, foldConstants is just constant folding, that is, changing 3 - 3 into 0.
| // we can prove that b is not zero. If we don't care about error case, then | ||
| // having `b` in the denominator already indicates that `b` is non-zero. So we | ||
| // can just do the simplification without worrying about `b` being zero. | ||
| bool notSurelyZero(Val* value, const Context& context) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this name doesn't sound very intuitive to me. What this tells is value is not proven to be zero, right? It seems it just says it may be or may not be zero. Am I right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your understanding is correct. The name and comment of this function is very confusing, I renamed it and rewite its comment. See #2534 (comment)
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. It's much clearer now. I added some comments, which I think would make it further easier to follow.
| // are sure that there will be an error, then we don't remove the error. That | ||
| // is, if we don't know if there will be an error, we procceed assuming no | ||
| // error. If we are sure there will be an error, then don't procceed. | ||
| bool isValidDenominator(Val* value, const Context& context) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: value -> denominator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent suggestion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
| // Q1: Can we prove that value is nonzero? | ||
| // Q2: Can we prove that value is zero? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"that value" -> "the denominator"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
| if (context.preserveError()) { | ||
| return false; | ||
| } | ||
| bool proved_zero = foldConstants(value)->isZero(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have prove::isNonZero(). Don't we have prove::isZero()? If so, why isn't it used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are doing the same thing. I created prove::isZero when there were no Val::isZero, but later, I added Val::isZero, so prove::isZero is no longer needed. I am removing it in another PR, so I am not using it here.
No description provided.