diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index c2747d1458579..a79537f46e744 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -441,8 +441,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, const Array& const_names) { - auto n = make_object(symbol_name.operator std::string(), - graph_json.operator std::string(), const_names); + auto n = make_object(symbol_name, graph_json, const_names); return runtime::Module(n); } diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index a2769568cf043..c4f126e8ccba9 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -204,7 +204,7 @@ class JSONRuntimeBase : public ModuleNode { */ void SetupConstants(const Array& consts) { for (size_t i = 0; i < consts.size(); ++i) { - data_entry_[const_idx_[i]] = consts[i].operator->(); + data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->(); } } @@ -253,9 +253,9 @@ class JSONRuntimeBase : public ModuleNode { std::vector outputs_; /*! \brief Data of that entry. */ std::vector data_entry_; - /*! \brief Map the input name to index. */ + /*! \brief Map the input name to node index. */ std::vector input_var_idx_; - /*! \brief input const index. */ + /*! \brief input const node index. */ std::vector const_idx_; /*! \brief Indicate if the engine has been initialized. */ bool initialized_{false}; diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index d15468c6a9423..a886692c58386 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -53,7 +53,7 @@ def check_result(mod, # Run the reference result compile_engine.get().clear() - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(ref_mod, target=target, params=params) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)