Skip to content

Commit

Permalink
[ARITH] Simplify casts of constants 0 and 1 (apache#3758)
Browse files Browse the repository at this point in the history
* [ARITH] Simplify casts of constants 0 and 1

* [EXPR] is_const_value to check whether non-ints are consts

* Revert "[EXPR] is_const_value to check whether non-ints are consts"

This reverts commit 7e1b346.

* Use tvm::cast
  • Loading branch information
sgrechanik-h authored and wweic committed Aug 16, 2019
1 parent 99101d9 commit 18934bd
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,13 @@ Mutate_(const Variable* op, const Expr& self) {
return self;
}

Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Cast>();
return cast(op->type, op->value);
}

Expr RewriteSimplifier::operator()(const Expr& expr) {
// Run simplification in post order
Expr res = expr;
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Call* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;
Expr Mutate_(const Variable* op, const Expr& self) override;
Expr Mutate_(const Cast* op, const Expr& self) override;

protected:
/*! \brief internal structure for comparison. */
Expand Down
5 changes: 5 additions & 0 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {

Expr cast(const Type& t, Expr value) {
using ir::IntImm;
using ir::UIntImm;
using ir::FloatImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
} else if (const UIntImm* op = value.as<UIntImm>()) {
return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
return make_const(t, op->value);
}
Expand All @@ -122,6 +125,8 @@ Expr cast(const Type& t, Expr value) {
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
} else if (const UIntImm* op = value.as<UIntImm>()) {
return make_const(t, op->value);
} else if (const FloatImm* op = value.as<FloatImm>()) {
value = make_const(vtype, op->value);
} else {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,18 @@ def test_let_simplify():
z = tvm.expr.Let(x, 1, x + 1)
ck.verify(z + z, 4)

def test_cast_simplify():
ck = RewriteChecker()
x = tvm.var("x")

dtypes = ["float32", "float16", "int32", "int8", "bool"]
for dtype1 in dtypes:
ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))

if __name__ == "__main__":
test_floordiv_index_simplify()
test_floormod_index_simplify()
Expand All @@ -819,3 +831,4 @@ def test_let_simplify():
test_select_simplify()
test_logical_simplify()
test_let_simplify()
test_cast_simplify()

0 comments on commit 18934bd

Please sign in to comment.