Skip to content

Commit

Permalink
[Bugfix][IR][ATTRS] Fix AttrEqual for Array and StrMap, double
Browse files Browse the repository at this point in the history
- Use fuzzy comparison for double.
- Removed the hack for BatchNormAttrs and DictAttr.

Also removed a warning from text printer printing.
  • Loading branch information
tqchen committed Mar 12, 2020
1 parent ec86d7f commit 6092be3
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 64 deletions.
7 changes: 6 additions & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ class AttrsEqualHandler;
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
return lhs == rhs;
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}

bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
Expand Down
29 changes: 21 additions & 8 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ using namespace tir;
// Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}

Expand All @@ -96,22 +97,25 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
return false;
}

bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
return false;
}

bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
} else {
return false;
}
return false;
}

bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
Expand All @@ -120,8 +124,10 @@ bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other)
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
} else {
return false;
}
return true;
}

bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
Expand All @@ -132,8 +138,10 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other
if (it == rhs->data.end()) return false;
if (!Equal(kv.second, it->second)) return false;
}
return true;
} else {
return false;
}
return true;
}

#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
Expand Down Expand Up @@ -340,8 +348,13 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
}

TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Attrs()->ListFieldInfo();
.set_body_typed([](Attrs attrs) {
return attrs->ListFieldInfo();
});

TVM_REGISTER_GLOBAL("ir.AttrsEqual")
.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
return AttrsEqual()(lhs, rhs);
});

} // namespace tvm
10 changes: 7 additions & 3 deletions src/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class DocTextNode : public DocAtomNode {

explicit DocTextNode(std::string str_val)
: str(str_val) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}

static constexpr const char* _type_key = "printer.DocText";
Expand All @@ -54,6 +51,9 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode);
class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
data_ = runtime::make_object<DocTextNode>(str);
}

Expand Down Expand Up @@ -125,6 +125,10 @@ Doc Doc::Text(std::string text) {
return Doc() << DocText(text);
}

Doc Doc::RawText(std::string text) {
return Doc() << DocAtom(runtime::make_object<DocTextNode>(text));
}

Doc Doc::Indent(int indent, Doc doc) {
for (size_t i = 0; i < doc.stream_.size(); ++i) {
if (auto* line = doc.stream_[i].as<DocLineNode>()) {
Expand Down
5 changes: 5 additions & 0 deletions src/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ class Doc {
* \return The created doc.
*/
static Doc Text(std::string value);
/*!
* \brief Create a doc that represents raw text(can have new lines)
* \return The created doc.
*/
static Doc RawText(std::string value);
/*!
* \brief Create a doc that represents a new line.
* \return The created doc.
Expand Down
2 changes: 1 addition & 1 deletion src/printer/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class TextMetaDataContext {
*/
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc::Text(
return Doc::RawText(
SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(), meta_data_.end())));
}

Expand Down
88 changes: 38 additions & 50 deletions src/relay/analysis/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "../../ir/attr_functor.h"


namespace tvm {
namespace relay {

Expand All @@ -50,37 +52,7 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs.same_as(rhs)) return true;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
return true;
}
return AttrEqual(lhs, rhs);
}

bool DoubleEqual(double l, double r) {
return true;
return VisitAttr(lhs, rhs);
}
/*!
* Check equality of two attributes.
Expand All @@ -90,25 +62,7 @@ class AlphaEqualHandler:
*/
bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
auto compute = [&]() {
if (&lhs == &rhs) return true;
if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = rhs.as<DictAttrsNode>();
if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) {
if (!Equal(k.second, rhsd->dict[k.first])) return false;
}
return true;
}
if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
auto rhsbn = rhs.as<BatchNormAttrs>();
if (!rhsbn) return false;
return (lhsbn->axis == rhsbn->axis)
&& DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
&& (lhsbn->center == rhsbn->center)
&& (lhsbn->scale == rhsbn->scale);
}
return AttrsEqualHandler::Equal(lhs, rhs);
return VisitAttr(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
Expand Down Expand Up @@ -164,6 +118,40 @@ class AlphaEqualHandler:
}

protected:
// So that the new definition of equality in relay can be handled directly.
// Specifically, if a DictAttr contains a value defined by a relay AST.
// We want to able to recursively check the equality in the attr defined by the relay AST.
bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
return true;
}
// Fall back to the object equal case.
return AttrsEqualHandler::VisitAttr(lhs, rhs);
}
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
Expand Down
15 changes: 14 additions & 1 deletion tests/python/unittest/test_ir_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import tvm.ir._ffi_api

def test_make_attrs():
try:
Expand Down Expand Up @@ -50,6 +50,19 @@ def test_dict_attrs():
assert len(dattr.items()) == 4


def test_attrs_equal():
attr_equal = tvm.ir._ffi_api.AttrsEqual
dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
assert attr_equal(dattr0, dattr1)
assert not attr_equal(dattr0, dattr2)
assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
assert not attr_equal([1, 2], tvm.runtime.convert(1))



if __name__ == "__main__":
test_make_attrs()
test_dict_attrs()
test_attrs_equal()

0 comments on commit 6092be3

Please sign in to comment.