Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heterogeneous Runtime #1695

Merged
merged 10 commits into from
Sep 22, 2018
Merged

Heterogeneous Runtime #1695

merged 10 commits into from
Sep 22, 2018

Conversation

zhiics
Copy link
Member

@zhiics zhiics commented Sep 7, 2018

This is the first part of the PR #1688.

This PR only focuses on making the runtime be able to take heterogeneous graphs. Changes are mainly made for graph runtime c++ and python interfaces. Meanwhile, to test the execution, I manually created two simple graphs containing only addition, subtraction, and copy nodes. One test, test_simplex_data_transferring tests data transferring from GPU to CPU at runtime. The other one, test_duplex_data_transferring tests duplex data transferring back-and-forth between GPU and CPU.

A new column, device_type, is added to the json file, which indicates which device a node should be scheduled to. In this PR this column is manually created as part of a graph json file. This field will be also used to annotate the graph node in the compiler PR. The serialization of this column is similar to that of the dtype column in the current json file. Loading/Saveing json API need to be modified slightly to support this field.

Tested the functionality on a MacBook with Intel CPU and Intel Graphics GPU for using the generated module in memory and exporting/importing it from the disk.

The next PR will focus on the compiler part to generate heterogeneous binaries to feed in the runtime. Major changes will be needed for the compiler.build interface. Another issue is the removal of with target statements in the high-level build interface.

# CPU is always used as the host processor. Its device type is 1 as
# defined in TVMContext and dlpack.h. The libmod_ctx is sorted according
# to the device type field in TVMContext. It is used to guarantee that the
# first lib and context in the array belong to CPU.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not good at the first glance. rely on the device type number is a little bit tricky, and is it really necessary to make host device the head of list? - given the fact that you pass host_ctx as an arg.

Copy link
Member Author

@zhiics zhiics Sep 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhliu Thanks. Yeah, I am also aware of this. The host_ctx here is actually a bad name. It is actually the local context that deploys the module as used in graph_runtime.create(). I think it is not guaranteed to be the context for the host processors. I will change the name to local_ctx.

Another option was to search the dictionary and check cpu context. Then we can put cpu related lib/device_type/device_id as the first element in each list and append each field for the other devices, like:

for lib, ctx in libmod_ctx.items():
    if (ctx == tvm.cpu(ctx.device_id)):
         libs/device_types/device_ids.insert(0, ...)
    else:
        libs/device_types/device_ids.append(...)
if device_types[0] != STR2MASK["cpu"]:
    Raise....

I went for the sorting way because I think it is transparent to users and well documented. Actually, I can probably also add a check for the first element of device_types to make sure it is cpu.

@@ -73,7 +52,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
api_[type] = GetAPI(tvm::runtime::DeviceName(type), allow_missing);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change

StorageDeviceMap sid_dev_map;
for (uint32_t nid = 0; nid < this->num_nodes(); ++nid) {
const auto &inode = nodes_[nid];
for (const auto &e : inode.inputs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor issue, auto& var instead

for (const auto &e : inode.inputs) {
uint32_t eid = this->entry_id(e);
uint32_t sid = attrs_.storage_id[eid];
sid_dev_map[sid] = nodes_[e.node_id].device;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(sid_dev_map.count(sid) == 0)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhliu I think you probably meant:
CHECK(sid_dev_map.count(sid) == 0 || sid_dev_map[sid] == nodes_[e.node_id].device) << "Cannot assign the storage id to multiple devices ", right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

size_t size = 1;
for (int64_t sz : attrs_.shape[i]) {
size *= static_cast<size_t>(sz);
}
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
CHECK_GE(sid, 0) << "Do not support runtime shape op";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confusing, cast to uint and check >= 0?

}
pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes);
DLDeviceType dev_type = sid_dev_map[sid];
device_pool_entry_bytes[dev_type][sid] =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there's no need to make dev_type a key. it can be achieved from sid_dev_map right? just feel better to keep data struct simple.

Copy link
Member Author

@zhiics zhiics Sep 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhliu Good point. Thanks.

DLTensor *tensor;
TVM_CCALL(TVMArrayAlloc(shape, 1, kDLFloat, 32, 1, ctx.device_type,
ctx.device_id, &tensor));
device_storage_pool_[it.first][pit.first] = tensor;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above.

@@ -482,27 +136,28 @@ void GraphRuntime::SetupStorage() {

void GraphRuntime::SetupOpExecs() {
op_execs_.resize(this->num_nodes());
std::vector<DLTensor> ids;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

runtime_host_module_.GetFunction(param.func_name, false);
if (pf == nullptr) {
for (const auto& it : runtime_device_mod_ctx_map_) {
pf = it.first->GetFunction(param.func_name, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if two mod have functions of same name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. This can be solved by using device information.

@@ -36,6 +36,39 @@ constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;

/*! \brief The default device allocated to an operator */
constexpr DLDeviceType kDLDefaultDevice = kDLCPU;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by this default , do you mean the fallback device?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackwish Yes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not appear here, but instead we should pass fallback device as argument or put it in a context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen +1

@tqchen
Copy link
Member

tqchen commented Sep 9, 2018

@srkreddy1238 @nishi-t @eqy please take a bit time to review this

@tqchen tqchen self-assigned this Sep 9, 2018
if ctx.device_type >= rpc_base.RPC_SESS_MASK:
raise RuntimeError(
"rpc is not supported for heterogeneous execution yet.")
if ctx == tvm.cpu(ctx.device_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this code segment, are we assuming that, considering the subgraph of the whole graph, each target device will only be used by one subgraph at most?

Copy link
Member Author

@zhiics zhiics Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackwish Not really. The "subgraph" is actually just a fused node here. For example, we annotate the graph and provide the context information for each node. During fusion, nodes with different context will not be grouped together. Instead, different ops marked with the same device/context will be compiled into the same binary.

Copy link
Contributor

@zhenhuaw-me zhenhuaw-me left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remarkable work!

btw, is type casting used a bit frequent?

@tqchen
Copy link
Member

tqchen commented Sep 11, 2018

Some hight level comments:

  • Please rebase against the master and add changes on top of current graph runtime, so that the diff is more clear(currently it moves all implementation to the header and removes the previous changes of NDArray)
  • Please document the new context support format, specifically, what is the serialization meta-data is necessary (e.g. context_id)
  • Enhance the graph runtime create to simply create a runtime that passes a list of context, which gives the contexts that are necessary to assign each one. Maybe have a good cpu default
  • We don't need to pass in modules for different device separately, instead, we should pack them into a single module, if there is something that is not supported by TVM, we should add support to it

@zhiics
Copy link
Member Author

zhiics commented Sep 11, 2018

@tqchen Sorry. Could you please elaborate more how we can pack them into a single module? It seems that tvm.build only takes one target. Or do you mean we need to pack the modules after they are generated for different devices? I think both of them are not clear to me.

@tqchen
Copy link
Member

tqchen commented Sep 11, 2018

you might find some insight by checkout the implementation of the code here: https://github.com/dmlc/tvm/blob/master/python/tvm/build_module.py#L503

build can take in list of LoweredFunc and build a module, either a host one, or device one(with host module). What we really have to do, is to delay the generation of the host module and return it as a list of LoweredFunc. The device module can simply build already for each device. Then we can build a single host module and import the device module from there.

If you are building a cpu, gpu mixed runtime, things can even be simpler. Just put everything into a list of LoweredFunc, and build with gpu as the target. The cpu code will be compiled as host module

python/tvm/contrib/graph_runtime.py Outdated Show resolved Hide resolved
@zhiics
Copy link
Member Author

zhiics commented Sep 13, 2018

Thanks for @tqchen 's suggestion. I looked into tvm.build. I think there are probably two ways to generate just one module.

  • Modify tvm.build somehow so that we only need to call it once with all lowered functions. I think the most important thing in this solution is to know which target a func (https://github.com/dmlc/tvm/blob/master/python/tvm/build_module.py#L456) belongs to. It means we probably need have a target attribute in loweredfunc, and this attribute could be set when we call MakeAPI. It also means we probably want to pass target to tvm.lower when a lowered function is generated.

  • Call tvm.build multiple times, each time a different target is passed into this function. There would be multiple binaries depending on the number of devices as what we are currently doing. In the end, we can probably implement a combine_modules(device_modules, host_module) method which combines all the devices modules to the host_module using import_module. This way, we don't need any change to the current tvm.build, tvm.lower, and tvm c++ pass APIs, except having a simple combining function. But we need to change grpah_compile.cc slightly to pass individual targets to tvm.build multiple times in the future when we work on the compiler PR.
    Update: This solution does have another problem because the current TVM doesn't support exporting hierarchical modules to binary files (https://github.com/dmlc/tvm/blob/master/src/codegen/codegen.cc#L40).

I quickly tested both methods locally and it seemed that both worked. I think both solutions are simple, but the second one needs fewer changes to the current code base. There might be some other solutions. @tqchen Do I miss something? or do you have any comments/advices? Thanks.

@tqchen
Copy link
Member

tqchen commented Sep 15, 2018

One possible approach

  • Allow tvm.build returns (list_of_lowered_host_funcs, device_module), two ways:
    • Add a flag delay_host_codegen, when set to True, return fhost, fdevice instead
    • Attach a private field mhost._lowered_funcs to the host module.
  • call tvm.build on each device, to get the device modules,
  • Now we have device module for each target, and bunch of list of lowered funcs, combine the list of LoweredFunc into a single list, call code gen to get the host module
  • Import all the device module into that host module

@zhiics
Copy link
Member Author

zhiics commented Sep 17, 2018

@tqchen Thanks for your suggestion. I think that postponing codegen should be able to solve the hierarchical module problem.

I would like to go for the approach of returning (fhost, device_module) because I don't want have loweredfunc as a member of the module class. I think it would be good to keep runtime and compilation separate. On the other hand, returning the (fhost, device_module) tuple also keeps code clean and simple.

I will update the PR soon to 1) move the definition of GraphRuntime back to cc file, 2) document the device column in the json file, and 3) clean the other files accordingly.

@zhiics zhiics force-pushed the runtime branch 3 times, most recently from b257811 to 63e3519 Compare September 19, 2018 16:36
@zhiics
Copy link
Member Author

zhiics commented Sep 19, 2018

@tqchen @yzhliu @srkreddy1238 Now we only have one graph_runtime interface for both C++ and Python. Currently we can pass either 4 or 5 arguments to graph_runtime.create in the C++ backend to keep support for Java and js because heterogeneous execution for them has not implemented yet. This is also documented in the code.

# collected.
mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None
if postpone_host_codegen:
return mdev, fhost
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return fhost, mdev host first, as device code can be none

mhost.import_module(mdev)
return mhost

def combine_modules(host_funcs, device_modules, target_host=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we can directly call codegen.build_module and do not need this additional level of abstraction for now, move the comment to the place that actually calls combine_mmodules

Copy link
Member Author

@zhiics zhiics Sep 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Yes, we can do that. The reason that I have this function here is because we will need it anyway in the compiler PR. I called this function in the unit test. I can removed it for now and add it back later.

src/runtime/graph/graph_runtime.cc Outdated Show resolved Hide resolved
"CPU should be the host processor for heterogenous execution, but"
" not found in ctx.")

device_type_arr = (ctypes.c_int * num_devices)(*device_types)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not pass in raw integer pointers into a function, it will not be RPC compatible. Instead, we can pass things in positionally, simply by

fcreate(json, libmod, ndevices, device_type0, device_id0, device_type1, device_id1); ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen I am aware of this, we can pass fcreat(json, libmod, ctx[0], ctx[1]). But it seems to me that we need to check the number of context although we usually we have at most 2 or 3 context.

if len(ctx) == 1: fcreate(json, libmod, ctx[0]) elif len(ctx) == 2: fcreate(json, libmod, ctx[0], ctx[1])

Do I miss something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in python you can just do fcreate(json, libmod, *ctx)

Copy link
Member Author

@zhiics zhiics Sep 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen btw, passing tvmcontext seems only working for python, not java and js, right? If so, we probably still need to use the way to pass (Json, mod, num_dev, dev_types, dev_ids).

@@ -277,10 +297,16 @@ class GraphRuntime : public ModuleNode {
this->LoadAttrs(reader, &param);
} else if (key == "control_deps") {
reader->Read(&control_deps);
} else if (key == "device_type") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_index? as we are doing virtual device_index to device type mapping


namespace tvm {
namespace runtime {
using StorageDeviceMap = std::unordered_map<uint32_t, DLDeviceType>;
using DeviceStoragePoolMap = std::unordered_map<size_t, NDArray>;
using ModuleContextMap = std::unordered_map<tvm::runtime::Module*, TVMContext>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove module context map

// to make sure homogeneous execution works correctly. It will be removed
// when we add the compiler pass as we will read the serialized value from
// json for all execution.
DLDeviceType device_type{kDLCPU};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us not put device_type here, instead, introduce a device_index column attribute, just like storage_id to indicate device assignment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the column do not exist, fallback to default(0 primary context passed in)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen So we use ctxs_[0] as the default, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is a reasonable choice

/*! \brief Execution context of all devices including the host. */
std::vector<TVMContext> ctxs_;
/*! \brief Common storage pool for each device. */
DeviceStoragePoolMap device_storage_pool_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is likely we can still just use vector for storage_pool, as long as the storage index sharing algorithm is device aware.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen I am not sure. I thought the storage index sharing algorithm was not device aware because we don't have multiple devices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be true, but we can still introduce post-processing to make them aware of device, the general principle is that we want the runtime to be as dumb as possible and let compiler do most of the jobs

Copy link
Member Author

@zhiics zhiics Sep 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Thanks. I agree we should keep runtime minimum. We can remove all maps if a storage id is guaranteed to be only assigned to one device. I can also use a vector to represent the sid to device_type mapping and use sid for indexing. I think we need this mapping to help memory allocation on the correct device more conveniently, right?

for (size_t i = 0; i < pool_entry_bytes.size(); ++i) {

// Allocate the space on each device.
for (const auto& pit : device_pool_entry_bytes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely we do not need to do it per device. We can still do it by each entries in the pool, but when we try to allocate the memory in the pool, we look up which device that entry belongs to

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen But the number of iterations would be still the same, right? Going by device seems more intuitive to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree going by device is more intuitive, but going by pool likely removes additional data structure we need( we only need a vector pool) and vector of the arrays. As per last comment, one of goal of runtime is to keep it as simple as possible

python/tvm/contrib/graph_runtime.py Show resolved Hide resolved
fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
return GraphModule(fcreate(graph_json_str, libmod, num_devices,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pas in ctx[0]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen So we actually pass fcreate(json, libmod, device_type_id[0], device_type_id[1], *device_type_id_others), right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah, that can be a solution


# Assume CPU is the host processor when there are multiple devices on
# a hardware platform.
if (num_devices > 1) and (cpu_ctx_index < 0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid doing context check for now and just use ctx[0] as primary context

// This for loop is very fast since there are usually only a couple of
// devices available on the same hardware.
for (const auto& cit : ctxs_) {
if (pool_entry[i].device_type == static_cast<int>(cit.device_type)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use std::find_if

@@ -508,8 +554,9 @@ void GraphRuntime::SetupOpExecs() {
uint32_t eid = this->entry_id(nid, index);
args.push_back(*(data_entry_[eid].operator->()));
}
CHECK_EQ(inode.op_type, "tvm_op")
<< "Can only take tvm_op as op";
CHECK(inode.op_type == "tvm_op" || inode.op_type == "device_copy_op")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are already using __copy for cross-device copy, we just need to make sure op_type is "tvm_op"

std::vector<TVMContext> ret(1);
if (args.num_args == 4) {
int dev_type = args[2];
ret[0].device_type = static_cast<DLDeviceType>(dev_type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just do push_back so there is no special logic getting involved

int dev_type = args[3 + i];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[3 + i + 1];
if (ctx.device_type == static_cast<int>(kDLCPU)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not do magic like this, let just just push things and build up the ctx

src/runtime/graph/graph_runtime.cc Show resolved Hide resolved
src/runtime/graph/graph_runtime.cc Outdated Show resolved Hide resolved
python/tvm/contrib/graph_runtime.py Outdated Show resolved Hide resolved
python/tvm/contrib/graph_runtime.py Outdated Show resolved Hide resolved
@tqchen
Copy link
Member

tqchen commented Sep 21, 2018

@tqchen tqchen added status: need update need update based on feedbacks and removed status: review in progress labels Sep 21, 2018
"""Build a function with arguments as signiture.
binds=None,
postpone_host_codegen=False):
"""Build a function with arguments as signiture. Code will be generated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signiture -> signature


def get_simplex_graph(host_dev_type, device_dev_type):
r""" Return the hand-crafted json object where only one copy node is
inserted. Tis node copies data from the target device to cpu.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tis?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume it's probably a typo on "the")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmoreau89 Thanks. It was a typo. I was trying to say "This".

Copy link
Contributor

@tmoreau89 tmoreau89 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, excellent work! Thank you for providing well written examples. This will open a lot of interesting work on heterogeneous execution on CPU+GPU or CPU+FPGA systems.

@tqchen
Copy link
Member

tqchen commented Sep 21, 2018

@zhiics please address the review comments

@zhiics zhiics force-pushed the runtime branch 2 times, most recently from 9683264 to f12e36e Compare September 21, 2018 16:36
Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some followup comments

raise ValueError("ctx has to be the type of TVMContext or a list "
"of TVMContext")
if cur_ctx.device_type >= rpc_base.RPC_SESS_MASK:
ctx[0], ctx[i] = ctx[i], ctx[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this swapping? can we just remove it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Sorry. I think I misunderstood RPC here. I thought we just need one of them to be remote. So I put it as the first one. I updated it. Please take another look and see if it makes sense. Thanks.

device_type = device_type % rpc_base.RPC_SESS_MASK
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = ctx[0].device_type % rpc_base.RPC_SESS_MASK
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do the RPC session stripping for all the context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to assert that all the context are remote and belongs to the same session


# ctx[0] is used as the primary/fallback context. All other ones are used
# as device context for heterogeneous execution.
device_type_id = [x for c in ctx[1:] for x in [c.device_type, c.device_id]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider use for loop to populates this, since we need to strip of rpc sess mask from all of them, maybe some of the logic need to be put into loops


fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, ctx[0].device_type,
ctx[0].device_id, *device_type_id))
return GraphModule(fcreate(graph_json_str, libmod, device_type_id[0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can just do *device_type_id

@tqchen tqchen added status: accepted and removed status: need update need update based on feedbacks labels Sep 22, 2018
@tqchen tqchen merged commit 7c3ec7d into apache:master Sep 22, 2018
@tqchen
Copy link
Member

tqchen commented Sep 22, 2018

Thanks @zhiics @yzhliu @jackwish @tmoreau89 , this is merged

@zhiics zhiics deleted the runtime branch September 22, 2018 03:33
@zhiics zhiics restored the runtime branch September 22, 2018 03:33
@zhiics zhiics deleted the runtime branch September 22, 2018 15:48
FrozenGene pushed a commit to FrozenGene/tvm that referenced this pull request Dec 27, 2018
@ZihengJiang ZihengJiang mentioned this pull request Feb 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants