Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jun 25, 2020
1 parent db5db74 commit 0649788
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/codegen_json/codegen_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
node->SetNumOutput(tuple_type->fields.size());
} else {
const auto* tensor_type = checked_type.as<TensorTypeNode>();
CHECK(tensor_type) << "Expect TensorType, but received: ." << checked_type->GetTypeKey();
CHECK(tensor_type) << "Expect TensorType, but received: " << checked_type->GetTypeKey();
shape.emplace_back(GetIntShape(tensor_type->shape));
dtype.emplace_back(DType2String(tensor_type->dtype));
ret.push_back(JSONGraphNodeEntry(node_id, 0));
Expand Down
26 changes: 26 additions & 0 deletions src/runtime/contrib/json/json_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,32 @@ class JSONGraphNode {

namespace dmlc {
namespace json {
template <typename T>
inline bool SameType(const dmlc::any& data) {
return std::type_index(data.type()) == std::type_index(typeid(T));
}

template <>
struct Handler<std::unordered_map<std::string, dmlc::any>> {
inline static void Write(dmlc::JSONWriter* writer,
const std::unordered_map<std::string, dmlc::any>& data) {
for (const auto& kv : data) {
auto k = kv.first;
const dmlc::any& v = kv.second;
if (SameType<std::vector<dmlc::any>>(v)) {
writer->WriteObjectKeyValue(k, dmlc::get<std::vector<dmlc::any>>(v));
} else {
LOG(FATAL) << "Not supported";
}
}
writer->EndObject();
}
inline static void Read(dmlc::JSONReader* reader,
std::unordered_map<std::string, dmlc::any>* data) {
LOG(FATAL) << "Not implemented";
}
};

template <>
struct Handler<std::shared_ptr<tvm::runtime::json::JSONGraphNode>> {
inline static void Write(dmlc::JSONWriter* writer,
Expand Down
10 changes: 5 additions & 5 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class JSONRuntimeBase : public ModuleNode {
} else if (key == "heads") {
reader->Read(&outputs_);
} else {
LOG(FATAL) << "Unknow key: " << key;
LOG(FATAL) << "Unknown key: " << key;
}
}
}
Expand All @@ -237,11 +237,11 @@ class JSONRuntimeBase : public ModuleNode {
uint32_t NumEntries() const { return node_row_ptr_.back(); }

protected:
/* The only subgraph name for this module. */
/*! \brief The only subgraph name for this module. */
std::string symbol_name_;
/* The graph. */
/*! \brief The graph. */
std::string graph_json_;
/* The required constant names. */
/*! \brief The required constant names. */
Array<String> const_names_;
/*! \brief The json graph nodes. */
std::vector<JSONGraphNode> nodes_;
Expand All @@ -257,7 +257,7 @@ class JSONRuntimeBase : public ModuleNode {
std::vector<uint32_t> input_var_idx_;
/*! \brief input const index. */
std::vector<uint32_t> const_idx_;
/* Indicate if the engine has been initialized. */
/*! \brief Indicate if the engine has been initialized. */
bool initialized_{false};
};

Expand Down

0 comments on commit 0649788

Please sign in to comment.