Skip to content

Commit

Permalink
default device_type to CPU for transition purpose
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Sep 19, 2018
1 parent 02a5e35 commit b257811
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,12 @@ class GraphRuntime : public ModuleNode {
// inputs
std::vector<NodeEntry> inputs;
// device_type is used to indicate where the node should be scheduled to.
DLDeviceType device_type;
// TODO(zhiics) device_type is defaulted to CPU for transition purpose only
// because it will have random value otherwise. Using this default value is
// to make sure homogeneous execution works correctly. It will be removed
// when we add the compiler pass as we will read the serialized value from
// json for all execution.
DLDeviceType device_type{kDLCPU};
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
Expand Down Expand Up @@ -496,7 +501,7 @@ StorageDeviceMap GraphRuntime::GetStorageDeviceMap() const {
for (const auto& output : outputs_) {
uint32_t eid = this->entry_id(output);
uint32_t sid = attrs_.storage_id[eid];
auto en_dev = nodes_[eid].device_type;
auto en_dev = nodes_[output.node_id].device_type;
CHECK(sid_dev_map.count(sid) == 0 || sid_dev_map[sid] == en_dev)
<< "Cannot map the same storage id to multiple devices.";
sid_dev_map[sid] = en_dev;
Expand Down Expand Up @@ -728,13 +733,13 @@ std::vector<TVMContext> GetAllContext(const TVMArgs& args) {
TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<TVMContext> contexts;
// 4 argument version is currently reserved to keep support of calling
// from jvm4j and js, since they don't have heterogeneous execution
// support yet. For heterogenenous execution, 5 arguments will be passed
// in. They are graph_json, module, list of context types, list of context
// ids, and the number of devices. Eventually, we will only have the
// version with 5 parameters when we support heterogeneous execution for
// Java and js.
// 4-argument version is currently reserved to keep support of calling
// from tvm4j and javascript, since they don't have heterogeneous
// execution support yet. For heterogenenous execution, 5 arguments will
// be passed in. They are graph_json, module, list of context types, list
// of context ids, and the number of devices.
// Eventually, we will only have the version with 5 arguments when we
// support heterogeneous execution for Java and js.
if (args.num_args == 4) {
TVMContext ctx;
int dev_type = args[2];
Expand All @@ -745,7 +750,8 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
contexts = GetAllContext(args);
} else {
LOG(FATAL)
<< "The number arguments of creaet must be 4 or 5, but it has "
<< "The expected number of arguments for graph_runtime.create is "
"4 or 5, but it has "
<< args.num_args;
}
*rv = GraphRuntimeCreate(args[0], args[1], contexts);
Expand Down

0 comments on commit b257811

Please sign in to comment.