From f10565b089f48632453416fe9d87329be27fd5c2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 16 Jul 2025 10:00:08 -0400 Subject: [PATCH] [TIR] Decouple DeepEqual from StructuralEqual This PR decouples deep equal from structural equal implementation by providing a more direct implementatio through functor. DeepEqual is being used at heart of arith simplification as subroutine and it performs more direct nested checking without doing var remapping as structural equal for efficiency reasons. It also do not need to trace the wrong comparison since the failed path is also expected to happen often. This step likely will improve the deep equal efficiency because of the more direct approach and gives us opportunity to run simplify future refactor of structural equal to focus on struct path tracing. --- src/tir/analysis/deep_equal.cc | 179 ++++++++++++++++++++++++++++----- 1 file changed, 154 insertions(+), 25 deletions(-) diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index bb105be02ada..fe22d152cb5e 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -25,48 +25,177 @@ #include #include #include -#include #include +#include namespace tvm { namespace tir { -class DeepCmpSEqualHandler : public SEqualReducer::Handler { +#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \ + bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \ + const auto* prhs = rhs.as(); \ + return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \ + VisitExpr(plhs->b, prhs->b); \ + } + +#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode) \ + bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \ + const auto* prhs = rhs.as(); \ + return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \ + } + +class ExprDeepEqualChecker : private ExprFunctor { public: - // use direct recursion. - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, - const Optional&) final { + static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) { + // quick path without constructing the object if (lhs.same_as(rhs)) return true; if (!lhs.defined() && rhs.defined()) return false; if (!rhs.defined() && lhs.defined()) return false; if (lhs->type_index() != rhs->type_index()) return false; - return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) && - !fail_; + if (auto* plhs = lhs.as()) { + auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && plhs->value == prhs->value; + } + return ExprDeepEqualChecker().VisitExpr(lhs, rhs); } - void DeferFail(const ObjectPathPair&) final { fail_ = true; } - bool IsFailDeferralEnabled() final { return false; } - - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; } - void MarkGraphNode() final {} + bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + return ExprFunctor::VisitExpr(lhs, rhs); + } private: - // reflection vtable - ReflectionVTable* vtable_ = ReflectionVTable::Global(); - bool fail_ = false; + bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); i++) { + if (!VisitExpr(lhs[i], rhs[i])) return false; + } + return true; + } + + bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + // for iter var, we require pointer equality + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); i++) { + if (!lhs[i].same_as(rhs[i])) return true; + } + return true; + } + + bool OptionalDeepEqual(const Optional& lhs, const Optional& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (lhs.defined() && !rhs.defined()) return false; + return VisitExpr(*lhs, *rhs); + } + + bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final { + // for var, we require pointer equality + return plhs == rhs.get(); + } + + bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final { + // for var, we require pointer equality + return plhs == rhs.get(); + } + + bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + // we run pointer comparison of the buffer + return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) && + ArrayDeepEqual(plhs->indices, prhs->indices) && + OptionalDeepEqual(plhs->predicate, prhs->predicate); + } + + bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + // run shallow pointer comparison of the producer + return plhs->dtype == prhs->dtype && plhs->producer.same_as(prhs->producer) && + ArrayDeepEqual(plhs->indices, prhs->indices); + } + + bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) && + VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body, prhs->body); + } + + bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) && + ArrayDeepEqual(plhs->args, prhs->args); + } + + bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && plhs->combiner.same_as(prhs->combiner) && + ArrayDeepEqual(plhs->source, prhs->source) && ArrayDeepEqual(plhs->init, prhs->init) && + ArrayDeepEqual(plhs->axis, prhs->axis) && VisitExpr(plhs->condition, prhs->condition) && + plhs->value_index == prhs->value_index; + } + + bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value); + } + + bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a); + } + + bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition, prhs->condition) && + VisitExpr(plhs->true_value, prhs->true_value) && + VisitExpr(plhs->false_value, prhs->false_value); + } + + bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) && + VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes, prhs->lanes); + } + + bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors, prhs->vectors) && + ArrayDeepEqual(plhs->indices, prhs->indices); + } + + bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final { + const auto* prhs = rhs.as(); + return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) && + VisitExpr(plhs->lanes, prhs->lanes); + } + + DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(NENode) + DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(LENode) + DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(GENode) + DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode) + DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode) + DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode) + DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode) + DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode) }; bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { - // quick path - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - if (lhs->type_index() != rhs->type_index()) return false; - if (auto* plhs = lhs.as()) { - auto* prhs = rhs.as(); - return plhs->dtype == prhs->dtype && plhs->value == prhs->value; - } - return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt); + return ExprDeepEqualChecker::Check(lhs, rhs); } TVM_FFI_STATIC_INIT_BLOCK({