From 7068d2f00be45136d4d0a436297bd4a8704feee4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 10 Oct 2022 09:45:48 -0500 Subject: [PATCH] [Arith][TIR] Check for constant offsets of known literal constraints Previously, the checks for a literal constraint would find exact matches for an inequality, but any alterations to the conditional would break this exact matching. This commit introduces checks for constant offsets relative to a known value. These checks are not always expressible using the existing `ConstIntSetAnalyzer`, which represents allowed values using a single contiguous region. (e.g. `i!=5` is not representable, because it requires a region for `i<5` and another for `i>5`.) This implementation reuses the internal representation for inequalities introduced in https://github.com/apache/tvm/pull/12863, along with much of its implementation. However, the indirect comparisons (e.g. using `a < b` and `b < c` to prove that `a < c`) introduced in that PR still require an explicit flag to be used. --- include/tvm/arith/analyzer.h | 11 +- src/arith/rewrite_simplify.cc | 7 +- src/arith/transitive_comparison_analyzer.cc | 168 +++++++++++++----- .../unittest/test_tir_transform_simplify.py | 12 ++ 4 files changed, 153 insertions(+), 45 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index b80d75a170585..3c22c449adcd7 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -396,10 +396,19 @@ class TransitiveComparisonAnalyzer { * * \param rhs The right-hand side of the comparison * + * \param propagate_inequalities If true, attempt to find a sequence + * of transitive inequalities that allow the lhs and rhs to be + * compared. If false, only use the known comparison that have been + * directly provided. Using `propagate_inequalities = false` is + * roughly equivalent to comparing against all known inequality + * expressions using `ExprDeepEqual`, but also allows for constant + * offsets on either side of the inequality. + * * \return The most specific result that can be proven about the * comparison. If nothing can be proven, returns kUnknown. */ - TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs); + TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, + bool propagate_inequalities = true); /*! \brief Bind a variable as being equal to a known expression * diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 5e565d7e36c6f..8e0c62e063602 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -84,9 +84,7 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimE if (is_finished()) return output; - if (enabled_extensions_ & kTransitivelyProveInequalities) { - output = CompareResult(output & TryCompareUsingKnownInequalities(x, y)); - } + output = CompareResult(output & TryCompareUsingKnownInequalities(x, y)); return output; } @@ -98,7 +96,8 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimE CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y) { - return analyzer_->transitive_comparisons.TryCompare(x, y); + bool propagate_inequalities = enabled_extensions_ & kTransitivelyProveInequalities; + return analyzer_->transitive_comparisons.TryCompare(x, y, propagate_inequalities); } // try to prove x equals val diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 9a835f7fdec85..b71096a479b56 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -43,10 +43,19 @@ class TransitiveComparisonAnalyzer::Impl { * * \param rhs The right-hand side of the comparison * + * \param propagate_inequalities If true, attempt to find a sequence + * of transitive inequalities that allow the lhs and rhs to be + * compared. If false, only use the known comparison that have been + * directly provided. Using `propagate_inequalities = false` is + * roughly equivalent to comparing against all known values with + * `ExprDeepEqual`, but also allowing for constant offsets on either + * side of the inequality. + * * \return The most specific result that can be proven about the * comparison. If nothing can be proven, returns kUnknown. */ - CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, + bool propagate_inequalities = true) const; /*! \brief Bind a variable as being equal to a known expression * @@ -192,7 +201,37 @@ class TransitiveComparisonAnalyzer::Impl { */ void AddKnown(const PrimExpr& expr, std::vector* vec); - /*! \brief Attempt to compare the expressions, starting at the lhs. + /*! Collect known comparisons between LHS and RHS, without propagation + * + * Allows the internal representation to handle any constant + * offsets, without searching for a sequence of inequalities. + * + * \param lhs_key The left-hand side of the comparison + * + * \param rhs_key The right-hand side of the comparison + * + * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to + * only include comparisons between `lhs_key` and `rhs_key`, + * normalized such that `lhs_key` is on the left-hand side. + */ + std::vector CollectDirectComparisons(Key lhs_key, Key rhs_key) const; + + /*! Collect known comparisons between LHS and RHS, with propagation + * + * \param lhs_key The left-hand side of the comparison + * + * \param rhs_key The right-hand side of the comparison + * + * \returns All comparisons between `lhs_key` and `rhs_key`, + * including the explicitly-provided comparisons in `knowns_` and + * `scoped_knowns_`, and comparisons provable through a series of + * comparisons through other values. All comparisons returned are + * between `lhs_key` and `rhs_key`, and are normalized such that + * `lhs_key` is on the left-hand side. + */ + std::vector CollectIndirectComparisons(Key lhs_key, Key rhs_key) const; + + /*! \brief Internal function used by CollectIndirectComparisons * * Perform a depth-first search through the space of known * expressions, starting at the LHS of a comparison. In this @@ -208,14 +247,29 @@ class TransitiveComparisonAnalyzer::Impl { * expression D, then combine the comparisons that compose the path * into the expression A<=D-4. * - * \param lhs The left-hand side of the comparison + * \param lhs_key The left-hand side of the comparison * - * \param rhs The right-hand side of the comparison + * \param rhs_key The right-hand side of the comparison + * + * \returns A vector of comparisons between the two expressions. + */ + std::vector DFSFromLHS(Key lhs_key, Key rhs_key) const; + + /*! \brief Combine a set of comparisons that share a LHS and RHS + * + * \param lhs_to_rhs The comparisons to merge. These should all + * have the same LHS and RHS. This parameter will typically be the + * result from `CollectDirectComparisons` or + * `CollectIndirectComparisons`. * - * \return The result of the comparison + * \param offset The constant offset in the comparison being proven. + * This is extracted from any additive/subtractive constants in the + * `PrimExpr` arguments to `TryCompare`. + * + * \returns The possible comparisons between LHS and RHS provided + * inequalities. */ - CompareResult DFSFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, - const PrimExpr& rhs) const; + CompareResult MergeComparisons(const std::vector& lhs_to_rhs, int64_t offset) const; /*! \brief Previous Range bindings * @@ -475,8 +529,9 @@ bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique()) {} TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} -CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { - return impl_->TryCompare(lhs, rhs); +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, + bool propagate_inequalities) { + return impl_->TryCompare(lhs, rhs, propagate_inequalities); } void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { @@ -547,7 +602,8 @@ std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const } CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, - const PrimExpr& rhs_expr) const { + const PrimExpr& rhs_expr, + bool propagate_inequalities) const { // Currently only supports integer checks if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { return CompareResult::kUnknown; @@ -575,29 +631,59 @@ CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs return CompareResult::kUnknown; } - auto from_lhs = DFSFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); - auto from_rhs = Reverse(DFSFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); - auto output = from_lhs & from_rhs; + auto lhs_to_rhs = [&]() { + if (propagate_inequalities) { + return CollectIndirectComparisons(lhs_key.value(), rhs_key.value()); + } else { + return CollectDirectComparisons(lhs_key.value(), rhs_key.value()); + } + }(); + return MergeComparisons(lhs_to_rhs, offset); +} + +std::vector +TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rhs_key) const { + std::vector output; + + auto append_known = [&](Comparison cmp) { + if (auto normalized = cmp.WithLHS(lhs_key)) { + if (normalized.value().rhs_ == rhs_key) { + output.push_back(normalized.value()); + } + } + }; + + for (const auto& known : knowns_) { + append_known(known); + } + for (const auto& known : scoped_knowns_) { + append_known(known); + } return output; } -CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, Key rhs_key_input, - int64_t offset_input, - const PrimExpr& lhs_input, - const PrimExpr& rhs_input) const { - Key lhs_key = lhs_key_input; - Key rhs_key = rhs_key_input; - int64_t offset = offset_input; +std::vector +TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key rhs_key) const { + auto output = DFSFromLHS(lhs_key, rhs_key); + for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) { + auto opt_normalized = cmp.WithLHS(lhs_key); + ICHECK(opt_normalized.has_value()); + output.push_back(opt_normalized.value()); + } + return output; +} +std::vector +TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const { // Everything in `to_visit` has lhs as its lhs. std::unordered_set seen; std::unordered_set to_visit; - std::unordered_map> compared_to_x; + std::unordered_map> compared_to_lhs; // Utility function to add a new known statement auto declare_known = [&](Comparison cmp) { - std::vector& knowns = compared_to_x[cmp.rhs_]; + std::vector& knowns = compared_to_lhs[cmp.rhs_]; // The comparison adds no new information, no modification // required. @@ -646,8 +732,8 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, Key middle_key = *to_visit.begin(); to_visit.erase(to_visit.begin()); - std::vector& prev_knowns_using_middle = compared_to_x.at(middle_key); - ICHECK(compared_to_x.count(middle_key)); + std::vector& prev_knowns_using_middle = compared_to_lhs.at(middle_key); + ICHECK(compared_to_lhs.count(middle_key)); std::vector new_knowns_using_lhs; @@ -721,27 +807,29 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, } } - // It's possible that we don't have any transitive comparisons that - // can prove something about LHS and RHS. - auto it = compared_to_x.find(rhs_key); - if (it == compared_to_x.end()) { - return CompareResult::kUnknown; + if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) { + return it->second; + } else { + // There are known comparisons involving the LHS and the RHS, but + // no path that connects the two expressions. + return {}; } +} - const std::vector& known_between_lhs_and_rhs = it->second; - +CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons( + const std::vector& lhs_to_rhs, int64_t offset) const { // Just because we found a comparison involving LHS and RHS doesn't // mean that it's useful. e.g. Knowing that `x < y` doesn't let us // prove whether `x + 5 < y`. CompareResult result = CompareResult::kUnknown; - for (const auto& known : known_between_lhs_and_rhs) { - switch (known.result_) { + for (const auto& cmp : lhs_to_rhs) { + switch (cmp.result_) { case CompareResult::kInconsistent: result = CompareResult::kInconsistent; break; case CompareResult::kEQ: - if (offset == known.offset_) { + if (offset == cmp.offset_) { result = result & CompareResult::kEQ; } else { result = result & CompareResult::kNE; @@ -749,23 +837,23 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, break; case CompareResult::kLE: - if (known.offset_ < offset) { + if (cmp.offset_ < offset) { result = result & CompareResult::kLT; - } else if (known.offset_ <= offset) { + } else if (cmp.offset_ <= offset) { result = result & CompareResult::kLE; } break; case CompareResult::kGE: - if (known.offset_ > offset) { + if (cmp.offset_ > offset) { result = result & CompareResult::kGT; - } else if (known.offset_ >= offset) { + } else if (cmp.offset_ >= offset) { result = result & CompareResult::kGE; } break; case CompareResult::kNE: - if (offset == known.offset_) { + if (offset == cmp.offset_) { result = result & CompareResult::kNE; } break; @@ -779,7 +867,7 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, return CompareResult::kInconsistent; default: - LOG(FATAL) << "Invalid CompareResult: " << static_cast(known.result_); + LOG(FATAL) << "Invalid CompareResult: " << static_cast(cmp.result_); return CompareResult::kInconsistent; } } diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 2eb9c3546ee5b..6236791a75eee 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -816,5 +816,17 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): A[0] = (i != 30) or (j == 0) +class TestProvableConditionWithOffset(BaseBeforeAfter): + transitively_prove_inequalities = False + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32): + if i < j: + A[0] = i < j + 1 + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32): + if i < j: + A[0] = True + + if __name__ == "__main__": tvm.testing.main()