From e35663f8a07d807acb12378580ea9d4e1c155f1a Mon Sep 17 00:00:00 2001 From: syang-ng Date: Wed, 8 Sep 2021 09:08:31 +0000 Subject: [PATCH 1/8] fix div zero error in rewrite_simplify --- src/arith/rewrite_simplify.cc | 2 ++ .../unittest/test_arith_rewrite_simplify.py | 27 ++++++++----------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1d3475b13dad..1cdd6f0b4cb5 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,6 +474,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + // If divisor is equal to zero + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 641eed51d5cf..ba42b7cb920c 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import py +import pytest import tvm from tvm import te @@ -931,20 +933,13 @@ def test_shift_left_simplify(): ck.verify(z, tvm.tir.const(1 << 10, "int32")) +def test_div_zero_simplify(): + ck = RewriteChecker() + + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1,1,2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + + if __name__ == "__main__": - test_floordiv_index_simplify() - test_floormod_index_simplify() - test_cmp_simplify() - test_vector_simplify() - test_add_index_simplify() - test_sub_index_simplify() - test_mul_index_simplify() - test_div_index_simplify() - test_max_index_simplify() - test_min_index_simplify() - test_mod_index_simplify() - test_select_simplify() - test_logical_simplify() - test_let_simplify() - test_cast_simplify() - test_shift_left_simplify() + pytest.main([__file__]) From 17a517470ae9300fff6afc2b8424074bf58a87a3 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Wed, 8 Sep 2021 10:55:13 +0000 Subject: [PATCH 2/8] update the style to fix ci error --- tests/python/unittest/test_arith_rewrite_simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index ba42b7cb920c..95d24d68ee7f 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -937,7 +937,7 @@ def test_div_zero_simplify(): ck = RewriteChecker() with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1,1,2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) assert "division by zero" in str(cm.execption) From 564bf4a0de08628a3ff14bcb20a423a2595f1590 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Thu, 9 Sep 2021 00:13:37 +0000 Subject: [PATCH 3/8] remove useless code and comment --- src/arith/rewrite_simplify.cc | 1 - tests/python/unittest/test_arith_rewrite_simplify.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1cdd6f0b4cb5..0087866ea4f8 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,7 +474,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - // If divisor is equal to zero ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 95d24d68ee7f..9ff9ff18e5b5 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import py import pytest import tvm from tvm import te From d4b33b8c7f073674788bbdce5334a4de1c046d37 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Fri, 10 Sep 2021 14:22:04 +0000 Subject: [PATCH 4/8] fix div zero error of mod, floordiv, floormod in rewrite_simplify --- src/arith/rewrite_simplify.cc | 3 +++ tests/python/unittest/test_arith_rewrite_simplify.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0087866ea4f8..4a99e10211b7 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -645,6 +645,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return broadcast(truncmod(b1, c2), lanes).Eval(); } @@ -724,6 +725,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); } @@ -852,6 +854,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (floormod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return broadcast(floormod(b1, c2), lanes).Eval(); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 9ff9ff18e5b5..99c52fd855dd 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -939,6 +939,18 @@ def test_div_zero_simplify(): ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) assert "division by zero" in str(cm.execption) + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.Mod(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.FloorDiv(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.FloorMod(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + if __name__ == "__main__": pytest.main([__file__]) From 158eb5d37f495f9b2b9eba271de9028afc1e5bf3 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Fri, 10 Sep 2021 15:00:21 +0000 Subject: [PATCH 5/8] rewrite the test case of divison by zero to fix ci error --- tests/python/unittest/test_arith_rewrite_simplify.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 99c52fd855dd..c32038d4a39a 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -934,21 +934,23 @@ def test_shift_left_simplify(): def test_div_zero_simplify(): ck = RewriteChecker() + ramp = tvm.tir.Ramp(1, 1, 2) + broadcast = tvm.tir.Broadcast(0, 2) with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.Div(ramp, broadcast)) assert "division by zero" in str(cm.execption) with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.Mod(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.Mod(ramp, broadcast)) assert "division by zero" in str(cm.execption) with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.FloorDiv(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.FloorDiv(ramp, broadcast)) assert "division by zero" in str(cm.execption) with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.FloorMod(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.FloorMod(ramp, broadcast)) assert "division by zero" in str(cm.execption) From 8f99035c13ce967dc80ae16103e84fab9fa76aab Mon Sep 17 00:00:00 2001 From: syang-ng Date: Fri, 10 Sep 2021 16:00:24 +0000 Subject: [PATCH 6/8] remove useless tab --- tests/python/unittest/test_arith_rewrite_simplify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c32038d4a39a..6ca2a2a5fcb0 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -944,11 +944,11 @@ def test_div_zero_simplify(): with pytest.raises(tvm.error.TVMError) as cm: ck.analyzer.rewrite_simplify(tvm.tir.Mod(ramp, broadcast)) assert "division by zero" in str(cm.execption) - + with pytest.raises(tvm.error.TVMError) as cm: ck.analyzer.rewrite_simplify(tvm.tir.FloorDiv(ramp, broadcast)) assert "division by zero" in str(cm.execption) - + with pytest.raises(tvm.error.TVMError) as cm: ck.analyzer.rewrite_simplify(tvm.tir.FloorMod(ramp, broadcast)) assert "division by zero" in str(cm.execption) From f4015d90fda45a1d28bc1bd7991ce7476c88edb4 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Thu, 16 Sep 2021 01:28:26 +0000 Subject: [PATCH 7/8] retrigger ci --- src/arith/rewrite_simplify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..46846b1e3bba 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,7 +474,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } From c185796f19b3f380ceeea2f5cc3fca8e92a24c1e Mon Sep 17 00:00:00 2001 From: syang-ng Date: Thu, 16 Sep 2021 01:28:50 +0000 Subject: [PATCH 8/8] remove useless blank to retrigger ci --- src/arith/rewrite_simplify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 46846b1e3bba..4a99e10211b7 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,7 +474,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); }