diff --git a/HalideIR b/HalideIR index 6375e6b76f6b..af2a2fcee593 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169 +Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438 diff --git a/dmlc-core b/dmlc-core index 749e570c1942..3a51614d39b6 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16 +Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6 diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index e45e25a265d0..d7d41c4f693e 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -15,7 +15,7 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { - if (dynamic_cast(args.at(0).sptr.get())) { + if (NodeTypeChecker::Check(args.at(0).sptr.get())) { *ret = Simplify(args.at(0).operator Stmt()); } else { *ret = Simplify(args.at(0).operator Expr()); @@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify) TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { - if (dynamic_cast(args.at(0).sptr.get())) { - CHECK(args.at(1).type_id == kNodeHandle); + if (NodeTypeChecker::Check(args.at(0).sptr.get())) { *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); } else { - Expr a = args.at(0).operator Expr(); - Expr b = args.at(1).operator Expr(); - *ret = Equal(a, b); + *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); } }); diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 835368f500ce..7223ebaee8b9 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) { template struct NodeTypeChecker { - static inline void Check(Node* sptr) { + static inline bool Check(Node* sptr) { + // This is the only place in the project where RTTI is used + // It can be turned off, but will make non strict checking. + // TODO(tqchen) possibly find alternative to turn of RTTI using ContainerType = typename T::ContainerType; - // use dynamic RTTI for safety - CHECK(dynamic_cast(sptr)) - << "wrong type specified, expected " << typeid(ContainerType).name(); + return (dynamic_cast(sptr) != nullptr); + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + using ContainerType = typename T::ContainerType; + os << ContainerType::_type_key; } }; template struct NodeTypeChecker > { - static inline void Check(Node* sptr) { - // use dynamic RTTI for safety - CHECK(sptr != nullptr && sptr->is_type()) - << "wrong type specified, expected Array"; + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; ArrayNode* n = static_cast(sptr); for (const auto& p : n->data) { - NodeTypeChecker::Check(p.get()); + if (!NodeTypeChecker::Check(p.get())) return false; } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "array<"; + NodeTypeChecker::PrintName(os); + os << ">"; } }; template struct NodeTypeChecker > { - static inline void Check(Node* sptr) { - // use dynamic RTTI for safety - CHECK(sptr != nullptr && sptr->is_type()) - << "wrong type specified, expected Map"; + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; MapNode* n = static_cast(sptr); for (const auto& kv : n->data) { - NodeTypeChecker::Check(kv.first.get()); - NodeTypeChecker::Check(kv.second.get()); + if (!NodeTypeChecker::Check(kv.first.get())) return false; + if (!NodeTypeChecker::Check(kv.second.get())) return false; } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "map<"; + NodeTypeChecker::PrintName(os); + os << ','; + NodeTypeChecker::PrintName(os); + os << '>'; } }; +template +inline std::string NodeTypeName() { + std::ostringstream os; + NodeTypeChecker::PrintName(os); + return os.str(); +} + /*! \brief Variant container for API calls */ class APIVariantValue { public: @@ -127,7 +151,8 @@ class APIVariantValue { inline operator T() const { if (type_id == kNull) return T(); CHECK_EQ(type_id, kNodeHandle); - NodeTypeChecker::Check(sptr.get()); + CHECK(NodeTypeChecker::Check(sptr.get())) + << "Did not get expected type " << NodeTypeName(); return T(sptr); } inline operator Expr() const { @@ -140,7 +165,7 @@ class APIVariantValue { if (sptr->is_type()) { return IterVar(sptr)->var; } else { - CHECK(dynamic_cast(sptr.get())) + CHECK(NodeTypeChecker::Check(sptr.get())) << "did not pass in Expr in a place need Expr"; return Expr(sptr); }