diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fbca3bb5ef62..0fc832e0fb7a 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -46,6 +46,8 @@ #include #include +#include +#include #include #include @@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); }; -class AttrsHashHandler; -class AttrsEqualHandler; -/*! - * \brief Content-aware Equality comparator for attrs. - * - * This comparator will recursively deep compare the following Attributes. - * - * - IntImm, UIntImm, FloatImm, StringImm - * - Any subclass of BaseAttrsNode - * - Array of Attributes. - * - Map from string to Attributes. - */ -class AttrsEqual { - public: - bool operator()(const double& lhs, const double& rhs) const { - // fuzzy float pt comparison - constexpr double atol = 1e-9; - if (lhs == rhs) return true; - double diff = lhs - rhs; - return diff > -atol && diff < atol; - } - - bool operator()(const int64_t& lhs, const int64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { - return lhs == rhs; - } - bool operator()(const bool& lhs, const bool& rhs) const { - return lhs == rhs; - } - bool operator()(const std::string& lhs, const std::string& rhs) const { - return lhs == rhs; - } - bool operator()(const DataType& lhs, const DataType& rhs) const { - return lhs == rhs; - } - // node comparator - TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; - - protected: - friend class AttrsEqualHandler; - /*! \brief internal handle. */ - AttrsEqualHandler* handler_{nullptr}; -}; - -/*! - * \brief Content-aware hash function. - * - * This hash functor will recursively hash the content of the Attributes. - * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b); - */ -class AttrsHash { - public: - size_t operator()(const double& value) const { - return std::hash()(value); - } - size_t operator()(const int64_t& value) const { - return std::hash()(value); - } - size_t operator()(const uint64_t& value) const { - return std::hash()(value); - } - size_t operator()(const int& value) const { - return std::hash()(value); - } - size_t operator()(const bool& value) const { - return std::hash()(value); - } - size_t operator()(const std::string& value) const { - return std::hash()(value); - } - size_t operator()(const DataType& value) const { - return std::hash()( - static_cast(value.code()) | - (static_cast(value.bits()) << 8) | - (static_cast(value.lanes()) << 16)); - } - TVM_DLL size_t operator()(const ObjectRef& value) const; - - private: - friend class AttrsHashHandler; - /*! \brief internal handle. */ - AttrsHashHandler* handler_{nullptr}; -}; - /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, @@ -266,20 +179,6 @@ class BaseAttrsNode : public Object { * \note This function throws when the required field is not present. */ TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; - /*! - * \brief Whether this attribute's content equals to another node. - * \param other The pointer to another node. - * \param equal The equal comparator - * \return The comparison result. - */ - TVM_DLL virtual bool ContentEqual( - const Object* other, AttrsEqual equal) const = 0; - /*! - * \brief Content aware hash. - * \param hasher The hasher to run the hash. - * \return the hash result. - */ - TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - bool ContentEqual(const Object* other, AttrsEqual equal) const final; - size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -386,34 +283,6 @@ class AttrNormalVisitor { AttrVisitor* visitor_; }; -// Wrapper for normal visitor. -class AttrsEqualVisitor { - public: - bool result_{true}; - // constructor - AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal) - : lhs_(lhs), rhs_(rhs), equal_(equal) { - } - template - AttrNopEntry operator()(const char* key, T* lhs_value) { - if (!result_) return AttrNopEntry(); - const T* rhs_value = - reinterpret_cast( - reinterpret_cast(rhs_) + - (reinterpret_cast(lhs_value) - - reinterpret_cast(lhs_))); - if (!equal_(*lhs_value, *rhs_value)) { - result_ = false; - } - return AttrNopEntry(); - } - - private: - const Object* lhs_; - const Object* rhs_; - const AttrsEqual& equal_; -}; - class AttrsSEqualVisitor { public: bool result_{true}; @@ -441,23 +310,6 @@ class AttrsSEqualVisitor { const SEqualReducer& equal_; }; -class AttrsHashVisitor { - public: - explicit AttrsHashVisitor(const AttrsHash& hasher) - : hasher_(hasher) {} - - size_t result_{0}; - - template - AttrNopEntry operator()(const char* key, T* value) { - result_ = dmlc::HashCombine(result_, hasher_(*value)); - return AttrNopEntry(); - } - - private: - const AttrsHash& hasher_; -}; - class AttrsSHashVisitor { public: explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) @@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry { return *this; } TSelf& set_default(const T& value) { - if (AttrsEqual()(value, *data_)) { + if (tvm::StructuralEqual()(value, *data_)) { trigger_ = false; } return *this; @@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Object* other, AttrsEqual equal) const final { - DerivedType* pself = self(); - if (pself == other) return true; - if (other == nullptr) return false; - if (pself->type_index() != other->type_index()) return false; - ::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal); - self()->__VisitAttrs__(visitor); - return visitor.result_; - } - - size_t ContentHash(AttrsHash hasher) const final { - ::tvm::detail::AttrsHashVisitor visitor(hasher); - visitor.result_ = this->GetTypeKeyHash(); - self()->__VisitAttrs__(visitor); - return visitor.result_; - } - private: DerivedType* self() const { return const_cast( diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 9acc4651d089..dbd5a4fab23b 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -147,94 +147,5 @@ class AttrFunctor { } }; -class AttrsEqualHandler : - protected AttrFunctor { - public: - /*! - * \brief Check if lhs equals rhs - * \param lhs The left operand. - * \param rhs The right operand. - */ - bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); - - protected: - bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final; -}; - -class AttrsHashHandler : - protected AttrFunctor { - public: - /*! - * \brief Get hash value of node - * \param node The node to be hashed. - */ - size_t Hash(const ObjectRef& node) { - if (!node.defined()) return 0; - return this->VisitAttr(node); - } - - protected: - size_t VisitAttrDefault_(const Object* lhs) final; - size_t VisitAttr_(const tir::IntImmNode* lhs) final; - size_t VisitAttr_(const tir::FloatImmNode* lhs) final; - size_t VisitAttr_(const tir::StringImmNode* lhs) final; - size_t VisitAttr_(const ArrayNode* lhs) final; - size_t VisitAttr_(const StrMapNode* lhs) final; - size_t VisitAttr_(const tir::AddNode* op) final; - size_t VisitAttr_(const tir::SubNode* op) final; - size_t VisitAttr_(const tir::MulNode* op) final; - size_t VisitAttr_(const tir::DivNode* op) final; - size_t VisitAttr_(const tir::ModNode* op) final; - size_t VisitAttr_(const tir::FloorDivNode* op) final; - size_t VisitAttr_(const tir::FloorModNode* op) final; - size_t VisitAttr_(const tir::MinNode* op) final; - size_t VisitAttr_(const tir::MaxNode* op) final; - size_t VisitAttr_(const tir::GENode* op) final; - size_t VisitAttr_(const tir::GTNode* op) final; - size_t VisitAttr_(const tir::LENode* op) final; - size_t VisitAttr_(const tir::LTNode* op) final; - size_t VisitAttr_(const tir::EQNode* op) final; - size_t VisitAttr_(const tir::NENode* op) final; - size_t VisitAttr_(const tir::AndNode* op) final; - size_t VisitAttr_(const tir::OrNode* op) final; - size_t VisitAttr_(const tir::NotNode* op) final; - size_t VisitAttr_(const tir::CastNode* op) final; - size_t VisitAttr_(const tir::CallNode* op) final; - size_t VisitAttr_(const tir::SelectNode* op) final; - /*! - * \brief alias of dmlc::HashCombine - * \param lhs The first hash value. - * \param rhs The second hash value. - */ - static size_t Combine(size_t lhs, size_t rhs) { - return dmlc::HashCombine(lhs, rhs); - } -}; } // namespace tvm #endif // TVM_IR_ATTR_FUNCTOR_H_ diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 868fec640352..066b8f99ea7c 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") return attrs->dict; }); - -using namespace tir; -// Equal handler. -bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - return this->VisitAttr(lhs, rhs); -} - -bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) { - if (lhs->IsInstance()) { - AttrsEqual equal; - equal.handler_ = this; - return static_cast(lhs)->ContentEqual( - other.get(), equal); - } - return lhs == other.get(); -} - -bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - if (rhs->data.size() != lhs->data.size()) return false; - for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Equal(lhs->data[i], rhs->data[i])) return false; - } - return true; - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - if (rhs->data.size() != lhs->data.size()) return false; - for (const auto& kv : lhs->data) { - auto it = rhs->data.find(kv.first); - if (it == rhs->data.end()) return false; - if (!Equal(kv.second, it->second)) return false; - } - return true; - } else { - return false; - } -} - -#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ - bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \ - if (const auto* rhs = other.as()) { \ - if (!Equal(lhs->a, rhs->a)) return false; \ - if (!Equal(lhs->b, rhs->b)) return false; \ - return true; \ - } else { \ - return false; \ - } \ - } \ - -TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode); - -bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return Equal(lhs->a, rhs->a); - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - if (lhs->dtype != rhs->dtype) return false; - return Equal(lhs->value, rhs->value); - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return - lhs->name == rhs->name && - lhs->dtype == rhs->dtype && - lhs->call_type == rhs->call_type && - Equal(lhs->args, rhs->args); - } else { - return false; - } -} - -bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return - Equal(lhs->condition, rhs->condition) && - Equal(lhs->true_value, rhs->true_value) && - Equal(lhs->false_value, rhs->false_value); - } else { - return false; - } -} - -// Hash Handler. -size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) { - if (value->IsInstance()) { - AttrsHash hasher; - hasher.handler_ = this; - return static_cast(value)->ContentHash(hasher); - } else { - return ObjectHash()(GetRef(value)); - } -} - -size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { - return std::hash()(op->value); -} - -size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { - return std::hash()(op->value); -} - -size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) { - return std::hash()(op->value); -} - -size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) { - size_t result = op->data.size(); - for (size_t i = 0; i < op->data.size(); ++i) { - result = Combine(result, this->Hash(op->data[i])); - } - return result; -} - -size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { - using Entry = std::pair; - std::vector data(lhs->data.begin(), lhs->data.end()); - std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) { - return a.first < b.first; - }); - size_t result = 0; - for (const Entry& kv : data) { - result = Combine(result, std::hash()(kv.first)); - result = Combine(result, this->Hash(kv.second)); - } - return result; -} - - -#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \ - size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \ - static size_t key = std::hash()(NodeName::_type_key); \ - return Combine(key, Combine(Hash(op->a), Hash(op->b))); \ - } \ - -TVM_DEFINE_ATTRS_BINOP_HASH(AddNode); -TVM_DEFINE_ATTRS_BINOP_HASH(SubNode); -TVM_DEFINE_ATTRS_BINOP_HASH(MulNode); -TVM_DEFINE_ATTRS_BINOP_HASH(DivNode); -TVM_DEFINE_ATTRS_BINOP_HASH(ModNode); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode); -TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode); -TVM_DEFINE_ATTRS_BINOP_HASH(MinNode); -TVM_DEFINE_ATTRS_BINOP_HASH(GENode); -TVM_DEFINE_ATTRS_BINOP_HASH(GTNode); -TVM_DEFINE_ATTRS_BINOP_HASH(LENode); -TVM_DEFINE_ATTRS_BINOP_HASH(LTNode); -TVM_DEFINE_ATTRS_BINOP_HASH(EQNode); -TVM_DEFINE_ATTRS_BINOP_HASH(NENode); -TVM_DEFINE_ATTRS_BINOP_HASH(AndNode); -TVM_DEFINE_ATTRS_BINOP_HASH(OrNode); - -size_t AttrsHashHandler::VisitAttr_(const NotNode* op) { - static size_t key = std::hash()(NotNode::_type_key); - return Combine(key, Hash(op->a)); -} - -size_t AttrsHashHandler::VisitAttr_(const CastNode* op) { - static size_t key = std::hash()(CastNode::_type_key); - AttrsHash hasher; - size_t res = key; - res = Combine(res, hasher(op->dtype)); - res = Combine(res, Hash(op->value)); - return res; -} - -size_t AttrsHashHandler::VisitAttr_(const CallNode* op) { - static size_t key = std::hash()(CallNode::_type_key); - AttrsHash hasher; - size_t res = key; - res = Combine(res, hasher(op->name)); - res = Combine(res, hasher(op->dtype)); - res = Combine(res, Hash(op->args)); - return res; -} - -size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) { - static size_t key = std::hash()(SelectNode::_type_key); - size_t res = key; - res = Combine(res, Hash(op->condition)); - res = Combine(res, Hash(op->true_value)); - res = Combine(res, Hash(op->false_value)); - return res; -} - - -// Default case -bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - if (lhs.same_as(rhs)) return true; - if (handler_ == nullptr) { - return AttrsEqualHandler().Equal(lhs, rhs); - } else { - return handler_->Equal(lhs, rhs); - } -} - -size_t AttrsHash::operator()(const ObjectRef& node) const { - if (!node.defined()) return 0; - if (handler_ == nullptr) { - return AttrsHashHandler().Hash(node); - } else { - return handler_->Hash(node); - } -} - -size_t DictAttrsNode::ContentHash(AttrsHash hasher) const { - return hasher(this->dict); -} - -bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { - if (this == other) return true; - if (other == nullptr) return false; - if (this->type_index() != other->type_index()) return false; - return equal(this->dict, static_cast(other)->dict); -} - TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") .set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); -TVM_REGISTER_GLOBAL("ir.AttrsEqual") -.set_body_typed([](ObjectRef lhs, ObjectRef rhs) { - return AttrsEqual()(lhs, rhs); -}); - } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index df7b8ff22850..b2191c1f890c 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -103,6 +103,7 @@ class RemapVarSEqualHandler : // Function that implements actual equality check. bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { + if (!lhs.defined() && !rhs.defined()) return true; task_stack_.clear(); pending_tasks_.clear(); equal_map_lhs_.clear(); diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 0dbce9bf5dd6..3884dacbb22c 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { - AttrsEqual eq; + StructuralEqual eq; const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); @@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { - AttrsEqual eq; + StructuralEqual eq; auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); auto toutput_a = a->type_as(); diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index cd234bbd9fa9..612dae5ef00c 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { protected: virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { - AttrsEqual eq; + StructuralEqual eq; const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); CHECK(attrs_a); diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc index 6b9926c698d6..a7f7af2b79e5 100644 --- a/src/relay/transforms/combine_parallel_op.cc +++ b/src/relay/transforms/combine_parallel_op.cc @@ -23,6 +23,7 @@ * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ +#include #include #include #include @@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { const CallNode* call = branches[0][depth]; - AttrsEqual attrs_equal; + tvm::StructuralEqual attrs_equal; // check if all branches in current depth can be combined for (auto it = branches.begin() + 1; it != branches.end(); it++) { const Branch& branch = *it; diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index fa63573afd50..55ca3f62bec0 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode return false; } - AttrsEqual eq; + StructuralEqual eq; for (size_t i = 0; i < a->args.size(); i++) { auto ta = a->args[i]->type_as(); auto tb = b->args[i]->type_as(); @@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { } bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { - AttrsEqual eq; + StructuralEqual eq; auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index bb31d3222690..f905ba55719d 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator { const CallNode* new_call = new_expr.as(); CHECK(new_call); const OpNode* op = new_call->op.as(); - AttrsEqual attrs_equal; + StructuralEqual attrs_equal; if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef(op), false)) { return new_expr; diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index c3114c78bd1d..49f6e3fd01cd 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - AttrsEqual equal; + StructuralEqual equal; if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; @@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call, } Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); - AttrsEqual equal; + StructuralEqual equal; if (lhs_message.defined() && rhs_message.defined()) { CHECK(equal(lhs_message->axes, rhs_message->axes)); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 6e95441ea162..9168898cae36 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // The output. IndexedForwardGraph graph_; // attribute equal comparator - AttrsEqual attr_equal_; + StructuralEqual attr_equal_; // Update the message stored at the node. void Update(const Expr& node, IndexedForwardGraph::Node* parent, diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 8ce42a2023d8..350d9e1f31fa 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const Array& lhs_axes, Expr* rhs_value = nullptr) { if (tlhs->shape.size() < trhs->shape.size()) return false; - AttrsEqual equal; + StructuralEqual equal; size_t base = tlhs->shape.size() - trhs->shape.size(); size_t j = 0; diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 233bfa51d614..46d0f67c6d51 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") return RewriteForTensorCore(stmt, schedule, extern_buffer); }); -TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual") -.set_body_typed( - [](const ObjectRef& lhs, const ObjectRef& rhs) { - return AttrsEqual()(lhs, rhs); - }); - -TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") -.set_body_typed([](const ObjectRef &node) -> int64_t { - return AttrsHash()(node); -}); - - TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 6d4a685ecebb..dbd5934c38ac 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -106,7 +106,6 @@ def test_function(): check_json_roundtrip(fn) -@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.") def test_function_attrs(): param_names = ['a', 'b', 'c', 'd'] params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names]) diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index f4148caa0642..8f2e9bb8a80d 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -51,14 +51,13 @@ def test_dict_attrs(): def test_attrs_equal(): - attr_equal = tvm.ir._ffi_api.AttrsEqual dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) - assert attr_equal(dattr0, dattr1) - assert not attr_equal(dattr0, dattr2) - assert not attr_equal({"x": 1}, tvm.runtime.convert(1)) - assert not attr_equal([1, 2], tvm.runtime.convert(1)) + assert tvm.ir.structural_equal(dattr0, dattr1) + assert not tvm.ir.structural_equal(dattr0, dattr2) + assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1)) + assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1)) diff --git a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py index b3587cd7cb3d..9a115be74559 100644 --- a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py +++ b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py @@ -21,28 +21,28 @@ def test_attrs_equal(): x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1)) - assert tvm.tir.ir_pass.AttrsEqual(x, y) - assert not tvm.tir.ir_pass.AttrsEqual(x, z) + assert tvm.ir.structural_equal(x, y) + assert not tvm.ir.structural_equal(x, z) dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert not tvm.tir.ir_pass.AttrsEqual(dattr, x) + assert not tvm.ir.structural_equal(dattr, x) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert tvm.tir.ir_pass.AttrsEqual(dattr, dattr2) + assert tvm.ir.structural_equal(dattr, dattr2) - assert tvm.tir.ir_pass.AttrsEqual({"x": x}, {"x": y}) + assert tvm.ir.structural_equal({"x": x}, {"x": y}) # array related checks - assert tvm.tir.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) - assert not tvm.tir.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) + assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]}) + assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]}) n = te.var("n") - assert tvm.tir.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1}) + assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1}) def test_attrs_hash(): - fhash = tvm.tir.ir_pass.AttrsHash + fhash = tvm.ir.structural_hash x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) assert fhash({"x": x}) == fhash({"x": y})