diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 77ec869a171e..05843ede9284 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,4 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map +from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 4033d797dff8..5aa817bd7a24 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -112,3 +112,19 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals Empty array if no match can be found. """ return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) + + +def normalize_iter_map_to_expr(expr): + """Given an IterMapExpr, transform it to normal PrimExpr + + Parameters + ---------- + expr : IterMapExpr + the input IterMapExpr + + Returns + ------- + result : PrimExpr + the corresponding normal PrimExpr + """ + return _ffi_api.NormalizeIterMapToExpr(expr) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 3757b5eb0d51..a49478a43635 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1028,5 +1028,63 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { } } +/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ +class IterMapToExprNormalizer { + public: + explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} + + PrimExpr Convert(const IterMapExpr& expr) { + if (const auto* op = expr.as()) { + return ConvertIterSplitExpr(GetRef(op)); + } else if (const auto* op = expr.as()) { + return ConvertIterSumExpr(GetRef(op)); + } else { + ICHECK(expr.defined()); + LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey(); + return 0; + } + } + + PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) { + PrimExpr res = 0; + for (const IterSplitExpr& arg : expr->args) { + res += ConvertIterSplitExpr(arg); + } + res += expr->base; + return res; + } + + PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) { + PrimExpr source; + if (const auto* op = expr->source->source.as()) { + source = GetRef(op); + } else if (const auto* op = expr->source->source.as()) { + source = ConvertIterSumExpr(GetRef(op)); + } else { + LOG(FATAL) << "Unexpected source of IterSplitExpr"; + } + if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { + return source * expr->scale; + } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { + return floordiv(source, expr->lower_factor) * expr->scale; + } else { + return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale; + } + } + + private: + Analyzer* analyzer_; +}; + +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr) { + arith::Analyzer analyzer; + IterMapToExprNormalizer normalizer(&analyzer); + return normalizer.Convert(expr); +} + +TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const IterMapExpr& expr) { + return NormalizeIterMapToExpr(expr); +}); + } // namespace arith } // namespace tvm diff --git a/src/arith/expr_complexity.cc b/src/tir/analysis/expr_complexity.cc similarity index 100% rename from src/arith/expr_complexity.cc rename to src/tir/analysis/expr_complexity.cc diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ac05809449bd..5ce68aaaf51b 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -285,9 +285,30 @@ def test_predicate(): assert len(res) == 0 +def test_normalize_iter_map_to_expr(): + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + x = tvm.tir.Var("x", "int32"), 10 + y = tvm.tir.Var("y", "int32"), 9 + + xo, xi = isplit(x, 5) + yo, yi = isplit(y, 3) + z = ifuse([yo, xo, yi]) + + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + + tvm.ir.assert_structural_equal( + tvm.arith.normalize_iter_map_to_expr(res[0]), + fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3), + ) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + + if __name__ == "__main__": test_split() test_trivial() test_fuse() test_compound() test_predicate() + test_normalize_iter_map_to_expr()