From af1154549a25d5b29bcd247dd1d1e39fbcf4a043 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 21:44:34 +0800 Subject: [PATCH 1/7] move expr complexity --- src/{arith => tir/analysis}/expr_complexity.cc | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{arith => tir/analysis}/expr_complexity.cc (100%) diff --git a/src/arith/expr_complexity.cc b/src/tir/analysis/expr_complexity.cc similarity index 100% rename from src/arith/expr_complexity.cc rename to src/tir/analysis/expr_complexity.cc From 2473559e6cd0b90c77320a762230a7676a4c058e Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 22:07:04 +0800 Subject: [PATCH 2/7] [ARITH] normalize iter expr --- src/arith/iter_affine_map.cc | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 3757b5eb0d51..3d0b41f573ec 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1028,5 +1028,61 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { } } +/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ +class IterMapToExprNormalizer { + public: + explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} + + PrimExpr Convert(const IterMapExpr& expr) { + if (const auto* op = expr.as()) { + return ConvertIterSplitExpr(GetRef(op)); + } else if (const auto* op = expr.as()) { + return ConvertIterSumExpr(GetRef(op)); + } else { + ICHECK(expr.defined()); + LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey(); + return 0; + } + } + + PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) { + PrimExpr res = 0; + for (const IterSplitExpr& arg : expr->args) { + res += ConvertIterSplitExpr(arg); + } + res += expr->base; + return res; + } + + PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) { + PrimExpr source; + if (const auto* op = expr->source->source.as()) { + source = GetRef(op); + } else if (const auto& op = expr->source->source.as()) { + source = ConvertIterSumExpr(GetRef(op)); + } + if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { + return source * expr->scale; + } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { + return floordiv(source, expr->lower_factor) * expr->scale; + } else { + return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale; + } + } + + private: + Analyzer* analyzer_; +}; + +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) { + arith::Analyzer analyzer; + IterMapToExprNormalizer normalizer(&analyzer); + return normalizer.Convert(expr); +} + +TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) { + return NormalizeIterMapToExpr(expr); +}); + } // namespace arith } // namespace tvm From b864ea6d0838bb1d5a520213819ed4693da0cf08 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 23:43:26 +0800 Subject: [PATCH 3/7] [ARITH] normalize iter expr --- python/tvm/arith/iter_affine_map.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 4033d797dff8..5aa817bd7a24 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -112,3 +112,19 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals Empty array if no match can be found. """ return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) + + +def normalize_iter_map_to_expr(expr): + """Given an IterMapExpr, transform it to normal PrimExpr + + Parameters + ---------- + expr : IterMapExpr + the input IterMapExpr + + Returns + ------- + result : PrimExpr + the corresponding normal PrimExpr + """ + return _ffi_api.NormalizeIterMapToExpr(expr) From 8a97482fc0719d0072941baa4ebac7efb775d27f Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 23:45:24 +0800 Subject: [PATCH 4/7] [ARITH] normalize iter expr --- python/tvm/arith/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 77ec869a171e..05843ede9284 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,4 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map +from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr From e257042ab99b0e35d5c79cb45142e55c4c89261b Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sun, 28 Mar 2021 10:41:30 +0800 Subject: [PATCH 5/7] [ARITH] add testcase --- .../unittest/test_arith_iter_affine_map.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ac05809449bd..f0b46b3e5500 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -285,9 +285,29 @@ def test_predicate(): assert len(res) == 0 +def test_normalize_iter_map_to_expr(): + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + x = tvm.tir.Var("x", "int32"), 10 + y = tvm.tir.Var("y", "int32"), 9 + + xo, xi = isplit(x, 5) + yo, yi = isplit(y, 3) + z = ifuse([yo, xo, yi]) + + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[0]), + fld(y[0], 3)*6 + fld(x[0], 5)*3 + flm(y[0], 3)) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), + flm(x[0], 5)) + + if __name__ == "__main__": test_split() test_trivial() test_fuse() test_compound() test_predicate() + test_normalize_iter_map_to_expr() From 4a59b4567bcadb56b984f849e6b10feaab8e100d Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sun, 28 Mar 2021 10:41:54 +0800 Subject: [PATCH 6/7] [ARITH] add testcase --- tests/python/unittest/test_arith_iter_affine_map.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index f0b46b3e5500..5ce68aaaf51b 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -298,10 +298,11 @@ def test_normalize_iter_map_to_expr(): res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[0]), - fld(y[0], 3)*6 + fld(x[0], 5)*3 + flm(y[0], 3)) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), - flm(x[0], 5)) + tvm.ir.assert_structural_equal( + tvm.arith.normalize_iter_map_to_expr(res[0]), + fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3), + ) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) if __name__ == "__main__": From 92044e63e0bdd673ae13f32af9017a9e8e41cc86 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Mon, 29 Mar 2021 15:48:57 +0800 Subject: [PATCH 7/7] [ARITH] process comments --- src/arith/iter_affine_map.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 3d0b41f573ec..a49478a43635 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1058,8 +1058,10 @@ class IterMapToExprNormalizer { PrimExpr source; if (const auto* op = expr->source->source.as()) { source = GetRef(op); - } else if (const auto& op = expr->source->source.as()) { + } else if (const auto* op = expr->source->source.as()) { source = ConvertIterSumExpr(GetRef(op)); + } else { + LOG(FATAL) << "Unexpected source of IterSplitExpr"; } if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { return source * expr->scale;