Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] microtvm: fix RPC large transfer size issue #7838

Merged
merged 12 commits into from
Apr 22, 2021
Merged
24 changes: 14 additions & 10 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,9 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
RPCCode code = RPCCode::kCopyToRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
const uint64_t block_size = GetRPCMaxTransferSize() - overhead;
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size - overhead, 0) << "CopyToRemote: Invalid block size!";
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* from_bytes;
Expand All @@ -1004,7 +1006,9 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
RPCCode code = RPCCode::kCopyFromRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes);
const uint64_t block_size = GetRPCMaxTransferSize() - overhead;
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size - overhead, 0) << "CopyFromRemote: Invalid block size!";
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* to_bytes;
Expand Down Expand Up @@ -1079,12 +1083,6 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
bool IsLocalSession() const final { return false; }

private:
void RPCMaxTransferRemoteReturnValue(TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
}

uint64_t GetRPCMaxTransferSize() {
if (rpc_chunk_max_size_bytes_ > 0) {
return (uint64_t)rpc_chunk_max_size_bytes_;
Expand All @@ -1094,8 +1092,14 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
} else {
CallFunc(rpc_func, nullptr, nullptr, 0,
[this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); });
CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
ICHECK_GT(rpc_chunk_max_size_bytes_, 0)
<< "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_
<< ")";
});
}
return (uint64_t)rpc_chunk_max_size_bytes_;
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
* \param to DLTensor to copy.
* \param code RPCCode for this transfer.
* \param nbytes Number of bytes to transfer.
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
* \return The remote-copy packet overhead size.
*/
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);

Expand Down