-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Analyzer CanonicalSimplifier #2891
Conversation
Also as a side note, this PR helps to demonstrate how we can consolidate some of the simplification infra around the Analayzer, which could be helpful to improve some open PR that @sgrechanik-h is working on |
337f7d1
to
b433c00
Compare
CI is now green, would be great if we can get some inputs into the PR |
@Hzfengsy can you also help review this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I have little time currently. I'll try to look again today or tomorrow.
this->const_int_bound.Update(var, this->const_int_bound(expr)); | ||
this->modular_set.Update(var, this->modular_set(expr)); | ||
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr)); | ||
Expr new_expr = expr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why do you copy expr to new_expr
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make new expr mutable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But none of the subsequent lines changes it, or am I wrong?.
src/arithmetic/canonical_simplify.cc
Outdated
/*! | ||
* \brief Internal "Split normal form" of expression. | ||
* | ||
* This is a special expression that represent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: represents
src/arithmetic/canonical_simplify.cc
Outdated
Expr NormalizeWithScale(int64_t sscale) const { | ||
Expr res = this->index; | ||
Type dtype = this->type; | ||
CHECK_EQ(this->type, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check seems redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just like an asset, to check runtime consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But dtype gets initialized with this->type
in the previous line, so this check is obviously true.
src/arithmetic/canonical_simplify.cc
Outdated
* args are divided into segments with the same index. | ||
* within each segment, the SplitExpr is ordered in descending order of lower_factor. | ||
* | ||
* \note Can be mutated by TryMergeSplitExpr, which is idempotent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is TryMergeSplitExpr
? (Didn't find it in the code)
src/arithmetic/canonical_simplify.cc
Outdated
// | ||
// ((x / (c * s)) * s + (x % (c * s)) / c | ||
// => ((x / c) / s) * s + ((x / c) % s) | ||
// => (x / c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly speaking, I can't understand this algorithm. Probably I have to return to it in a better
state of mind. Expanding the explanation may be helpful too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simplification rule and proof are correct. It's based on two basic rules:
Rule 1: (x % (c * s)) / c = (x / c) % s
Proof:
x can always be decomposed into p * c * s + q * c + r where 0 <= q * c + r < c * s and 0 <= r < c.
Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q
rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q
Thus, lhs = rhs
Rule 2: (x / s) * s + x % s = x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Although it not obvious to me if the rules are still correct for the C/C++ division used in tvm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rule works for both trunc div and floor version. Mainly because that the first rule only involves mul div and mod. And you can simply take abs of all operands and then take addd the final sign. The second rule is an invariant for both types of div
src/arithmetic/canonical_simplify.cc
Outdated
} | ||
// sort by the entry | ||
auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) { | ||
// order by scale first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be ordered by index
first? Or at least if the indices are different, the elements
should be incomparable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, however, it is can be quite costly to deep compare indices. So instead we just order by the scale and factor so that it is mostly in a consistent form(ignoring the indices)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the algorithm assumes that the vector contains contiguous segments of same-index elements, we have to check in the comparison function if lhs and rhs have the same index, otherwise this assumption may be destroyed by sorting.
Also I would still suggest sorting by index because otherwise we may often get into the situation when something like f(x + y) - f(y + x)
don't get simplified. And if it really leads to performance problems, we should think about optimizing deep comparison somehow (probably we can cache the size or some other measure of an expression).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now the code is fine because the result is only directly used by normalize, which is the intended usecase. It does break if we call it in the middle.
The only reason why I am not sure about comparing index is that var comparison can depend on runtime, which makes its behavior indeterministic. I want to think a bit more about this before we come back to revisit it
void DivideBy(int64_t scale) { | ||
this->base /= scale; | ||
for (size_t i = 0; i < this->args.size(); ++i) { | ||
args[i].CopyOnWrite()->scale /= scale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we raise an error if some of the scales are not divisible by the argument?
return; | ||
} | ||
} | ||
// Insert other in the end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we should also sort by index.
* \param other The expression to be added. | ||
* \param scale The additional scale on value. | ||
*/ | ||
void AddToSelf(SplitExpr other, int64_t scale) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be a better way to use const SplitExpr &other
instead of SplitExpr other
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to use CopyOnWrite inside the arguments so use SplitExpr directly. If the item is newly constructed, other will be directly passed in as a unique copy, and CopyOnWrite will reuse that data from there.
if (!IsIndexType(op->type)) { | ||
return Rewriter::Mutate_(op, self); | ||
} | ||
// normalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we may build a function to reduce the duplicate code in above three functions
src/arithmetic/canonical_simplify.cc
Outdated
// note: x = z, c = 3, s = 2 | ||
// | ||
// ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3 | ||
// => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't there be a condition that lhs->upper_factor % rhs->upper_factor == 0
so that we can perform the transformation z % rhs->upper_factor => (z % lhs->upper_factor) % rhs->upper_factor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an invariant condition that lhs->upper_factor % lhs->lower_factor == 0
@Hzfengsy @sxjscience @sgrechanik-h thanks for the reviews, I have updated the comment blocks to add more explanations about the proof |
we have a case like this, can this PR handle it?
|
@xqdan you should try it out. In theory this canonical simplifier is able to handle all kinds of div mode mul pattern that comes out from split and re-fuse |
nice, I will try after this is merged. |
// note also the invariance lhs->upper_factor % lhs->lower_factor == 0 | ||
// | ||
SplitExprNode* merged = rhs.CopyOnWrite(); | ||
merged->upper_factor = lhs->upper_factor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct only when lhs->uppper_factor == kPosInf? For example, ((x % 5) / (3 * 2)) * 2 + (x % (3 * 2)) / 3
is simplified to x % 5 / 3
, but this is not correct when x == 5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above on invariance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, sorry for my confusion.
src/arithmetic/canonical_simplify.cc
Outdated
// - s = lhs->scale / rhs->scale | ||
// - c = rhs->lower_factor | ||
// | ||
// ((x / (c * s)) * s + (x % (c * s)) / c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant (
.
(x / (c * s)) * s + (x % (c * s)) / c
// note also the invariance lhs->upper_factor % lhs->lower_factor == 0 | ||
// | ||
SplitExprNode* merged = rhs.CopyOnWrite(); | ||
merged->upper_factor = lhs->upper_factor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, sorry for my confusion.
src/arithmetic/canonical_simplify.cc
Outdated
if (cval % lhs->scale == 0) { | ||
int64_t scaled_cval = cval / lhs->scale; | ||
lhs.CopyOnWrite()->scale = 1; | ||
lhs.CopyOnWrite()->lower_factor *= scaled_cval; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this guarantee the invariance lhs->upper_factor % lhs->lower_factor == 0
? It looks not obvious and I wonder if we should call lhs->Verify() here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch :)
@Hzfengsy @sxjscience @sgrechanik-h @kazum thanks for the reviews, please take another look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed through the rest, seems ok. My main concern is sorting by the index field, but it can be done in subsequent PRs.
src/arithmetic/canonical_simplify.cc
Outdated
* \param coeff The co-efficient. | ||
* \param out_divisible The result divisible component. | ||
* \param out_non_divisible The non-divisible component. | ||
* \return Whetjer detection is successful. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo whether
src/arithmetic/stmt_simplify.cc
Outdated
ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); | ||
else_case = this->Mutate(op->else_case); | ||
} | ||
if (is_one(condition)) return op->then_case; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we return then_case
instead of op->then_case
here? (And same thing with the else_case)
return lhs; | ||
} else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { | ||
// (x % c1) / c2 => 0 when c2 >= c1 | ||
return ToSplitExpr(make_zero(lhs.type())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also return zero when cval % lhs->scale != 0? I mean the below looks correct in any cases.
if (lhs->upper_factor <= (lhs->lower_factor * cval / lhs->scale))
return ToSplitExpr(make_zero(lhs.type()));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the mul and division are not necessarily exchangeable, in the case of cval % lhs->scale != 0, we will need to consider the consequence more carefully, because it is rare to have such case, we just skip the optimization
Thanks, @sgrechanik-h @sxjscience @kazum @xqdan @Hzfengsy , this is now merged |
This PR contains one step of #2588
The main highlight of this PR is the introduction of the "split normal form", so we can simplify the following expression.
It is quite fun to implement the split normalization. Currently, we only support constant div and mod co-efficient for simplicity, we can consider adding symbolic support later.