Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplify cast #7045

Merged
merged 17 commits into from
Jan 7, 2021
Merged
161 changes: 161 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
}
}

/*!
* \brief check if value fits in dtype
* \param value The value to be analyzed
* \param dtype The target dtype
* \param analyzer The analyzer
* \return whether value fits in dtype
*/
bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
if (!IsIndexType(dtype)) {
return false;
}
ConstIntBound bound = analyzer->const_int_bound(value);
int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
if (value.dtype().bits() <= dtype.bits() || // upcast is safe
(bound->max_value <= ubound && bound->min_value >= lbound)) {
return true;
}
return false;
}

/*!
* \brief Internal "Split normal form" of expression.
*
Expand Down Expand Up @@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode {

void MulToSelf(int64_t scale) { this->scale *= scale; }

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, index % upper_factor / lower_factor * scale) ==
// cast(dtype, index) % upper_factor / lower_factor * scale
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = this->index;
if (this->scale == 0) {
return true;
}
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->lower_factor != 1) {
res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->scale != 1) {
ICHECK(!this->dtype.is_uint() || this->scale > 0);
res = res * make_const(this->dtype, this->scale);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
this->index = cast(dtype, this->index);
this->dtype = dtype;
}

inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;

Expand Down Expand Up @@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode {

void AddToSelf(const SumExpr& other, int64_t scale);

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, arg_1 + arg_2 + ... arg_n) ==
// cast(dtype, arg_1) + ... + cast(dtype, arg_n)
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base > 0) {
res = res + make_const(dtype, base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale < 0) {
res = res - args[i]->NormalizeWithScale(-1);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base < 0) {
res = res - make_const(dtype, -base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
for (const auto& arg : args) {
if (!arg->CanPushCastToChildren(dtype, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
for (auto& arg : args) {
arg.CopyOnWrite()->PushCastToChildren(dtype);
}
this->dtype = dtype;
}

static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);

Expand Down Expand Up @@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

private:
/*!
Expand Down Expand Up @@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
return ret;
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
PrimExpr value = this->CanonicalMutate(op->value);
// PushCastToChildren
if (value.as<SumExprNode>()) {
SumExpr se = Downcast<SumExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
if (value.as<SplitExprNode>()) {
SplitExpr se = Downcast<SplitExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
return Rewriter::VisitExpr_(op);
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,46 @@ def test_complex_cases():
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))


def test_simplify_cast():
ck = CanonicalChecker()
tcast = tvm.tir.Cast
fld = tvm.te.floordiv
flm = tvm.te.floormod
# cast(i64, i + j + 1) - cast(i64, i)
i = te.var("i", dtype="int32")
j = te.var("j", dtype="int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
axis = te.var("axis", dtype="int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
"int32",
flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ tvm.tir.const(1, "int64"),
)
+ tvm.tir.const(1, "int32")
- tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64"))
)
ck.verify(res, 2)


if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
Expand All @@ -321,3 +361,4 @@ def test_complex_cases():
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
test_simplify_cast()