From 7e93f5203c3ab657a9d4384629b89d143928a746 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 15:58:31 +0800 Subject: [PATCH 01/17] canonical simplification supports cast --- src/arith/canonical_simplify.cc | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index d0a0702a0fb0..0bc41e4f8307 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -127,6 +127,7 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { return NormalizeWithScale(1); } void MulToSelf(int64_t scale) { this->scale *= scale; } + void CastTo(DataType dtype) { this->index = cast(dtype, this->index); } inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -255,6 +256,16 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); + /*! + * \brief self = cast(dtype, self) + * \param dtype The target datatype + */ + void CastTo(DataType dtype) { + for (auto& arg : args) { + arg.CopyOnWrite()->CastTo(dtype); + } + } + static constexpr const char* _type_key = "arith.SumExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); @@ -430,6 +441,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { PrimExpr VisitExpr_(const FloorDivNode* op) final; PrimExpr VisitExpr_(const FloorModNode* op) final; PrimExpr VisitExpr_(const ReduceNode* op) final; + PrimExpr VisitExpr_(const CastNode* op) final; private: /*! @@ -448,6 +460,13 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return The result expression; */ SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode); + // /*! + // * \brief cast value to dtype + // * \param dtype The target datatype + // * \param value The SplitExpr to be casted + // * \return The result expression; + // */ + // SplitExpr CastSplitExpr(DataType dtype, SplitExpr value); /*! * \brief Separate psum into divisible and non-divisible parts. * \param psum The sum expression. @@ -689,6 +708,11 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } +// SplitExpr CanonicalSimplifier::Impl::CastSplitExpr(DataType dtype, SplitExpr value) { +// value.CopyOnWrite()->index = cast(dtype, value->index); +// return value; +// } + PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); @@ -1071,6 +1095,26 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { return ret; } +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { + if (!IsIndexType(op->dtype)) { + return Rewriter::VisitExpr_(op); + } + // normalize + PrimExpr value = this->CanonicalMutate(op->value); + if (value.as()) { + SumExpr se = Downcast(value); + se.CopyOnWrite()->CastTo(op->dtype); + return se; + } else if (value.as()) { + SplitExpr se = Downcast(value); + se.CopyOnWrite()->CastTo(op->dtype); + return se; + } else { + return Rewriter::VisitExpr_(op); + } +} + + PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } From 99074cbc8241cd4c4fdd68277f85e983bf0afcbb Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 16:17:58 +0800 Subject: [PATCH 02/17] fix --- src/arith/canonical_simplify.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 0bc41e4f8307..d19fb6848f0c 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -127,7 +127,11 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { return NormalizeWithScale(1); } void MulToSelf(int64_t scale) { this->scale *= scale; } - void CastTo(DataType dtype) { this->index = cast(dtype, this->index); } + + void CastTo(DataType dtype) { + this->index = cast(dtype, this->index); + this->dtype = dtype; + } inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -264,6 +268,7 @@ class SumExprNode : public CanonicalExprNode { for (auto& arg : args) { arg.CopyOnWrite()->CastTo(dtype); } + this->dtype = dtype; } static constexpr const char* _type_key = "arith.SumExpr"; From 44bf2c1e384b95dee409795b4968594cbd134069 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 17:58:47 +0800 Subject: [PATCH 03/17] add condition and tests --- src/arith/canonical_simplify.cc | 121 ++++++++++++++++-- .../unittest/test_arith_canonical_simplify.py | 10 ++ 2 files changed, 118 insertions(+), 13 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index d19fb6848f0c..dee2614b849f 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -77,6 +77,23 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { } } +bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { + if (!IsIndexType(dtype)) { + return false; + } + ConstIntBound bound = analyzer->const_int_bound(value); + int64_t ubound = Downcast(max_value(dtype))->value; + int64_t lbound = Downcast(min_value(dtype))->value; + if (value.dtype().bits() <= dtype.bits() || // upcast is safe + (bound->max_value <= ubound && bound->min_value >= lbound)) { + return true; + } + return false; +} + +#define CHECK_CAST(DTYPE, VALUE) \ + if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { return false; } + /*! * \brief Internal "Split normal form" of expression. * @@ -128,6 +145,46 @@ class SplitExprNode : public CanonicalExprNode { void MulToSelf(int64_t scale) { this->scale *= scale; } + /*! + * \brief check if cast(dtype, self) is safe + * \param dtype The target datatype + * \param analyzer The analyzer + * \return whether the cast is safe or not + */ + bool CheckCast(DataType dtype, Analyzer* analyzer) const { + // cast(dtype, index % upper_factor / lower_factor * scale) == + // cast(dtype, index) % upper_factor / lower_factor * scale + // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of + // its intermediate results fit in the range of dtype + if (dtype.bits() >= this->dtype.bits()) { + return true; // upcast is safe + } + PrimExpr res = this->index; + DataType dtype = this->dtype; + if (this->scale == 0) { + return true; + } + CHECK_CAST(dtype, res) + if (this->upper_factor != SplitExprNode::kPosInf) { + res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode); + CHECK_CAST(dtype, res) + } + if (this->lower_factor != 1) { + res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode); + CHECK_CAST(dtype, res) + } + if (this->scale != 1) { + ICHECK(!dtype.is_uint() || this->scale > 0); + res = res * make_const(dtype, this->scale); + CHECK_CAST(dtype, res) + } + return true; + } + + /*! + * \brief self = cast(dtype, self) + * \param dtype The target datatype + */ void CastTo(DataType dtype) { this->index = cast(dtype, this->index); this->dtype = dtype; @@ -260,6 +317,48 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); + /*! + * \brief check if cast(dtype, self) is safe + * \param dtype The target datatype + * \param analyzer The analyzer + * \return whether the cast is safe or not + */ + bool CheckCast(DataType dtype, Analyzer* analyzer) const { + // cast(dtype, arg_1 + arg_2 + ... arg_n) == + // cast(dtype, arg_1) + ... + cast(dtype, arg_n) + // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of + // its intermediate results fit in the range of dtype + if (dtype.bits() >= this->dtype.bits()) { + return true; // upcast is safe + } + PrimExpr res = make_const(dtype, 0); + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale > 0) { + res = res + args[i]->Normalize(); + CHECK_CAST(dtype, res) + } + } + if (base > 0) { + res = res + make_const(dtype, base); + CHECK_CAST(dtype, res) + } + // negative scales follows using sub. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->scale < 0) { + res = res - args[i]->NormalizeWithScale(-1); + CHECK_CAST(dtype, res) + } + } + if (base < 0) { + res = res - make_const(dtype, -base); + CHECK_CAST(dtype, res) + } + for (const auto& arg : args) { + if (!arg->CheckCast(dtype, analyzer)) { return false; } + } + return true; + } + /*! * \brief self = cast(dtype, self) * \param dtype The target datatype @@ -465,13 +564,6 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return The result expression; */ SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode); - // /*! - // * \brief cast value to dtype - // * \param dtype The target datatype - // * \param value The SplitExpr to be casted - // * \return The result expression; - // */ - // SplitExpr CastSplitExpr(DataType dtype, SplitExpr value); /*! * \brief Separate psum into divisible and non-divisible parts. * \param psum The sum expression. @@ -1108,15 +1200,18 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr value = this->CanonicalMutate(op->value); if (value.as()) { SumExpr se = Downcast(value); - se.CopyOnWrite()->CastTo(op->dtype); - return se; + if (se->CheckCast(op->dtype, analyzer_)) { + se.CopyOnWrite()->CastTo(op->dtype); + return se; + } } else if (value.as()) { SplitExpr se = Downcast(value); - se.CopyOnWrite()->CastTo(op->dtype); - return se; - } else { - return Rewriter::VisitExpr_(op); + if (se->CheckCast(op->dtype, analyzer_)) { + se.CopyOnWrite()->CastTo(op->dtype); + return se; + } } + return Rewriter::VisitExpr_(op); } diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 65c8ec3dfe02..e699c567d885 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -24,6 +24,7 @@ def __init__(self): def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) + print(res) expected = tvm.runtime.convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected @@ -310,6 +311,14 @@ def test_complex_cases(): ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4)) +def test_simplify_cast(): + ck = CanonicalChecker() + i = te.var("i", dtype="int32") + tcast = tvm.tir.Cast + res = tcast("int64", i + j + 1) - tcast("int64", i) + ck.verify(res, 1) + + if __name__ == "__main__": test_floormod_simplify() test_mul_sum_simplify() @@ -321,3 +330,4 @@ def test_complex_cases(): test_split_index_simplify() test_canonical_mixed() test_complex_cases() + test_simplify_cast() From f59eabdf4a6045457cb4af72e0009276a3eb238c Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 10:33:50 +0000 Subject: [PATCH 04/17] tests --- .../unittest/test_arith_canonical_simplify.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index e699c567d885..5003acf664a5 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -24,7 +24,6 @@ def __init__(self): def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - print(res) expected = tvm.runtime.convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected @@ -313,10 +312,41 @@ def test_complex_cases(): def test_simplify_cast(): ck = CanonicalChecker() - i = te.var("i", dtype="int32") tcast = tvm.tir.Cast + fld = tvm.te.floordiv + flm = tvm.te.floormod + # cast(i64, i + j + 1) - cast(i64, i) + i = te.var("i", dtype="int32") + j = te.var("j", dtype="int32") res = tcast("int64", i + j + 1) - tcast("int64", i) - ck.verify(res, 1) + ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64")) + # cast(i32, i + j + 1) - cast(i32, i) + i = te.var("i", dtype="int64") + j = te.var("j", dtype="int64") + ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10)) + ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) + res = tcast("int32", i + j + 1) - tcast("int32", i) + ck.verify(res, tcast("int32", j) + 1) + # cast(i32, i + j - 100) + i = te.var("i", dtype="int64") + j = te.var("j", dtype="int64") + ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1)) + ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) + res = tcast("int32", i + j - 100) + ck.verify(res, res) + # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32 + # - cast(i32, flm(axis, 7i64) * 2i64) + axis = te.var("axis", dtype="int64") + ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42)) + res = ( + tcast("int32", flm(axis, tvm.tir.const(7, "int64")) + * tvm.tir.const(2, "int64") + + tvm.tir.const(1, "int64")) + + tvm.tir.const(1, "int32") + - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) + * tvm.tir.const(2, "int64")) + ) + ck.verify(res, 2) if __name__ == "__main__": From 49c37c95eb6f154f44a34377e34123da277f8f89 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 18:51:23 +0800 Subject: [PATCH 05/17] fix --- src/arith/canonical_simplify.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index dee2614b849f..4af57a99afab 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -805,11 +805,6 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } -// SplitExpr CanonicalSimplifier::Impl::CastSplitExpr(DataType dtype, SplitExpr value) { -// value.CopyOnWrite()->index = cast(dtype, value->index); -// return value; -// } - PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); From 92cc8d75b448c9dd749637a6c0938581f7258c82 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 19:09:04 +0800 Subject: [PATCH 06/17] fix --- src/arith/canonical_simplify.cc | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 4af57a99afab..4738d09f52bd 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -78,21 +78,23 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { } bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { - if (!IsIndexType(dtype)) { - return false; - } - ConstIntBound bound = analyzer->const_int_bound(value); - int64_t ubound = Downcast(max_value(dtype))->value; - int64_t lbound = Downcast(min_value(dtype))->value; - if (value.dtype().bits() <= dtype.bits() || // upcast is safe - (bound->max_value <= ubound && bound->min_value >= lbound)) { - return true; - } - return false; + if (!IsIndexType(dtype)) { + return false; + } + ConstIntBound bound = analyzer->const_int_bound(value); + int64_t ubound = Downcast(max_value(dtype))->value; + int64_t lbound = Downcast(min_value(dtype))->value; + if (value.dtype().bits() <= dtype.bits() || // upcast is safe + (bound->max_value <= ubound && bound->min_value >= lbound)) { + return true; + } + return false; } -#define CHECK_CAST(DTYPE, VALUE) \ - if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { return false; } +#define CHECK_CAST(DTYPE, VALUE) \ + if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ + return false; \ + } /*! * \brief Internal "Split normal form" of expression. @@ -354,7 +356,9 @@ class SumExprNode : public CanonicalExprNode { CHECK_CAST(dtype, res) } for (const auto& arg : args) { - if (!arg->CheckCast(dtype, analyzer)) { return false; } + if (!arg->CheckCast(dtype, analyzer)) { + return false; + } } return true; } @@ -1209,7 +1213,6 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { return Rewriter::VisitExpr_(op); } - PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } From 1265518c0b7e8dbf754ef8b9a1fa915347a683d0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 19:11:06 +0800 Subject: [PATCH 07/17] fix --- src/arith/canonical_simplify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 4738d09f52bd..996c84b2f618 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -79,7 +79,7 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { if (!IsIndexType(dtype)) { - return false; + return false; } ConstIntBound bound = analyzer->const_int_bound(value); int64_t ubound = Downcast(max_value(dtype))->value; From f2a8e752dd1c450bb76d0932dbbeecc727f64317 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 19:18:38 +0800 Subject: [PATCH 08/17] fix --- .../unittest/test_arith_canonical_simplify.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 5003acf664a5..c241b81da986 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -330,7 +330,7 @@ def test_simplify_cast(): # cast(i32, i + j - 100) i = te.var("i", dtype="int64") j = te.var("j", dtype="int64") - ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1)) + ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1)) ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) res = tcast("int32", i + j - 100) ck.verify(res, res) @@ -339,12 +339,13 @@ def test_simplify_cast(): axis = te.var("axis", dtype="int64") ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42)) res = ( - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) - * tvm.tir.const(2, "int64") - + tvm.tir.const(1, "int64")) + tcast( + "int32", + flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64") + + tvm.tir.const(1, "int64"), + ) + tvm.tir.const(1, "int32") - - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) - * tvm.tir.const(2, "int64")) + - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")) ) ck.verify(res, 2) From c91fcc220e5400a556f2d5da4a2977b4b5f5b887 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 7 Dec 2020 19:23:50 +0800 Subject: [PATCH 09/17] fix --- src/arith/canonical_simplify.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 996c84b2f618..48fdf2533dc2 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -162,22 +162,21 @@ class SplitExprNode : public CanonicalExprNode { return true; // upcast is safe } PrimExpr res = this->index; - DataType dtype = this->dtype; if (this->scale == 0) { return true; } CHECK_CAST(dtype, res) if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode); + res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); CHECK_CAST(dtype, res) } if (this->lower_factor != 1) { - res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode); + res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); CHECK_CAST(dtype, res) } if (this->scale != 1) { - ICHECK(!dtype.is_uint() || this->scale > 0); - res = res * make_const(dtype, this->scale); + ICHECK(!this->dtype.is_uint() || this->scale > 0); + res = res * make_const(this->dtype, this->scale); CHECK_CAST(dtype, res) } return true; From b549c5474dca092b80dc1c02c7eea7628cd0864d Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 8 Dec 2020 15:40:48 +0800 Subject: [PATCH 10/17] fix --- src/arith/canonical_simplify.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 48fdf2533dc2..5cc5552c7857 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1196,20 +1196,23 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { } // normalize PrimExpr value = this->CanonicalMutate(op->value); + PrimExpr ret; if (value.as()) { SumExpr se = Downcast(value); if (se->CheckCast(op->dtype, analyzer_)) { se.CopyOnWrite()->CastTo(op->dtype); - return se; + ret = se; } } else if (value.as()) { SplitExpr se = Downcast(value); if (se->CheckCast(op->dtype, analyzer_)) { se.CopyOnWrite()->CastTo(op->dtype); - return se; + ret = se; } + } else { + ret = Rewriter::VisitExpr_(op); } - return Rewriter::VisitExpr_(op); + return ret; } PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { From 0ddc33a7dfbb1bfb3f4912d8e8d492947140ca79 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 8 Dec 2020 16:02:55 +0800 Subject: [PATCH 11/17] fix --- src/arith/canonical_simplify.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 5cc5552c7857..dc235b65d196 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1209,7 +1209,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { se.CopyOnWrite()->CastTo(op->dtype); ret = se; } - } else { + } + if (!ret.defined()) { ret = Rewriter::VisitExpr_(op); } return ret; From d3e9196f190704d1b6a15b79d33fe3f4a8d72630 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 28 Dec 2020 13:59:08 +0800 Subject: [PATCH 12/17] resolve comments --- src/arith/canonical_simplify.cc | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index dc235b65d196..76d7da53a242 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -91,9 +91,9 @@ bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { return false; } -#define CHECK_CAST(DTYPE, VALUE) \ - if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ - return false; \ +#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \ + if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ + return false; \ } /*! @@ -165,19 +165,19 @@ class SplitExprNode : public CanonicalExprNode { if (this->scale == 0) { return true; } - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) if (this->upper_factor != SplitExprNode::kPosInf) { res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } if (this->lower_factor != 1) { res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } if (this->scale != 1) { ICHECK(!this->dtype.is_uint() || this->scale > 0); res = res * make_const(this->dtype, this->scale); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } return true; } @@ -336,23 +336,23 @@ class SumExprNode : public CanonicalExprNode { for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale > 0) { res = res + args[i]->Normalize(); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } } if (base > 0) { res = res + make_const(dtype, base); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } // negative scales follows using sub. for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale < 0) { res = res - args[i]->NormalizeWithScale(-1); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } } if (base < 0) { res = res - make_const(dtype, -base); - CHECK_CAST(dtype, res) + TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) } for (const auto& arg : args) { if (!arg->CheckCast(dtype, analyzer)) { @@ -502,6 +502,8 @@ class SumExprNode : public CanonicalExprNode { } }; +#undef TVM_CHECK_CANONICAL_SIMPLIFY_CAST + class SumExpr : public PrimExpr { public: TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); From 89ce0d5058be57f9bb78c9e6b61f375b76524207 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 28 Dec 2020 14:01:24 +0800 Subject: [PATCH 13/17] fix --- src/arith/canonical_simplify.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 76d7da53a242..3b33799f5ed5 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -91,9 +91,9 @@ bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { return false; } -#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \ - if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ - return false; \ +#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \ + if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ + return false; \ } /*! From 25dd8849b43073c7622fb00b3c881a2cc472a4af Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 5 Jan 2021 16:55:35 +0800 Subject: [PATCH 14/17] rename --- src/arith/canonical_simplify.cc | 77 ++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 3b33799f5ed5..4a58ad031887 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -77,7 +77,14 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { } } -bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { +/*! + * \brief check if value fits in dtype + * \param value The value to be analyzed + * \param dtype The target dtype + * \param analyzer The analyzer + * \return whether value fits in dtype + */ +bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) { if (!IsIndexType(dtype)) { return false; } @@ -91,11 +98,6 @@ bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) { return false; } -#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \ - if (!CheckCastImpl(DTYPE, VALUE, analyzer)) { \ - return false; \ - } - /*! * \brief Internal "Split normal form" of expression. * @@ -148,12 +150,12 @@ class SplitExprNode : public CanonicalExprNode { void MulToSelf(int64_t scale) { this->scale *= scale; } /*! - * \brief check if cast(dtype, self) is safe + * \brief check if cast can be pushed to sub-expressions * \param dtype The target datatype * \param analyzer The analyzer - * \return whether the cast is safe or not + * \return whether the cast can be safely pushed to children */ - bool CheckCast(DataType dtype, Analyzer* analyzer) const { + bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { // cast(dtype, index % upper_factor / lower_factor * scale) == // cast(dtype, index) % upper_factor / lower_factor * scale // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of @@ -165,19 +167,27 @@ class SplitExprNode : public CanonicalExprNode { if (this->scale == 0) { return true; } - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } if (this->upper_factor != SplitExprNode::kPosInf) { res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } if (this->lower_factor != 1) { res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } if (this->scale != 1) { ICHECK(!this->dtype.is_uint() || this->scale > 0); res = res * make_const(this->dtype, this->scale); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } return true; } @@ -186,7 +196,7 @@ class SplitExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void CastTo(DataType dtype) { + void PushCastToChildren(DataType dtype) { this->index = cast(dtype, this->index); this->dtype = dtype; } @@ -319,12 +329,12 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); /*! - * \brief check if cast(dtype, self) is safe + * \brief check if cast can be pushed to sub-expressions * \param dtype The target datatype * \param analyzer The analyzer - * \return whether the cast is safe or not + * \return whether the cast can be safely pushed to children */ - bool CheckCast(DataType dtype, Analyzer* analyzer) const { + bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { // cast(dtype, arg_1 + arg_2 + ... arg_n) == // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of @@ -336,26 +346,34 @@ class SumExprNode : public CanonicalExprNode { for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale > 0) { res = res + args[i]->Normalize(); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } } if (base > 0) { res = res + make_const(dtype, base); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } // negative scales follows using sub. for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale < 0) { res = res - args[i]->NormalizeWithScale(-1); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } } if (base < 0) { res = res - make_const(dtype, -base); - TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res) + if (!CastIsSafe(dtype, res, analyzer)) { + return false; + } } for (const auto& arg : args) { - if (!arg->CheckCast(dtype, analyzer)) { + if (!arg->CanPushCastToChildren(dtype, analyzer)) { return false; } } @@ -366,9 +384,9 @@ class SumExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void CastTo(DataType dtype) { + void PushCastToChildren(DataType dtype) { for (auto& arg : args) { - arg.CopyOnWrite()->CastTo(dtype); + arg.CopyOnWrite()->PushCastToChildren(dtype); } this->dtype = dtype; } @@ -502,8 +520,6 @@ class SumExprNode : public CanonicalExprNode { } }; -#undef TVM_CHECK_CANONICAL_SIMPLIFY_CAST - class SumExpr : public PrimExpr { public: TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); @@ -1199,16 +1215,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { // normalize PrimExpr value = this->CanonicalMutate(op->value); PrimExpr ret; + // PushCastToChildren if (value.as()) { SumExpr se = Downcast(value); - if (se->CheckCast(op->dtype, analyzer_)) { - se.CopyOnWrite()->CastTo(op->dtype); + if (se->CanPushCastToChildren(op->dtype, analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->dtype); ret = se; } } else if (value.as()) { SplitExpr se = Downcast(value); - if (se->CheckCast(op->dtype, analyzer_)) { - se.CopyOnWrite()->CastTo(op->dtype); + if (se->CanPushCastToChildren(op->dtype, analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->dtype); ret = se; } } From c43a8c86a89b9826d49434023c4cfa2146b174d7 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 5 Jan 2021 16:58:04 +0800 Subject: [PATCH 15/17] fix --- src/arith/canonical_simplify.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 4a58ad031887..6446b54a8f5d 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -78,12 +78,12 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { } /*! - * \brief check if value fits in dtype - * \param value The value to be analyzed - * \param dtype The target dtype - * \param analyzer The analyzer - * \return whether value fits in dtype - */ + * \brief check if value fits in dtype + * \param value The value to be analyzed + * \param dtype The target dtype + * \param analyzer The analyzer + * \return whether value fits in dtype + */ bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) { if (!IsIndexType(dtype)) { return false; From 577f19487d353a0782c9a27c30ff9368920d32ef Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 7 Jan 2021 12:12:46 +0800 Subject: [PATCH 16/17] directly return --- src/arith/canonical_simplify.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 6446b54a8f5d..5545626c7480 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1220,19 +1220,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { SumExpr se = Downcast(value); if (se->CanPushCastToChildren(op->dtype, analyzer_)) { se.CopyOnWrite()->PushCastToChildren(op->dtype); - ret = se; + return std::move(se); } - } else if (value.as()) { + } + if (value.as()) { SplitExpr se = Downcast(value); if (se->CanPushCastToChildren(op->dtype, analyzer_)) { se.CopyOnWrite()->PushCastToChildren(op->dtype); - ret = se; + return std::move(se); } } - if (!ret.defined()) { - ret = Rewriter::VisitExpr_(op); - } - return ret; + return Rewriter::VisitExpr_(op); } PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { From e55ba942543018f55ef4e39ac04cb568013e80d8 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 7 Jan 2021 12:13:47 +0800 Subject: [PATCH 17/17] fix --- src/arith/canonical_simplify.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 5545626c7480..ba549959ac98 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1214,7 +1214,6 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { } // normalize PrimExpr value = this->CanonicalMutate(op->value); - PrimExpr ret; // PushCastToChildren if (value.as()) { SumExpr se = Downcast(value);