From f010494f86a5a839ca1e4377b3271165b0839ce7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 18 Jan 2017 20:18:24 -0800 Subject: [PATCH 1/2] [API] Move all RTTI related code to one place --- HalideIR | 2 +- dmlc-core | 2 +- src/c_api/c_api_pass.cc | 9 ++---- src/c_api/c_api_registry.h | 61 ++++++++++++++++++++++++++++---------- 4 files changed, 50 insertions(+), 24 deletions(-) diff --git a/HalideIR b/HalideIR index 6375e6b76f6b..594e091a7d85 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169 +Subproject commit 594e091a7d857d4142394e86b9177881cb290ed8 diff --git a/dmlc-core b/dmlc-core index 749e570c1942..3a51614d39b6 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16 +Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6 diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index e45e25a265d0..d7d41c4f693e 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -15,7 +15,7 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { - if (dynamic_cast(args.at(0).sptr.get())) { + if (NodeTypeChecker::Check(args.at(0).sptr.get())) { *ret = Simplify(args.at(0).operator Stmt()); } else { *ret = Simplify(args.at(0).operator Expr()); @@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify) TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { - if (dynamic_cast(args.at(0).sptr.get())) { - CHECK(args.at(1).type_id == kNodeHandle); + if (NodeTypeChecker::Check(args.at(0).sptr.get())) { *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); } else { - Expr a = args.at(0).operator Expr(); - Expr b = args.at(1).operator Expr(); - *ret = Equal(a, b); + *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); } }); diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 835368f500ce..07b61acb6f89 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -33,41 +33,69 @@ inline const char* TypeId2Str(ArgVariantID type_id) { template struct NodeTypeChecker { - static inline void Check(Node* sptr) { + static inline bool Check(Node* sptr) { +#if DMLC_ENABLE_RTTI + // This is the only place in the project where RTTI is used + // can be turned off: which will cause non-strict checking. using ContainerType = typename T::ContainerType; - // use dynamic RTTI for safety - CHECK(dynamic_cast(sptr)) - << "wrong type specified, expected " << typeid(ContainerType).name(); + return (dynamic_cast(sptr) != nullptr); +#else + return true; +#endif + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + using ContainerType = typename T::ContainerType; + os << ContainerType::_type_key; } }; template struct NodeTypeChecker > { - static inline void Check(Node* sptr) { + static inline bool Check(Node* sptr) { // use dynamic RTTI for safety - CHECK(sptr != nullptr && sptr->is_type()) - << "wrong type specified, expected Array"; + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; ArrayNode* n = static_cast(sptr); for (const auto& p : n->data) { - NodeTypeChecker::Check(p.get()); + if (!NodeTypeChecker::Check(p.get())) return false; } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "array<"; + NodeTypeChecker::PrintName(os); + os << ">"; } }; template struct NodeTypeChecker > { - static inline void Check(Node* sptr) { - // use dynamic RTTI for safety - CHECK(sptr != nullptr && sptr->is_type()) - << "wrong type specified, expected Map"; + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; MapNode* n = static_cast(sptr); for (const auto& kv : n->data) { - NodeTypeChecker::Check(kv.first.get()); - NodeTypeChecker::Check(kv.second.get()); + if (!NodeTypeChecker::Check(kv.first.get())) return false; + if (!NodeTypeChecker::Check(kv.second.get())) return false; } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "map<"; + NodeTypeChecker::PrintName(os); + os << ','; + NodeTypeChecker::PrintName(os); + os << '>'; } }; +template +inline std::string NodeTypeName() { + std::ostringstream os; + NodeTypeChecker::PrintName(os); + return os.str(); +} + /*! \brief Variant container for API calls */ class APIVariantValue { public: @@ -127,7 +155,8 @@ class APIVariantValue { inline operator T() const { if (type_id == kNull) return T(); CHECK_EQ(type_id, kNodeHandle); - NodeTypeChecker::Check(sptr.get()); + CHECK(NodeTypeChecker::Check(sptr.get())) + << "Did not get expected type " << NodeTypeName(); return T(sptr); } inline operator Expr() const { @@ -140,7 +169,7 @@ class APIVariantValue { if (sptr->is_type()) { return IterVar(sptr)->var; } else { - CHECK(dynamic_cast(sptr.get())) + CHECK(NodeTypeChecker::Check(sptr.get())) << "did not pass in Expr in a place need Expr"; return Expr(sptr); } From 667abbb1826cd583c2a924e54ba6e62caaaf4a2e Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 18 Jan 2017 22:18:19 -0800 Subject: [PATCH 2/2] add back rtti comment --- HalideIR | 2 +- src/c_api/c_api_registry.h | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/HalideIR b/HalideIR index 594e091a7d85..af2a2fcee593 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 594e091a7d857d4142394e86b9177881cb290ed8 +Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438 diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 07b61acb6f89..7223ebaee8b9 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -34,14 +34,11 @@ inline const char* TypeId2Str(ArgVariantID type_id) { template struct NodeTypeChecker { static inline bool Check(Node* sptr) { -#if DMLC_ENABLE_RTTI // This is the only place in the project where RTTI is used - // can be turned off: which will cause non-strict checking. + // It can be turned off, but will make non strict checking. + // TODO(tqchen) possibly find alternative to turn of RTTI using ContainerType = typename T::ContainerType; return (dynamic_cast(sptr) != nullptr); -#else - return true; -#endif } static inline void PrintName(std::ostringstream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; @@ -52,7 +49,6 @@ struct NodeTypeChecker { template struct NodeTypeChecker > { static inline bool Check(Node* sptr) { - // use dynamic RTTI for safety if (sptr == nullptr) return false; if (!sptr->is_type()) return false; ArrayNode* n = static_cast(sptr);