Skip to content

Commit

Permalink
[Arith] Simplify cast (#7045)
Browse files Browse the repository at this point in the history
  • Loading branch information
hzfan authored Jan 7, 2021
1 parent 93d79ba commit 9815ae2
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
161 changes: 161 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
}
}

/*!
* \brief check if value fits in dtype
* \param value The value to be analyzed
* \param dtype The target dtype
* \param analyzer The analyzer
* \return whether value fits in dtype
*/
bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
if (!IsIndexType(dtype)) {
return false;
}
ConstIntBound bound = analyzer->const_int_bound(value);
int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
if (value.dtype().bits() <= dtype.bits() || // upcast is safe
(bound->max_value <= ubound && bound->min_value >= lbound)) {
return true;
}
return false;
}

/*!
* \brief Internal "Split normal form" of expression.
*
Expand Down Expand Up @@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode {

void MulToSelf(int64_t scale) { this->scale *= scale; }

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, index % upper_factor / lower_factor * scale) ==
// cast(dtype, index) % upper_factor / lower_factor * scale
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = this->index;
if (this->scale == 0) {
return true;
}
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->lower_factor != 1) {
res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->scale != 1) {
ICHECK(!this->dtype.is_uint() || this->scale > 0);
res = res * make_const(this->dtype, this->scale);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
this->index = cast(dtype, this->index);
this->dtype = dtype;
}

inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;

Expand Down Expand Up @@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode {

void AddToSelf(const SumExpr& other, int64_t scale);

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, arg_1 + arg_2 + ... arg_n) ==
// cast(dtype, arg_1) + ... + cast(dtype, arg_n)
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base > 0) {
res = res + make_const(dtype, base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale < 0) {
res = res - args[i]->NormalizeWithScale(-1);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base < 0) {
res = res - make_const(dtype, -base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
for (const auto& arg : args) {
if (!arg->CanPushCastToChildren(dtype, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
for (auto& arg : args) {
arg.CopyOnWrite()->PushCastToChildren(dtype);
}
this->dtype = dtype;
}

static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);

Expand Down Expand Up @@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

private:
/*!
Expand Down Expand Up @@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
return ret;
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
PrimExpr value = this->CanonicalMutate(op->value);
// PushCastToChildren
if (value.as<SumExprNode>()) {
SumExpr se = Downcast<SumExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
if (value.as<SplitExprNode>()) {
SplitExpr se = Downcast<SplitExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
return Rewriter::VisitExpr_(op);
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,46 @@ def test_complex_cases():
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))


def test_simplify_cast():
ck = CanonicalChecker()
tcast = tvm.tir.Cast
fld = tvm.te.floordiv
flm = tvm.te.floormod
# cast(i64, i + j + 1) - cast(i64, i)
i = te.var("i", dtype="int32")
j = te.var("j", dtype="int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
axis = te.var("axis", dtype="int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
"int32",
flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ tvm.tir.const(1, "int64"),
)
+ tvm.tir.const(1, "int32")
- tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64"))
)
ck.verify(res, 2)


if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
Expand All @@ -321,3 +361,4 @@ def test_complex_cases():
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
test_simplify_cast()

0 comments on commit 9815ae2

Please sign in to comment.