-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Heterogeneous Runtime #1695
Heterogeneous Runtime #1695
Conversation
python/tvm/contrib/graph_runtime.py
Outdated
# CPU is always used as the host processor. Its device type is 1 as | ||
# defined in TVMContext and dlpack.h. The libmod_ctx is sorted according | ||
# to the device type field in TVMContext. It is used to guarantee that the | ||
# first lib and context in the array belong to CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not good at the first glance. rely on the device type number is a little bit tricky, and is it really necessary to make host device the head of list? - given the fact that you pass host_ctx
as an arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhliu Thanks. Yeah, I am also aware of this. The host_ctx
here is actually a bad name. It is actually the local context that deploys the module as used in graph_runtime.create(). I think it is not guaranteed to be the context for the host processors. I will change the name to local_ctx
.
Another option was to search the dictionary and check cpu context. Then we can put cpu related lib/device_type/device_id as the first element in each list and append each field for the other devices, like:
for lib, ctx in libmod_ctx.items():
if (ctx == tvm.cpu(ctx.device_id)):
libs/device_types/device_ids.insert(0, ...)
else:
libs/device_types/device_ids.append(...)
if device_types[0] != STR2MASK["cpu"]:
Raise....
I went for the sorting way because I think it is transparent to users and well documented. Actually, I can probably also add a check for the first element of device_types to make sure it is cpu.
src/runtime/c_runtime_api.cc
Outdated
@@ -73,7 +52,7 @@ class DeviceAPIManager { | |||
if (api_[type] != nullptr) return api_[type]; | |||
std::lock_guard<std::mutex> lock(mutex_); | |||
if (api_[type] != nullptr) return api_[type]; | |||
api_[type] = GetAPI(DeviceName(type), allow_missing); | |||
api_[type] = GetAPI(tvm::runtime::DeviceName(type), allow_missing); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change
src/runtime/graph/graph_runtime.cc
Outdated
StorageDeviceMap sid_dev_map; | ||
for (uint32_t nid = 0; nid < this->num_nodes(); ++nid) { | ||
const auto &inode = nodes_[nid]; | ||
for (const auto &e : inode.inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor issue, auto& var
instead
src/runtime/graph/graph_runtime.cc
Outdated
for (const auto &e : inode.inputs) { | ||
uint32_t eid = this->entry_id(e); | ||
uint32_t sid = attrs_.storage_id[eid]; | ||
sid_dev_map[sid] = nodes_[e.node_id].device; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK(sid_dev_map.count(sid) == 0)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhliu I think you probably meant:
CHECK(sid_dev_map.count(sid) == 0 || sid_dev_map[sid] == nodes_[e.node_id].device) << "Cannot assign the storage id to multiple devices "
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
src/runtime/graph/graph_runtime.cc
Outdated
size_t size = 1; | ||
for (int64_t sz : attrs_.shape[i]) { | ||
size *= static_cast<size_t>(sz); | ||
} | ||
CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; | ||
CHECK_GE(sid, 0) << "Do not support runtime shape op"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confusing, cast to uint and check >= 0?
src/runtime/graph/graph_runtime.cc
Outdated
} | ||
pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes); | ||
DLDeviceType dev_type = sid_dev_map[sid]; | ||
device_pool_entry_bytes[dev_type][sid] = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like there's no need to make dev_type
a key. it can be achieved from sid_dev_map
right? just feel better to keep data struct simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhliu Good point. Thanks.
src/runtime/graph/graph_runtime.cc
Outdated
DLTensor *tensor; | ||
TVM_CCALL(TVMArrayAlloc(shape, 1, kDLFloat, 32, 1, ctx.device_type, | ||
ctx.device_id, &tensor)); | ||
device_storage_pool_[it.first][pit.first] = tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
src/runtime/graph/graph_runtime.cc
Outdated
@@ -482,27 +136,28 @@ void GraphRuntime::SetupStorage() { | |||
|
|||
void GraphRuntime::SetupOpExecs() { | |||
op_execs_.resize(this->num_nodes()); | |||
std::vector<DLTensor> ids; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
src/runtime/graph/graph_runtime.cc
Outdated
runtime_host_module_.GetFunction(param.func_name, false); | ||
if (pf == nullptr) { | ||
for (const auto& it : runtime_device_mod_ctx_map_) { | ||
pf = it.first->GetFunction(param.func_name, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if two mod have functions of same name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out. This can be solved by using device information.
include/tvm/runtime/device_api.h
Outdated
@@ -36,6 +36,39 @@ constexpr int kTempAllocaAlignment = 64; | |||
/*! \brief Maximum size that can be allocated on stack */ | |||
constexpr int kMaxStackAlloca = 1024; | |||
|
|||
/*! \brief The default device allocated to an operator */ | |||
constexpr DLDeviceType kDLDefaultDevice = kDLCPU; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by this default
, do you mean the fallback device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jackwish Yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not appear here, but instead we should pass fallback device as argument or put it in a context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen +1
@srkreddy1238 @nishi-t @eqy please take a bit time to review this |
python/tvm/contrib/graph_runtime.py
Outdated
if ctx.device_type >= rpc_base.RPC_SESS_MASK: | ||
raise RuntimeError( | ||
"rpc is not supported for heterogeneous execution yet.") | ||
if ctx == tvm.cpu(ctx.device_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this code segment, are we assuming that, considering the subgraph of the whole graph, each target device will only be used by one subgraph at most?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jackwish Not really. The "subgraph" is actually just a fused node here. For example, we annotate the graph and provide the context information for each node. During fusion, nodes with different context will not be grouped together. Instead, different ops marked with the same device/context will be compiled into the same binary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remarkable work!
btw, is type casting used a bit frequent?
Some hight level comments:
|
@tqchen Sorry. Could you please elaborate more how we can pack them into a single module? It seems that tvm.build only takes one target. Or do you mean we need to pack the modules after they are generated for different devices? I think both of them are not clear to me. |
you might find some insight by checkout the implementation of the code here: https://github.com/dmlc/tvm/blob/master/python/tvm/build_module.py#L503 build can take in list of LoweredFunc and build a module, either a host one, or device one(with host module). What we really have to do, is to delay the generation of the host module and return it as a list of LoweredFunc. The device module can simply build already for each device. Then we can build a single host module and import the device module from there. If you are building a cpu, gpu mixed runtime, things can even be simpler. Just put everything into a list of LoweredFunc, and build with gpu as the target. The cpu code will be compiled as host module |
Thanks for @tqchen 's suggestion. I looked into tvm.build. I think there are probably two ways to generate just one module.
I quickly tested both methods locally and it seemed that both worked. I think both solutions are simple, but the second one needs fewer changes to the current code base. There might be some other solutions. @tqchen Do I miss something? or do you have any comments/advices? Thanks. |
One possible approach
|
@tqchen Thanks for your suggestion. I think that postponing codegen should be able to solve the hierarchical module problem. I would like to go for the approach of returning (fhost, device_module) because I don't want have loweredfunc as a member of the module class. I think it would be good to keep runtime and compilation separate. On the other hand, returning the (fhost, device_module) tuple also keeps code clean and simple. I will update the PR soon to 1) move the definition of |
b257811
to
63e3519
Compare
@tqchen @yzhliu @srkreddy1238 Now we only have one graph_runtime interface for both C++ and Python. Currently we can pass either 4 or 5 arguments to graph_runtime.create in the C++ backend to keep support for Java and js because heterogeneous execution for them has not implemented yet. This is also documented in the code. |
python/tvm/build_module.py
Outdated
# collected. | ||
mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None | ||
if postpone_host_codegen: | ||
return mdev, fhost |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return fhost, mdev host first, as device code can be none
python/tvm/build_module.py
Outdated
mhost.import_module(mdev) | ||
return mhost | ||
|
||
def combine_modules(host_funcs, device_modules, target_host=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we can directly call codegen.build_module and do not need this additional level of abstraction for now, move the comment to the place that actually calls combine_mmodules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Yes, we can do that. The reason that I have this function here is because we will need it anyway in the compiler PR. I called this function in the unit test. I can removed it for now and add it back later.
python/tvm/contrib/graph_runtime.py
Outdated
"CPU should be the host processor for heterogenous execution, but" | ||
" not found in ctx.") | ||
|
||
device_type_arr = (ctypes.c_int * num_devices)(*device_types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not pass in raw integer pointers into a function, it will not be RPC compatible. Instead, we can pass things in positionally, simply by
fcreate(json, libmod, ndevices, device_type0, device_id0, device_type1, device_id1); ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen I am aware of this, we can pass fcreat(json, libmod, ctx[0], ctx[1]). But it seems to me that we need to check the number of context although we usually we have at most 2 or 3 context.
if len(ctx) == 1: fcreate(json, libmod, ctx[0]) elif len(ctx) == 2: fcreate(json, libmod, ctx[0], ctx[1])
Do I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in python you can just do fcreate(json, libmod, *ctx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen btw, passing tvmcontext seems only working for python, not java and js, right? If so, we probably still need to use the way to pass (Json, mod, num_dev, dev_types, dev_ids).
src/runtime/graph/graph_runtime.cc
Outdated
@@ -277,10 +297,16 @@ class GraphRuntime : public ModuleNode { | |||
this->LoadAttrs(reader, ¶m); | |||
} else if (key == "control_deps") { | |||
reader->Read(&control_deps); | |||
} else if (key == "device_type") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_index? as we are doing virtual device_index to device type mapping
src/runtime/graph/graph_runtime.cc
Outdated
|
||
namespace tvm { | ||
namespace runtime { | ||
using StorageDeviceMap = std::unordered_map<uint32_t, DLDeviceType>; | ||
using DeviceStoragePoolMap = std::unordered_map<size_t, NDArray>; | ||
using ModuleContextMap = std::unordered_map<tvm::runtime::Module*, TVMContext>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove module context map
src/runtime/graph/graph_runtime.cc
Outdated
// to make sure homogeneous execution works correctly. It will be removed | ||
// when we add the compiler pass as we will read the serialized value from | ||
// json for all execution. | ||
DLDeviceType device_type{kDLCPU}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us not put device_type here, instead, introduce a device_index column attribute, just like storage_id to indicate device assignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the column do not exist, fallback to default(0 primary context passed in)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen So we use ctxs_[0] as the default, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is a reasonable choice
src/runtime/graph/graph_runtime.cc
Outdated
/*! \brief Execution context of all devices including the host. */ | ||
std::vector<TVMContext> ctxs_; | ||
/*! \brief Common storage pool for each device. */ | ||
DeviceStoragePoolMap device_storage_pool_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is likely we can still just use vector for storage_pool, as long as the storage index sharing algorithm is device aware.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen I am not sure. I thought the storage index sharing algorithm was not device aware because we don't have multiple devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That could be true, but we can still introduce post-processing to make them aware of device, the general principle is that we want the runtime to be as dumb as possible and let compiler do most of the jobs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Thanks. I agree we should keep runtime minimum. We can remove all maps if a storage id is guaranteed to be only assigned to one device. I can also use a vector to represent the sid to device_type mapping and use sid for indexing. I think we need this mapping to help memory allocation on the correct device more conveniently, right?
src/runtime/graph/graph_runtime.cc
Outdated
for (size_t i = 0; i < pool_entry_bytes.size(); ++i) { | ||
|
||
// Allocate the space on each device. | ||
for (const auto& pit : device_pool_entry_bytes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely we do not need to do it per device. We can still do it by each entries in the pool, but when we try to allocate the memory in the pool, we look up which device that entry belongs to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen But the number of iterations would be still the same, right? Going by device seems more intuitive to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree going by device is more intuitive, but going by pool likely removes additional data structure we need( we only need a vector pool) and vector of the arrays. As per last comment, one of goal of runtime is to keep it as simple as possible
python/tvm/contrib/graph_runtime.py
Outdated
fcreate = get_global_func("tvm.graph_runtime.create") | ||
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx) | ||
return GraphModule(fcreate(graph_json_str, libmod, num_devices, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just pas in ctx[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen So we actually pass fcreate(json, libmod, device_type_id[0], device_type_id[1], *device_type_id_others), right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yah, that can be a solution
python/tvm/contrib/graph_runtime.py
Outdated
|
||
# Assume CPU is the host processor when there are multiple devices on | ||
# a hardware platform. | ||
if (num_devices > 1) and (cpu_ctx_index < 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid doing context check for now and just use ctx[0] as primary context
src/runtime/graph/graph_runtime.cc
Outdated
// This for loop is very fast since there are usually only a couple of | ||
// devices available on the same hardware. | ||
for (const auto& cit : ctxs_) { | ||
if (pool_entry[i].device_type == static_cast<int>(cit.device_type)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use std::find_if
src/runtime/graph/graph_runtime.cc
Outdated
@@ -508,8 +554,9 @@ void GraphRuntime::SetupOpExecs() { | |||
uint32_t eid = this->entry_id(nid, index); | |||
args.push_back(*(data_entry_[eid].operator->())); | |||
} | |||
CHECK_EQ(inode.op_type, "tvm_op") | |||
<< "Can only take tvm_op as op"; | |||
CHECK(inode.op_type == "tvm_op" || inode.op_type == "device_copy_op") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are already using __copy for cross-device copy, we just need to make sure op_type is "tvm_op"
src/runtime/graph/graph_runtime.cc
Outdated
std::vector<TVMContext> ret(1); | ||
if (args.num_args == 4) { | ||
int dev_type = args[2]; | ||
ret[0].device_type = static_cast<DLDeviceType>(dev_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just do push_back so there is no special logic getting involved
src/runtime/graph/graph_runtime.cc
Outdated
int dev_type = args[3 + i]; | ||
ctx.device_type = static_cast<DLDeviceType>(dev_type); | ||
ctx.device_id = args[3 + i + 1]; | ||
if (ctx.device_type == static_cast<int>(kDLCPU)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not do magic like this, let just just push things and build up the ctx
@zhiics some final followup comments. @jackwish @srkreddy1238 @yzhliu @tmoreau89 please take a round of review and https://docs.tvm.ai/contribute/code_review.html#approve-and-request-changes-explicitly |
python/tvm/build_module.py
Outdated
"""Build a function with arguments as signiture. | ||
binds=None, | ||
postpone_host_codegen=False): | ||
"""Build a function with arguments as signiture. Code will be generated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
signiture -> signature
|
||
def get_simplex_graph(host_dev_type, device_dev_type): | ||
r""" Return the hand-crafted json object where only one copy node is | ||
inserted. Tis node copies data from the target device to cpu. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I assume it's probably a typo on "the")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tmoreau89 Thanks. It was a typo. I was trying to say "This".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, excellent work! Thank you for providing well written examples. This will open a lot of interesting work on heterogeneous execution on CPU+GPU or CPU+FPGA systems.
@zhiics please address the review comments |
9683264
to
f12e36e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some followup comments
python/tvm/contrib/graph_runtime.py
Outdated
raise ValueError("ctx has to be the type of TVMContext or a list " | ||
"of TVMContext") | ||
if cur_ctx.device_type >= rpc_base.RPC_SESS_MASK: | ||
ctx[0], ctx[i] = ctx[i], ctx[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the purpose of this swapping? can we just remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Sorry. I think I misunderstood RPC here. I thought we just need one of them to be remote. So I put it as the first one. I updated it. Please take another look and see if it makes sense. Thanks.
python/tvm/contrib/graph_runtime.py
Outdated
device_type = device_type % rpc_base.RPC_SESS_MASK | ||
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx) | ||
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create") | ||
device_type = ctx[0].device_type % rpc_base.RPC_SESS_MASK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to do the RPC session stripping for all the context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to assert that all the context are remote and belongs to the same session
python/tvm/contrib/graph_runtime.py
Outdated
|
||
# ctx[0] is used as the primary/fallback context. All other ones are used | ||
# as device context for heterogeneous execution. | ||
device_type_id = [x for c in ctx[1:] for x in [c.device_type, c.device_id]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider use for loop to populates this, since we need to strip of rpc sess mask from all of them, maybe some of the logic need to be put into loops
python/tvm/contrib/graph_runtime.py
Outdated
|
||
fcreate = get_global_func("tvm.graph_runtime.create") | ||
return GraphModule(fcreate(graph_json_str, libmod, ctx[0].device_type, | ||
ctx[0].device_id, *device_type_id)) | ||
return GraphModule(fcreate(graph_json_str, libmod, device_type_id[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can just do *device_type_id
combine modules for heterogeneous execution
Thanks @zhiics @yzhliu @jackwish @tmoreau89 , this is merged |
This is the first part of the PR #1688.
This PR only focuses on making the runtime be able to take heterogeneous graphs. Changes are mainly made for graph runtime c++ and python interfaces. Meanwhile, to test the execution, I manually created two simple graphs containing only addition, subtraction, and copy nodes. One test, test_simplex_data_transferring tests data transferring from GPU to CPU at runtime. The other one, test_duplex_data_transferring tests duplex data transferring back-and-forth between GPU and CPU.
A new column,
device_type
, is added to the json file, which indicates which device a node should be scheduled to. In this PR this column is manually created as part of a graph json file. This field will be also used to annotate the graph node in the compiler PR. The serialization of this column is similar to that of the dtype column in the current json file. Loading/Saveing json API need to be modified slightly to support this field.Tested the functionality on a MacBook with Intel CPU and Intel Graphics GPU for using the generated module in memory and exporting/importing it from the disk.
The next PR will focus on the compiler part to generate heterogeneous binaries to feed in the runtime. Major changes will be needed for the compiler.build interface. Another issue is the removal of
with target
statements in the high-level build interface.