Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplify cast #7045

Merged
merged 17 commits into from
Jan 7, 2021
Merged
147 changes: 147 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
}
}

bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps rename to UpcastIsSafe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess CastIsSafe seems better, since it checks both upcast and downcast

if (!IsIndexType(dtype)) {
return false;
}
ConstIntBound bound = analyzer->const_int_bound(value);
int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
if (value.dtype().bits() <= dtype.bits() || // upcast is safe
(bound->max_value <= ubound && bound->min_value >= lbound)) {
return true;
}
return false;
}

#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \
if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \
hzfan marked this conversation as resolved.
Show resolved Hide resolved
return false; \
}

/*!
* \brief Internal "Split normal form" of expression.
*
Expand Down Expand Up @@ -128,6 +147,50 @@ class SplitExprNode : public CanonicalExprNode {

void MulToSelf(int64_t scale) { this->scale *= scale; }

/*!
* \brief check if cast(dtype, self) is safe
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast is safe or not
*/
bool CheckCast(DataType dtype, Analyzer* analyzer) const {
hzfan marked this conversation as resolved.
Show resolved Hide resolved
// cast(dtype, index % upper_factor / lower_factor * scale) ==
// cast(dtype, index) % upper_factor / lower_factor * scale
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = this->index;
if (this->scale == 0) {
return true;
}
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
if (this->lower_factor != 1) {
res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
if (this->scale != 1) {
ICHECK(!this->dtype.is_uint() || this->scale > 0);
res = res * make_const(this->dtype, this->scale);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void CastTo(DataType dtype) {
hzfan marked this conversation as resolved.
Show resolved Hide resolved
this->index = cast(dtype, this->index);
this->dtype = dtype;
}

inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;

Expand Down Expand Up @@ -255,6 +318,61 @@ class SumExprNode : public CanonicalExprNode {

void AddToSelf(const SumExpr& other, int64_t scale);

/*!
* \brief check if cast(dtype, self) is safe
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast is safe or not
*/
bool CheckCast(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, arg_1 + arg_2 + ... arg_n) ==
// cast(dtype, arg_1) + ... + cast(dtype, arg_n)
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
}
if (base > 0) {
res = res + make_const(dtype, base);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale < 0) {
res = res - args[i]->NormalizeWithScale(-1);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
}
if (base < 0) {
res = res - make_const(dtype, -base);
TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
}
for (const auto& arg : args) {
if (!arg->CheckCast(dtype, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void CastTo(DataType dtype) {
hzfan marked this conversation as resolved.
Show resolved Hide resolved
for (auto& arg : args) {
arg.CopyOnWrite()->CastTo(dtype);
}
this->dtype = dtype;
}

static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);

Expand Down Expand Up @@ -384,6 +502,8 @@ class SumExprNode : public CanonicalExprNode {
}
};

#undef TVM_CHECK_CANONICAL_SIMPLIFY_CAST

class SumExpr : public PrimExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode);
Expand Down Expand Up @@ -430,6 +550,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

private:
/*!
Expand Down Expand Up @@ -1071,6 +1192,32 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
return ret;
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
PrimExpr value = this->CanonicalMutate(op->value);
PrimExpr ret;
hzfan marked this conversation as resolved.
Show resolved Hide resolved
if (value.as<SumExprNode>()) {
SumExpr se = Downcast<SumExpr>(value);
if (se->CheckCast(op->dtype, analyzer_)) {
se.CopyOnWrite()->CastTo(op->dtype);
ret = se;
hzfan marked this conversation as resolved.
Show resolved Hide resolved
}
} else if (value.as<SplitExprNode>()) {
SplitExpr se = Downcast<SplitExpr>(value);
if (se->CheckCast(op->dtype, analyzer_)) {
se.CopyOnWrite()->CastTo(op->dtype);
ret = se;
hzfan marked this conversation as resolved.
Show resolved Hide resolved
}
}
if (!ret.defined()) {
ret = Rewriter::VisitExpr_(op);
hzfan marked this conversation as resolved.
Show resolved Hide resolved
}
return ret;
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,46 @@ def test_complex_cases():
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))


def test_simplify_cast():
ck = CanonicalChecker()
tcast = tvm.tir.Cast
fld = tvm.te.floordiv
flm = tvm.te.floormod
# cast(i64, i + j + 1) - cast(i64, i)
i = te.var("i", dtype="int32")
j = te.var("j", dtype="int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
axis = te.var("axis", dtype="int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
"int32",
flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ tvm.tir.const(1, "int64"),
)
+ tvm.tir.const(1, "int32")
- tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64"))
)
ck.verify(res, 2)


if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
Expand All @@ -321,3 +361,4 @@ def test_complex_cases():
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
test_simplify_cast()