Skip to content

Commit

Permalink
fix rpc for microtvm
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Apr 14, 2021
1 parent 303a2f0 commit 75c1c0a
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 13 deletions.
24 changes: 21 additions & 3 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
}

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name,
out);
tvm_crt_error_t to_return =
FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
// For compatibility with C++
if (to_return == kTvmErrorFunctionNameNotFound) {
*out = NULL;
to_return = kTvmErrorNoError;
}
return to_return;
}

int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
Expand Down Expand Up @@ -343,7 +349,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
if (to_return == kTvmErrorFunctionNameNotFound) {
to_return = kTvmErrorNoError;
}

return to_return;
}

Expand Down Expand Up @@ -372,6 +377,15 @@ int TVMFuncFree(TVMFunctionHandle func) {

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

// Sends maximum transfer size for RPC.
int RPCGetTransferMaxSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
int* ret_type_codes) {
ret_value[0].v_int64 = TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES;
ret_type_codes[0] = kTVMArgInt;
return 0;
}

tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -412,6 +426,10 @@ tvm_crt_error_t TVMInitializeRuntime() {
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetTransferMaxSize", &RPCGetTransferMaxSize, 0);
}

if (error != kTvmErrorNoError) {
TVMPlatformMemoryFree(registry_backing_memory, dev);
TVMPlatformMemoryFree(func_registry_memory, dev);
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024

/*! \brief Maximum length of a PackedFunc function name. */
#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30

/*! Size of the global function for max RPC transfer, in bytes. */
#define TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES 2048

// #define TVM_CRT_FRAMER_ENABLE_LOGS

#endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_
2 changes: 1 addition & 1 deletion src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
}
}

uint8_t memory[512 * 1024];
uint8_t memory[2048 * 1024];

static char** g_argv = NULL;

Expand Down
47 changes: 39 additions & 8 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
}

/*!
* \brief Recive incoming packed seq from the stream.
* \brief Receive incoming packed seq from the stream.
* \return The received argments.
* \note The TVMArgs is available until we switchstate.
*/
Expand Down Expand Up @@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
*/
void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
TVMArgs args = RecvPackedSeq();

if (code == RPCCode::kException) {
// switch to the state before sending exception.
this->SwitchToState(kRecvPacketNumBytes);
Expand Down Expand Up @@ -801,14 +800,14 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyToRemote;

uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_EQ(nbytes, num_data_bytes);
uint64_t tensor_max_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size.";

uint64_t to_data = reinterpret_cast<uint64_t>(to->data);
uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<char*>(to->data) + to->byte_offset);
uint64_t shape_bytes = to->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) +
sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes +
sizeof(nbytes) + num_data_bytes;
sizeof(nbytes) + nbytes;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand Down Expand Up @@ -968,7 +967,10 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
/*!
* \brief param endpoint The client endpoint of the session.
*/
explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {}
explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {
// update max transfer size if not set already.
SetRPCMaxTransferSize();
}

// function overrides
PackedFuncHandle GetFunction(const std::string& name) final {
Expand All @@ -981,7 +983,20 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
}

void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes);
uint64_t block_size = (uint64_t)rpc_chunk_max_size_bytes_;
uint64_t block_count = 0;
uint64_t num_blocks = nbytes / block_size;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_to->byte_offset = block_count * block_size;
endpoint_->CopyToRemote(local_from_bytes, remote_to, block_size);
}

uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_to->byte_offset = block_count * block_size;
endpoint_->CopyToRemote(local_from_bytes, remote_to, remainder_bytes);
}
}

void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
Expand Down Expand Up @@ -1042,7 +1057,23 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
bool IsLocalSession() const final { return false; }

private:
void RPCMaxTransferRemoteReturnValue(TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
}

void SetRPCMaxTransferSize() {
PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetTransferMaxSize");
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeDefault;
return;
}
CallFunc(rpc_func, nullptr, nullptr, 0,
[this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); });
}

std::shared_ptr<RPCEndpoint> endpoint_;
int64_t rpc_chunk_max_size_bytes_;
};

std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ const int kRPCSuccess = kRPCMagic + 0;
// cannot found matched key in server
const int kRPCMismatch = kRPCMagic + 2;

// When tvm.rpc.server.GetTransferMaxSize global function is not registered.
const int kRPCMaxTransferSizeDefault = 128000;

/*! \brief Enumeration code for the RPC tracker */
enum class TrackerCode : int {
kFail = -1,
Expand Down

0 comments on commit 75c1c0a

Please sign in to comment.