From 75c1c0adcffb4d6c9ff55718a0d798cc1d3f936c Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 30 Mar 2021 13:16:50 -0700 Subject: [PATCH] fix rpc for microtvm --- src/runtime/crt/common/crt_runtime_api.c | 24 ++++++++++-- src/runtime/crt/host/crt_config.h | 5 ++- src/runtime/crt/host/main.cc | 2 +- src/runtime/rpc/rpc_endpoint.cc | 47 ++++++++++++++++++++---- src/runtime/rpc/rpc_endpoint.h | 3 ++ 5 files changed, 68 insertions(+), 13 deletions(-) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c8044b49a8d0..c53c8cad8119 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -298,8 +298,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index } int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { - return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, - out); + tvm_crt_error_t to_return = + FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out); + // For compatibility with C++ + if (to_return == kTvmErrorFunctionNameNotFound) { + *out = NULL; + to_return = kTvmErrorNoError; + } + return to_return; } int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, @@ -343,7 +349,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r if (to_return == kTvmErrorFunctionNameNotFound) { to_return = kTvmErrorNoError; } - return to_return; } @@ -372,6 +377,15 @@ int TVMFuncFree(TVMFunctionHandle func) { int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_code); + +// Sends maximum transfer size for RPC. +int RPCGetTransferMaxSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value, + int* ret_type_codes) { + ret_value[0].v_int64 = TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES; + ret_type_codes[0] = kTVMArgInt; + return 0; +} + tvm_crt_error_t TVMInitializeRuntime() { int idx = 0; tvm_crt_error_t error = kTvmErrorNoError; @@ -412,6 +426,10 @@ tvm_crt_error_t TVMInitializeRuntime() { error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0); } + if (error == kTvmErrorNoError) { + error = TVMFuncRegisterGlobal("tvm.rpc.server.GetTransferMaxSize", &RPCGetTransferMaxSize, 0); + } + if (error != kTvmErrorNoError) { TVMPlatformMemoryFree(registry_backing_memory, dev); TVMPlatformMemoryFree(func_registry_memory, dev); diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/crt/host/crt_config.h index 109abaf04083..e6987d96bb84 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/crt/host/crt_config.h @@ -46,11 +46,14 @@ #define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 /*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000 +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 +/*! Size of the global function for max RPC transfer, in bytes. */ +#define TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES 2048 + // #define TVM_CRT_FRAMER_ENABLE_LOGS #endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index e64455417928..c56d3fb3768a 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { } } -uint8_t memory[512 * 1024]; +uint8_t memory[2048 * 1024]; static char** g_argv = NULL; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b5768146b3f7..48e403384f33 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } /*! - * \brief Recive incoming packed seq from the stream. + * \brief Receive incoming packed seq from the stream. * \return The received argments. * \note The TVMArgs is available until we switchstate. */ @@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { */ void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { TVMArgs args = RecvPackedSeq(); - if (code == RPCCode::kException) { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); @@ -801,14 +800,14 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyToRemote; - uint64_t num_data_bytes = static_cast(GetDataSize(*to)); - ICHECK_EQ(nbytes, num_data_bytes); + uint64_t tensor_max_size_bytes = static_cast(GetDataSize(*to)); + ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size."; - uint64_t to_data = reinterpret_cast(to->data); + uint64_t to_data = reinterpret_cast(static_cast(to->data) + to->byte_offset); uint64_t shape_bytes = to->ndim * sizeof(int64_t); uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) + sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes + - sizeof(nbytes) + num_data_bytes; + sizeof(nbytes) + nbytes; handler_->Write(packet_nbytes); handler_->Write(code); @@ -968,7 +967,10 @@ class RPCClientSession : public RPCSession, public DeviceAPI { /*! * \brief param endpoint The client endpoint of the session. */ - explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) {} + explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) { + // update max transfer size if not set already. + SetRPCMaxTransferSize(); + } // function overrides PackedFuncHandle GetFunction(const std::string& name) final { @@ -981,7 +983,20 @@ class RPCClientSession : public RPCSession, public DeviceAPI { } void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final { - endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes); + uint64_t block_size = (uint64_t)rpc_chunk_max_size_bytes_; + uint64_t block_count = 0; + uint64_t num_blocks = nbytes / block_size; + + for (block_count = 0; block_count < num_blocks; block_count++) { + remote_to->byte_offset = block_count * block_size; + endpoint_->CopyToRemote(local_from_bytes, remote_to, block_size); + } + + uint64_t remainder_bytes = nbytes % block_size; + if (remainder_bytes != 0) { + remote_to->byte_offset = block_count * block_size; + endpoint_->CopyToRemote(local_from_bytes, remote_to, remainder_bytes); + } } void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final { @@ -1042,7 +1057,23 @@ class RPCClientSession : public RPCSession, public DeviceAPI { bool IsLocalSession() const final { return false; } private: + void RPCMaxTransferRemoteReturnValue(TVMArgs args) { + // Use args[1] as return value, args[0] is tcode + rpc_chunk_max_size_bytes_ = (int64_t)args[1]; + } + + void SetRPCMaxTransferSize() { + PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetTransferMaxSize"); + if (rpc_func == nullptr) { + rpc_chunk_max_size_bytes_ = kRPCMaxTransferSizeDefault; + return; + } + CallFunc(rpc_func, nullptr, nullptr, 0, + [this](TVMArgs args) { RPCMaxTransferRemoteReturnValue(args); }); + } + std::shared_ptr endpoint_; + int64_t rpc_chunk_max_size_bytes_; }; std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index cd3c9b2bec72..1fcdcf6400ac 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -48,6 +48,9 @@ const int kRPCSuccess = kRPCMagic + 0; // cannot found matched key in server const int kRPCMismatch = kRPCMagic + 2; +// When tvm.rpc.server.GetTransferMaxSize global function is not registered. +const int kRPCMaxTransferSizeDefault = 128000; + /*! \brief Enumeration code for the RPC tracker */ enum class TrackerCode : int { kFail = -1,