Skip to content

Commit

Permalink
Handle recursive attr equal
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 12, 2020
1 parent 073f469 commit 87db224
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
LOG(INFO) << "Content equal ";
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}

Expand Down
66 changes: 34 additions & 32 deletions src/relay/analysis/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,37 +52,7 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs.same_as(rhs)) return true;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
return true;
}
return AttrEqual(lhs, rhs);
}

bool DoubleEqual(double l, double r) {
return true;
return VisitAttr(lhs, rhs);
}
/*!
* Check equality of two attributes.
Expand All @@ -92,7 +62,7 @@ class AlphaEqualHandler:
*/
bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
auto compute = [&]() {
return AttrsEqualHandler::Equal(lhs, rhs);
return VisitAttr(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
Expand Down Expand Up @@ -148,6 +118,38 @@ class AlphaEqualHandler:
}

protected:
// So that the new definition of equality in relay can be handled directly.
bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
auto rhsm = rhs.as<IRModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
return true;
}
// Fall back to the object equal case.
return AttrsEqualHandler::VisitAttr(lhs, rhs);
}
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
Expand Down

0 comments on commit 87db224

Please sign in to comment.