Skip to content

Commit

Permalink
[API] Move all RTTI related code to one place (#20)
Browse files Browse the repository at this point in the history
* [API] Move all RTTI related code to one place

* add back rtti comment
  • Loading branch information
tqchen authored Jan 19, 2017
1 parent 4d4e19c commit 383494a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 25 deletions.
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 6375e6 to af2a2f
2 changes: 1 addition & 1 deletion dmlc-core
9 changes: 3 additions & 6 deletions src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using RetValue = APIVariantValue;

TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Stmt());
} else {
*ret = Simplify(args.at(0).operator Expr());
Expand All @@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify)

TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
CHECK(args.at(1).type_id == kNodeHandle);
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
} else {
Expr a = args.at(0).operator Expr();
Expr b = args.at(1).operator Expr();
*ret = Equal(a, b);
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
}
});

Expand Down
59 changes: 42 additions & 17 deletions src/c_api/c_api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) {

template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};

template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};

template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}

/*! \brief Variant container for API calls */
class APIVariantValue {
public:
Expand Down Expand Up @@ -127,7 +151,8 @@ class APIVariantValue {
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
NodeTypeChecker<T>::Check(sptr.get());
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
Expand All @@ -140,7 +165,7 @@ class APIVariantValue {
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
Expand Down

0 comments on commit 383494a

Please sign in to comment.