Skip to content

Commit

Permalink
[RUNTIME][RPC] Update RPC runtime to allow remote module as arg (#4462)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Dec 3, 2019
1 parent 77bdd5f commit 279a8eb
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 67 deletions.
21 changes: 6 additions & 15 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from tvm.rpc import base as rpc_base
from . import debug_result

_DUMP_ROOT_PREFIX = "tvmdbg_"
Expand Down Expand Up @@ -60,25 +59,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
try:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create")
else:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)

ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
libmod = rpc_base._ModuleHandle(libmod)
try:
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.remote_create"
)
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)

Expand Down
7 changes: 3 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx)

if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
else:
fcreate = get_global_func("tvm.graph_runtime.create")

fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))

def get_device_ctx(libmod, ctx):
Expand Down
15 changes: 0 additions & 15 deletions src/runtime/graph/debug/graph_runtime_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <chrono>
#include <sstream>
#include "../graph_runtime.h"
#include "../../object_internal.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<< args.num_args;
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
});

TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate(
args[0], GetRef<Module>(mnode), contexts);
});

} // namespace runtime
} // namespace tvm
15 changes: 0 additions & 15 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <vector>

#include "graph_runtime.h"
#include "../object_internal.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(args[0], args[1], contexts);
});

TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);

const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime
} // namespace tvm
24 changes: 23 additions & 1 deletion src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RPCWrappedFunc {
}

void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv, &fwrap_);
sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_);
}
~RPCWrappedFunc() {
try {
Expand All @@ -55,6 +55,9 @@ class RPCWrappedFunc {
TVMArgs args,
TVMRetValue* rv);

static void* UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg);

// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
Expand Down Expand Up @@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc fwrap_;
};

void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg) {
if (arg.type_code() == kModuleHandle) {
Module mod = arg;
std::string tkey = mod->type_key();
CHECK_EQ(tkey, "rpc")
<< "ValueError: Cannot pass a non-RPC module to remote";
auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index)
<< "ValueError: Cannot pass in module into a different remote session";
return rmod->module_handle();
} else {
LOG(FATAL) << "ValueError: Cannot pass type "
<< runtime::TypeCode2Str(arg.type_code())
<< " as an argument to the remote";
return nullptr;
}
}

void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
TVMArgs args,
TVMRetValue *rv) {
Expand Down
64 changes: 47 additions & 17 deletions src/runtime/rpc/rpc_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
return ctx;
}
// Send Packed sequence to writer.
//
// client_mode: whether we are in client mode.
//
// funwrap: auxiliary function to unwrap remote Object
// when it is provided, we need to unwrap objects.
//
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void SendPackedSeq(const TVMValue* arg_values,
const int* type_codes,
int n,
int num_args,
bool client_mode,
FUnwrapRemoteObject funwrap = nullptr,
bool return_ndarray = false) {
this->Write(n);
for (int i = 0; i < n; ++i) {
std::swap(client_mode_, client_mode);

this->Write(num_args);
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
if (tcode == kNDArrayContainer) tcode = kArrayHandle;
this->Write(tcode);
}

// Argument packing.
for (int i = 0; i < n; ++i) {
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
Expand All @@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
break;
}
case kFuncHandle:
case kModuleHandle:
case kModuleHandle: {
// always send handle in 64 bit.
uint64_t handle;
// allow pass module as argument to remote.
if (funwrap != nullptr) {
void* remote_handle = (*funwrap)(
rpc_sess_table_index_,
runtime::TVMArgValue(value, tcode));
handle = reinterpret_cast<uint64_t>(remote_handle);
} else {
CHECK(!client_mode_)
<< "Cannot directly pass remote object as argument";
handle = reinterpret_cast<uint64_t>(value.v_handle);
}
this->Write(handle);
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
Expand Down Expand Up @@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
}
std::swap(client_mode_, client_mode);
}

// Endian aware IO handling
Expand Down Expand Up @@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kHandle:
case kStr:
case kBytes:
case kModuleHandle:
case kTVMContext: {
this->RequestBytes(sizeof(TVMValue)); break;
}
case kFuncHandle:
case kModuleHandle: {
case kFuncHandle: {
CHECK(client_mode_)
<< "Only client can receive remote functions";
this->RequestBytes(sizeof(TVMValue)); break;
Expand Down Expand Up @@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
}
this->SwitchToState(kRecvCode);
Expand Down Expand Up @@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
this->Write(code);
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode);
}
Expand All @@ -734,22 +761,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kBytes) {
std::string* bytes = rv.ptr<std::string>();
TVMByteArray arr;
arr.data = bytes->c_str();
arr.size = bytes->length();
ret_value.v_handle = &arr;
ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kFuncHandle ||
rv.type_code() == kModuleHandle) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kNDArrayContainer) {
// always send handle in 64 bit.
CHECK(!client_mode_)
Expand All @@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
this->Write(code);
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
}

Expand Down Expand Up @@ -873,7 +900,7 @@ void RPCSession::Init() {
&reader_, &writer_, table_index_, name_, &remote_key_);
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
});
Expand Down Expand Up @@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
void RPCSession::CallFunc(void* h,
TVMArgs args,
TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

RPCCode code = RPCCode::kCallFunc;
handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(h);
handler_->Write(handle);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
handler_->SendPackedSeq(
args.values, args.type_codes, args.num_args, true, funwrap);
code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ enum class RPCCode : int {
kNDArrayFree
};

/*!
* \brief Function that unwraps a remote object to its handle.
* \param rpc_sess_table_index RPC session table index for validation.
* \param obj Handle to the object argument.
* \return The corresponding handle.
*/
typedef void* (*FUnwrapRemoteObject)(
int rpc_sess_table_index,
const TVMArgValue& obj);

/*!
* \brief Abstract channel interface used to create RPCSession.
*/
Expand Down Expand Up @@ -144,11 +154,13 @@ class RPCSession {
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
* \param funpwrap Function that takes a remote object and returns the raw handle.
* \param fwrap Wrapper function to turn Function/Module handle into real return.
*/
void CallFunc(RPCFuncHandle handle,
TVMArgs args,
TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap);
/*!
* \brief Copy bytes into remote array content.
Expand Down

0 comments on commit 279a8eb

Please sign in to comment.