Skip to content

Commit

Permalink
address comments and fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 16, 2021
1 parent 8ec0fea commit 10045f3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
11 changes: 7 additions & 4 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
}
}

uint8_t memory[1024 * 1024];
uint8_t memory[512 * 1024];

static char** g_argv = NULL;

Expand All @@ -135,10 +135,13 @@ int main(int argc, char** argv) {
CHECK_EQ(TVMGraphExecutorModule_Register(), kTvmErrorNoError,
"failed to register GraphExecutor TVMModule");
#endif

int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, 0);

int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
(TVMFunctionHandle)&testonly_reset_server, 0);
if (error) {
fprintf(stderr, "utvm runtime: internal error (error#: %d) registering global packedfunc; exiting\n", error);
fprintf(stderr,
"utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n",
error);
return 2;
}

Expand Down
39 changes: 27 additions & 12 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)

uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyToRemote: overflow in tensor size: (" << to->byte_offset << ", " << nbytes << ", "
<< tensor_total_size_bytes << ")";
<< "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
uint64_t packet_nbytes = overhead + nbytes;
Expand All @@ -822,8 +822,8 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes

uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from));
ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyFromRemote: overflow in tensor size: (" << from->byte_offset << ", " << nbytes << ", "
<< tensor_total_size_bytes << ")";
<< "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
uint64_t packet_nbytes = overhead;
Expand Down Expand Up @@ -983,16 +983,21 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
const uint64_t block_size = GetRPCMaxTransferSize() - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* from_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_to->byte_offset = block_count * block_size;
endpoint_->CopyToRemote(local_from_bytes, remote_to, block_size);
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_to->byte_offset = block_count * block_size;
endpoint_->CopyToRemote(local_from_bytes, remote_to, remainder_bytes);
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes);
}
}

Expand All @@ -1002,16 +1007,21 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
const uint64_t block_size = GetRPCMaxTransferSize() - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* to_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_from->byte_offset = block_count * block_size;
endpoint_->CopyFromRemote(remote_from, local_to_bytes, block_size);
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_from->byte_offset = block_count * block_size;
endpoint_->CopyFromRemote(remote_from, local_to_bytes, remainder_bytes);
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes);
}
}

Expand Down Expand Up @@ -1071,22 +1081,27 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
private:
void RPCMaxTransferRemoteReturnValue(TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
rpc_chunk_max_size_bytes_ = (uint64_t)args[1];
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
}

uint64_t GetRPCMaxTransferSize() {
if (rpc_chunk_max_size_bytes_ > 0) {
return (uint64_t)rpc_chunk_max_size_bytes_;
}

PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeBytesDefault;
rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
} else {
CallFunc(rpc_func, nullptr, nullptr, 0,
[this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); });
}
return rpc_chunk_max_size_bytes_;
return (uint64_t)rpc_chunk_max_size_bytes_;
}

std::shared_ptr<RPCEndpoint> endpoint_;
uint64_t rpc_chunk_max_size_bytes_;
int64_t rpc_chunk_max_size_bytes_ = -1;
};

std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
Expand Down

0 comments on commit 10045f3

Please sign in to comment.