diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 899db08055fc3..4413fc36879ce 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -143,8 +143,13 @@ class AttrsEqualHandler; class AttrsEqual { public: bool operator()(const double& lhs, const double& rhs) const { - return lhs == rhs; + // fuzzy float pt comparison + constexpr double atol = 1e-9; + if (lhs == rhs) return true; + double diff = lhs - rhs; + return diff > -atol && diff < atol; } + bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 4c4c997774700..868fec640352b 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -79,7 +79,8 @@ using namespace tir; // Equal handler. bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; return this->VisitAttr(lhs, rhs); } @@ -96,22 +97,25 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { @@ -120,8 +124,10 @@ bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) for (size_t i = 0; i < lhs->data.size(); ++i) { if (!Equal(lhs->data[i], rhs->data[i])) return false; } + return true; + } else { + return false; } - return true; } bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { @@ -132,8 +138,10 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other if (it == rhs->data.end()) return false; if (!Equal(kv.second, it->second)) return false; } + return true; + } else { + return false; } - return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ @@ -340,8 +348,13 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { } TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Attrs()->ListFieldInfo(); +.set_body_typed([](Attrs attrs) { + return attrs->ListFieldInfo(); +}); + +TVM_REGISTER_GLOBAL("ir.AttrsEqual") +.set_body_typed([](ObjectRef lhs, ObjectRef rhs) { + return AttrsEqual()(lhs, rhs); }); } // namespace tvm diff --git a/src/printer/doc.cc b/src/printer/doc.cc index c5595dbeeba8f..ee260f41df55e 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -40,9 +40,6 @@ class DocTextNode : public DocAtomNode { explicit DocTextNode(std::string str_val) : str(str_val) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; - } } static constexpr const char* _type_key = "printer.DocText"; @@ -54,6 +51,9 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: explicit DocText(std::string str) { + if (str.find_first_of("\t\n") != str.npos) { + LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; + } data_ = runtime::make_object(str); } @@ -125,6 +125,10 @@ Doc Doc::Text(std::string text) { return Doc() << DocText(text); } +Doc Doc::RawText(std::string text) { + return Doc() << DocAtom(runtime::make_object(text)); +} + Doc Doc::Indent(int indent, Doc doc) { for (size_t i = 0; i < doc.stream_.size(); ++i) { if (auto* line = doc.stream_[i].as()) { diff --git a/src/printer/doc.h b/src/printer/doc.h index 34a284b0f116a..7d8d72e00b4ca 100644 --- a/src/printer/doc.h +++ b/src/printer/doc.h @@ -110,6 +110,11 @@ class Doc { * \return The created doc. */ static Doc Text(std::string value); + /*! + * \brief Create a doc that represents raw text(can have new lines) + * \return The created doc. + */ + static Doc RawText(std::string value); /*! * \brief Create a doc that represents a new line. * \return The created doc. diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index 6c300fd851760..d3906926363cc 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -121,7 +121,7 @@ class TextMetaDataContext { */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - return Doc::Text( + return Doc::RawText( SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc index 8a07a19b2af78..3721f35e4a148 100644 --- a/src/relay/analysis/alpha_equal.cc +++ b/src/relay/analysis/alpha_equal.cc @@ -30,6 +30,8 @@ #include #include #include "../../ir/attr_functor.h" + + namespace tvm { namespace relay { @@ -90,24 +92,6 @@ class AlphaEqualHandler: */ bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) { auto compute = [&]() { - if (&lhs == &rhs) return true; - if (auto lhsd = lhs.as()) { - auto rhsd = rhs.as(); - if (!rhsd) return false; - if (lhsd->dict.size() != rhsd->dict.size()) return false; - for (const auto& k : lhsd->dict) { - if (!Equal(k.second, rhsd->dict[k.first])) return false; - } - return true; - } - if (auto lhsbn = lhs.as()) { - auto rhsbn = rhs.as(); - if (!rhsbn) return false; - return (lhsbn->axis == rhsbn->axis) - && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon) - && (lhsbn->center == rhsbn->center) - && (lhsbn->scale == rhsbn->scale); - } return AttrsEqualHandler::Equal(lhs, rhs); }; return Compare(compute(), lhs, rhs); diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index a2be2b7bf11f2..f4148caa0642e 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +import tvm.ir._ffi_api def test_make_attrs(): try: @@ -50,6 +50,19 @@ def test_dict_attrs(): assert len(dattr.items()) == 4 +def test_attrs_equal(): + attr_equal = tvm.ir._ffi_api.AttrsEqual + dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) + dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) + dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) + assert attr_equal(dattr0, dattr1) + assert not attr_equal(dattr0, dattr2) + assert not attr_equal({"x": 1}, tvm.runtime.convert(1)) + assert not attr_equal([1, 2], tvm.runtime.convert(1)) + + + if __name__ == "__main__": test_make_attrs() test_dict_attrs() + test_attrs_equal()