Skip to content

Commit

Permalink
[RPC] Fix the multihop cpu case (apache#5522)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 8, 2020
1 parent 8e84f3d commit fbe298a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 13 deletions.
27 changes: 16 additions & 11 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,20 +390,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;

char* data_ptr;
auto* sess = GetServingSession();

if (ctx.device_type == kDLCPU) {
// When session is local, we can directly treat handle
// as the cpu pointer without allocating a temp space.
if (ctx.device_type == kDLCPU &&
sess->IsLocalSession() &&
DMLC_IO_NO_ENDIAN_SWAP) {
data_ptr = reinterpret_cast<char*>(handle) + offset;
// endian aware handling
if (!DMLC_IO_NO_ENDIAN_SWAP) {
char* temp = this->ArenaAlloc<char>(num_bytes);
std::memcpy(temp, data_ptr, num_bytes);
dmlc::ByteSwap(temp, elem_bytes, num_bytes / elem_bytes);
data_ptr = temp;
}
} else {
try {
data_ptr = this->ArenaAlloc<char>(num_bytes);
GetServingSession()->CopyFromRemote(
sess->CopyFromRemote(
reinterpret_cast<void*>(handle), offset,
data_ptr, 0,
num_bytes, ctx, type_hint);
Expand Down Expand Up @@ -440,8 +438,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
this->Read(&type_hint);

size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
auto* sess = GetServingSession();

if (ctx.device_type == kDLCPU) {
// When session is local, we can directly treat handle
// as the cpu pointer without allocating a temp space.
if (ctx.device_type == kDLCPU && sess->IsLocalSession()) {
char* dptr = reinterpret_cast<char*>(handle) + offset;
this->ReadArray(dptr, num_bytes);

Expand All @@ -457,7 +458,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
}

try {
GetServingSession()->CopyToRemote(
sess->CopyToRemote(
temp_data, 0,
reinterpret_cast<void*>(handle), offset,
num_bytes, ctx, type_hint);
Expand Down Expand Up @@ -1046,6 +1047,10 @@ class RPCClientSession : public RPCSession,
return this;
}

bool IsLocalSession() const final {
return false;
}

private:
std::shared_ptr<RPCEndpoint> endpoint_;
};
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/rpc/rpc_local_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class LocalSession : public RPCSession {

DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final;

bool IsLocalSession() const final {
return true;
}

protected:
/*!
* \brief Internal implementation of GetFunction.
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,18 @@ class RPCSession {
*/
virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0;

/*!
* \brief Whether the session is a local session and we can directly
* the data handle returned by the session and treat it as pointer
* to the local memory.
*
* This information is useful for RPC server to directly copy into the
* local memory without creating a temporary buffer.
*
* \return Whether it is a local session.
*/
virtual bool IsLocalSession() const = 0;

/*!
* \return The session table index of the session.
*/
Expand Down
9 changes: 7 additions & 2 deletions tests/python/unittest/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,13 @@ def test_rpc_remote_module():
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = te.create_schedule(B.op)

server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port)
server0 = rpc.Server("localhost", key="x0")
server1 = rpc.Server("localhost", key="x1")

client = rpc.connect(
server0.host, server0.port, key="x0",
session_constructor_args=[
"rpc.Connect", server1.host, server1.port, "x1"])

def check_remote(remote):
if not tvm.runtime.enabled("llvm"):
Expand Down

0 comments on commit fbe298a

Please sign in to comment.