Skip to content

Commit

Permalink
address comment; fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jun 29, 2020
1 parent 4190373 commit f5487b4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json,
const Array<String>& const_names) {
auto n = make_object<DNNLJSONRuntime>(symbol_name.operator std::string(),
graph_json.operator std::string(), const_names);
auto n = make_object<DNNLJSONRuntime>(symbol_name, graph_json, const_names);
return runtime::Module(n);
}

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class JSONRuntimeBase : public ModuleNode {
*/
void SetupConstants(const Array<NDArray>& consts) {
for (size_t i = 0; i < consts.size(); ++i) {
data_entry_[const_idx_[i]] = consts[i].operator->();
data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->();
}
}

Expand Down Expand Up @@ -253,9 +253,9 @@ class JSONRuntimeBase : public ModuleNode {
std::vector<JSONGraphNodeEntry> outputs_;
/*! \brief Data of that entry. */
std::vector<const DLTensor*> data_entry_;
/*! \brief Map the input name to index. */
/*! \brief Map the input name to node index. */
std::vector<uint32_t> input_var_idx_;
/*! \brief input const index. */
/*! \brief input const node index. */
std::vector<uint32_t> const_idx_;
/*! \brief Indicate if the engine has been initialized. */
bool initialized_{false};
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_json_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def check_result(mod,

# Run the reference result
compile_engine.get().clear()
with relay.build_config(opt_level=3):
with tvm.transform.PassContext(opt_level=3):
json, lib, param = relay.build(ref_mod, target=target, params=params)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

Expand Down

0 comments on commit f5487b4

Please sign in to comment.