diff --git a/benchmark/fluid/kube_gen_job.py b/benchmark/fluid/kube_gen_job.py index 39ba207fd96f71..9da8a69af1d7b6 100644 --- a/benchmark/fluid/kube_gen_job.py +++ b/benchmark/fluid/kube_gen_job.py @@ -49,7 +49,7 @@ def parse_args(): parser.add_argument( '--fluid', default=1, type=int, help='whether is fluid job') parser.add_argument( - '--rdma', action='store_ture', help='whether mount rdma libs') + '--rdma', action='store_true', help='whether mount rdma libs') parser.add_argument( '--disttype', default="pserver", diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index 9f6ce40ede2524..913e344d371ddf 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -21,7 +21,10 @@ limitations under the License. */ #include #include +#include #include +#include +#include #include "paddle/fluid/inference/analysis/graph_traits.h" #include "paddle/fluid/inference/analysis/node.h" diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc index 60f159da914051..dcee75cee50ede 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc @@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) { LOG(INFO) << graph.nodes.size(); } -} // analysis -} // inference -} // paddle +}; // namespace analysis +}; // namespace inference +}; // namespace paddle diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index f848a7d1add79c..9f67c989cca4a9 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" +#include #include +#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" + namespace paddle { namespace inference { namespace analysis { diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h index cd0d4fabaafe84..33517e57becdff 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h @@ -19,6 +19,8 @@ #pragma once +#include + #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/pass.h" diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc index 851c98bef305fa..817d32c92cdbdc 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc @@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) { LOG(INFO) << '\n' << graph.DotString(); } -} // analysis -} // inference -} // paddle +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 24ea9a4bae7132..153dca576bd673 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -50,7 +50,7 @@ struct DataTypeNamer { return dic_.at(x); } - const std::string &repr(size_t &hash) const { + const std::string &repr(size_t &hash) const { // NOLINT PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation"); return dic_.at(hash); } @@ -62,7 +62,9 @@ struct DataTypeNamer { SET_TYPE(float); } - std::unordered_map dic_; + std::unordered_map + dic_; }; #undef SET_TYPE diff --git a/paddle/fluid/inference/analysis/pass.h b/paddle/fluid/inference/analysis/pass.h index 5c89b1304d84ab..aa0e8667b5e4a9 100644 --- a/paddle/fluid/inference/analysis/pass.h +++ b/paddle/fluid/inference/analysis/pass.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.h b/paddle/fluid/inference/analysis/subgraph_splitter.h index ed90a0dcf31e15..a31afbe6933da8 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.h +++ b/paddle/fluid/inference/analysis/subgraph_splitter.h @@ -18,6 +18,8 @@ limitations under the License. */ #pragma once +#include + #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/node.h" diff --git a/paddle/fluid/inference/analysis/ut_helper.h b/paddle/fluid/inference/analysis/ut_helper.h index c86083d1215392..722fa99a48a5f2 100644 --- a/paddle/fluid/inference/analysis/ut_helper.h +++ b/paddle/fluid/inference/analysis/ut_helper.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include +#include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 37fcb5c50309db..e46c577cdae145 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -19,6 +19,9 @@ limitations under the License. */ #pragma once +#include +#include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/analysis/helper.h" @@ -58,7 +61,7 @@ class TRTConvertValidation { public: TRTConvertValidation() = delete; - TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) { + explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) { // create engine. engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_->InitNetwork(); diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index b9a66474c9afc2..cf20530513cf6c 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,7 @@ if(WITH_DISTRIBUTE) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor + selected_rows memory) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index f7ce7786874285..da9ca1a0c1d550 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { } bool RPCClient::Wait() { + VLOG(3) << "RPCClient begin Wait()" + << " req_count_:" << req_count_; if (req_count_ <= 0) { return true; } diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 361cc24b5ba11e..e73756d89004bc 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detail/grpc_server.h" - #include #include -using ::grpc::ServerAsyncResponseWriter; +#include "paddle/fluid/operators/detail/grpc_server.h" -DEFINE_int32(rpc_server_handle_send_threads, 20, - "Number of threads used to handle send at rpc server."); -DEFINE_int32(rpc_server_handle_get_threads, 20, - "Number of threads used to handle get at rpc server."); -DEFINE_int32(rpc_server_handle_prefetch_threads, 1, - "Number of threads used to handle prefetch at rpc server."); +using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { @@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH }; class RequestBase { public: explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, bool sync_mode, - const platform::DeviceContext* dev_ctx) + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) : service_(service), cq_(cq), - sync_mode_(sync_mode), status_(PROCESS), - dev_ctx_(dev_ctx) { + request_handler_(request_handler), + req_id_(req_id) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} - virtual void Process() { assert(false); } + virtual void Process() = 0; CallStatus Status() { return status_; } void SetStatus(CallStatus status) { status_ = status; } - virtual std::string GetReqName() { - assert(false); - return ""; - } + virtual std::string GetReqName() = 0; protected: ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; - const bool sync_mode_; CallStatus status_; - const platform::DeviceContext* dev_ctx_; + RequestHandler* request_handler_; + int req_id_; }; class RequestSend final : public RequestBase { public: explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, bool sync_mode, - framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx, int req_id) - : RequestBase(service, cq, sync_mode, dev_ctx), - queue_(queue), - responder_(&ctx_), - req_id_(req_id) { - if (sync_mode_) { - request_.reset(new VariableResponse(scope, dev_ctx_, false)); - } else { - request_.reset(new VariableResponse(scope, dev_ctx_, true)); - } + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { + request_.reset(new VariableResponse(request_handler->scope(), + request_handler->dev_ctx(), + !request_handler->sync_mode())); int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, @@ -87,12 +71,17 @@ class RequestSend final : public RequestBase { virtual ~RequestSend() {} - virtual std::string GetReqName() { return request_->Varname(); } + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + std::string varname = GetReqName(); + VLOG(3) << "RequestSend var_name:" << varname; - virtual void Process() { - std::string var_name = GetReqName(); - VLOG(3) << "RequestSend " << var_name; - queue_->Push(std::make_pair(var_name, request_)); + auto scope = request_->GetMutableLocalScope(); + auto invar = request_->GetVar(); + framework::Variable* outvar = nullptr; + + request_handler_->Handle(varname, scope, invar, &outvar); status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, @@ -102,105 +91,85 @@ class RequestSend final : public RequestBase { protected: sendrecv::VoidMessage reply_; std::shared_ptr request_; - ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; - int req_id_; }; class RequestGet final : public RequestBase { public: explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, bool sync_mode, - framework::Scope* scope, - const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue, - int req_id) - : RequestBase(service, cq, sync_mode, dev_ctx), - responder_(&ctx_), - scope_(scope), - queue_(queue), - req_id_(req_id) { + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary( method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id_))); + reinterpret_cast(static_cast(req_id))); } virtual ~RequestGet() {} - virtual std::string GetReqName() { return request_.varname(); } + std::string GetReqName() override { return request_.varname(); } - virtual void Process() { + void Process() override { // proc request. - std::string var_name = request_.varname(); - VLOG(3) << "RequestGet " << var_name; - auto* var = scope_->FindVar(var_name); + std::string varname = request_.varname(); + VLOG(3) << "RequestGet " << varname; + + auto scope = request_handler_->scope(); + auto invar = scope->FindVar(varname); + framework::Variable* outvar = nullptr; - if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); + request_handler_->Handle(varname, scope, invar, &outvar); + + if (outvar) { + SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), + &reply_); } status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, reinterpret_cast(static_cast(req_id_))); - - if (var_name == FETCH_BARRIER_MESSAGE) { - sendrecv::VariableMessage msg; - MessageWithName msg_with_name = std::make_pair(var_name, msg); - queue_->Push(msg_with_name); - } } protected: sendrecv::VariableMessage request_; ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - framework::Scope* scope_; - framework::BlockingQueue* queue_; - int req_id_; }; class RequestPrefetch final : public RequestBase { public: explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, bool sync_mode, - framework::Scope* scope, - const platform::DeviceContext* dev_ctx, - framework::Executor* executor, - framework::ProgramDesc* program, - framework::ExecutorPrepareContext* prefetch_ctx, - int req_id) - : RequestBase(service, cq, sync_mode, dev_ctx), + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_), - scope_(scope), - executor_(executor), - program_(program), - prefetch_ctx_(prefetch_ctx), - req_id_(req_id) { - // prefetch always create a new sub scope - request_.reset(new VariableResponse(scope, dev_ctx_, true)); + local_scope_(nullptr) { + request_.reset(new VariableResponse(request_handler->scope(), + request_handler->dev_ctx(), true)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(req_id_))); + reinterpret_cast(static_cast(req_id))); } virtual ~RequestPrefetch() {} - virtual std::string GetReqName() { return request_->Varname(); } + std::string GetReqName() override { return request_->Varname(); } - virtual void Process() { + void Process() override { // prefetch process... + std::string varname = request_->OutVarname(); + VLOG(3) << "RequestPrefetch " << varname; + + auto scope = request_->GetMutableLocalScope(); + auto invar = scope->FindVar(varname); + framework::Variable* outvar = nullptr; - std::string var_name = request_->OutVarname(); - VLOG(3) << "RequestPrefetch " << var_name; - auto var_desc = program_->Block(0).FindVar(var_name); - framework::Scope* local_scope = request_->GetMutableLocalScope(); - auto* var = local_scope->FindVar(var_name); - InitializeVariable(var, var_desc->GetType()); - executor_->RunPreparedContext(prefetch_ctx_, local_scope); + request_handler_->Handle(varname, scope, invar, &outvar); - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); + SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), + &reply_); status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, @@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase { std::shared_ptr request_; ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; - framework::Scope* scope_; - framework::Executor* executor_; - framework::ProgramDesc* program_; - framework::ExecutorPrepareContext* prefetch_ctx_; - int req_id_; + framework::Scope* local_scope_; }; -void AsyncGRPCServer::WaitClientGet(int count) { - int fetch_barriers = 0; - while (fetch_barriers < count) { - auto msg = var_get_queue_.Pop(); - if (msg.first == FETCH_BARRIER_MESSAGE) { - fetch_barriers++; - } - } -} - void AsyncGRPCServer::WaitServerReady() { + VLOG(3) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); + VLOG(3) << "AsyncGRPCServer WaitSeverReady"; } -void AsyncGRPCServer::RunSyncUpdate() { +void AsyncGRPCServer::StartServer() { ::grpc::ServerBuilder builder; - builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), + builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), &selected_port_); + builder.SetMaxSendMessageSize(std::numeric_limits::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.RegisterService(&service_); - cq_send_ = builder.AddCompletionQueue(); - cq_get_ = builder.AddCompletionQueue(); - cq_prefetch_ = builder.AddCompletionQueue(); + for (auto t : rpc_call_map_) { + rpc_cq_[t.first].reset(builder.AddCompletionQueue().release()); + } server_ = builder.BuildAndStart(); - LOG(INFO) << "Server listening on " << address_ + LOG(INFO) << "Server listening on " << bind_address_ << " selected port: " << selected_port_; - std::function send_register = std::bind( - &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); - std::function get_register = std::bind( - &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1); - std::function prefetch_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this, - std::placeholders::_1); + std::function f = + std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, + std::placeholders::_1, std::placeholders::_2); - for (int i = 0; i < kSendReqsBufSize; ++i) { - TryToRegisterNewSendOne(i); - } - for (int i = 0; i < kGetReqsBufSize; ++i) { - TryToRegisterNewGetOne(i); - } - for (int i = 0; i < kPrefetchReqsBufSize; ++i) { - TryToRegisterNewPrefetchOne(i); - } + for (auto& t : rpc_call_map_) { + auto& rpc_name = t.first; + auto& cq = rpc_cq_[rpc_name]; + auto threadnum = rpc_thread_num_[rpc_name]; + auto& reqs = rpc_reqs_[rpc_name]; - for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { - t_sends_.emplace_back( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_send_.get(), "cq_send", send_register))); - } - for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { - t_gets_.emplace_back( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_get_.get(), "cq_get", get_register))); - } - for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { - t_prefetchs_.emplace_back(new std::thread( - std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), - "cq_prefetch", prefetch_register))); + reqs.reserve(kRequestBufSize); + + for (int i = 0; i < kRequestBufSize; i++) { + TryToRegisterNewOne(rpc_name, i); + } + + for (int i = 0; i < threadnum; i++) { + rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind( + &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f))); + VLOG(3) << t.first << " creates threads!"; + } } + { std::lock_guard lock(this->mutex_ready_); ready_ = 1; } condition_ready_.notify_all(); + // wait server server_->Wait(); - for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { - t_sends_[i]->join(); - } - for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { - t_gets_[i]->join(); - } - for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { - t_prefetchs_[i]->join(); + + for (auto& t : rpc_threads_) { + auto& threads = t.second; + for (size_t i = 0; i < threads.size(); ++i) { + threads[i]->join(); + VLOG(3) << t.first << " threads ends!"; + } } } void AsyncGRPCServer::ShutdownQueue() { - std::unique_lock lock(cq_mutex_); - cq_send_->Shutdown(); - cq_get_->Shutdown(); - cq_prefetch_->Shutdown(); + for (auto& t : rpc_cq_) { + t.second->Shutdown(); + VLOG(3) << t.first << " shutdown!"; + } } -// This URL explains why shutdown is complicate: -void AsyncGRPCServer::ShutDown() { +void AsyncGRPCServer::ShutDownImpl() { + std::unique_lock lock(cq_mutex_); is_shut_down_ = true; ShutdownQueue(); + + VLOG(3) << "server_ shutdown!"; server_->Shutdown(); } -void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { +void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, + int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } - RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, - scope_, &var_recv_queue_, dev_ctx_, i); - send_reqs_[i] = static_cast(send); - VLOG(4) << "Create RequestSend status:" << send->Status(); -} -void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; - return; + VLOG(4) << "register send rpc_name:" << rpc_name + << ", handler:" << rpc_call_map_[kRequestSend]; + + auto& reqs = rpc_reqs_[rpc_name]; + auto& handler = rpc_call_map_[rpc_name]; + auto& cq = rpc_cq_[rpc_name]; + + RequestBase* b = nullptr; + if (rpc_name == kRequestSend) { + b = new RequestSend(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestGet) { + b = new RequestGet(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestPrefetch) { + b = new RequestPrefetch(&service_, cq.get(), handler, req_id); + } else { + PADDLE_ENFORCE(false, "not surpported rpc"); } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_, req_id); - get_reqs_[req_id] = static_cast(get); - VLOG(4) << "Create RequestGet status:" << get->Status(); -} -void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; - return; - } - RequestPrefetch* prefetch = new RequestPrefetch( - &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, - program_, prefetch_ctx_.get(), req_id); - prefetch_reqs_[req_id] = static_cast(prefetch); + reqs[req_id] = b; - VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); + VLOG(4) << "Create RequestSend status:" << b->Status(); } -// FIXME(typhoonzero): change cq_name to enum. void AsyncGRPCServer::HandleRequest( - ::grpc::ServerCompletionQueue* cq, const std::string& cq_name, - std::function TryToRegisterNewOne) { + ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, + std::function TryToRegisterNewOne) { void* tag = NULL; bool ok = false; while (true) { - VLOG(3) << "HandleRequest for " << cq_name << " wait Next"; + VLOG(3) << "HandleRequest " << rpc_name << " wait next"; if (!cq->Next(&tag, &ok)) { - LOG(INFO) << cq_name << " CompletionQueue shutdown!"; + LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!"; break; } - VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - int req_id = static_cast(reinterpret_cast(tag)); - if (sync_mode_) { - // FIXME(typhoonzero): de-couple the barriers with recv_op - if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); - if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); - VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; - } + int req_id = static_cast(reinterpret_cast(tag)); + VLOG(3) << "HandleRequest " << rpc_name << ", req_id:" << req_id + << " get next"; + auto& reqs = rpc_reqs_[rpc_name]; RequestBase* base = nullptr; { - std::lock_guard l(cq_mutex_); - if (cq_name == "cq_get") { - base = get_reqs_[req_id]; - } else if (cq_name == "cq_send") { - base = send_reqs_[req_id]; - } else if (cq_name == "cq_prefetch") { - base = prefetch_reqs_[req_id]; - } + PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize); + std::unique_lock lock(cq_mutex_); + base = reqs[req_id]; } + // reference: // https://github.com/tensorflow/tensorflow/issues/5596 // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I if (!ok) { - LOG(WARNING) << cq_name << " recv no regular event:argument name[" + LOG(WARNING) << "completion queue:" << rpc_name + << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(req_id); + TryToRegisterNewOne(rpc_name, req_id); delete base; continue; } + VLOG(3) << "queue id:" << rpc_name << ", req_id:" << req_id + << ", status:" << base->Status(); + switch (base->Status()) { case PROCESS: { base->Process(); - VLOG(4) << cq_name << " PROCESS status:" << base->Status(); break; } case FINISH: { - TryToRegisterNewOne(req_id); - VLOG(4) << cq_name << " FINISH status:" << base->Status(); + TryToRegisterNewOne(rpc_name, req_id); delete base; break; } @@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest( } } -void AsyncGRPCServer::WaitCond(int cond) { - std::unique_lock lock(this->barrier_mutex_); - barrier_condition_.wait(lock, - [=] { return this->barrier_cond_step_ == cond; }); -} - -void AsyncGRPCServer::SetCond(int cond) { - { - std::lock_guard lock(this->barrier_mutex_); - barrier_cond_step_ = cond; - } - barrier_condition_.notify_all(); -} - } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index bdff9801a92869..d1fcbc414f123c 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include #include // NOLINT #include @@ -28,6 +30,8 @@ limitations under the License. */ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/grpc_service.h" +#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" @@ -37,106 +41,48 @@ namespace paddle { namespace operators { namespace detail { -typedef std::pair> - ReceivedMessage; -typedef framework::BlockingQueue ReceivedQueue; - -typedef std::pair MessageWithName; class RequestBase; -class AsyncGRPCServer final { +class AsyncGRPCServer final : public RPCServer { public: - explicit AsyncGRPCServer(const std::string &address, bool sync_mode) - : address_(address), sync_mode_(sync_mode), ready_(0) {} - - ~AsyncGRPCServer() {} - void WaitServerReady(); - void RunSyncUpdate(); - - // functions to sync server barrier status. - void WaitCond(int cond); - void SetCond(int cond); - void WaitClientGet(int count); - - void SetScope(framework::Scope *scope) { scope_ = scope; } - - void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } - - void SetProgram(framework::ProgramDesc *program) { program_ = program; } - - void SetExecutor(framework::Executor *executor) { executor_ = executor; } - - void SetPrefetchPreparedCtx( - std::unique_ptr prepared) { - prefetch_ctx_.reset(prepared.release()); - } - - int GetSelectedPort() const { return selected_port_; } - - const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } + explicit AsyncGRPCServer(const std::string& address, int client_num) + : RPCServer(address, client_num), ready_(0) {} - void Push(const std::string &msg_name) { - this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); - } + virtual ~AsyncGRPCServer() {} + void WaitServerReady() override; + void StartServer() override; - void ShutDown(); + private: + void HandleRequest( + ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, + std::function TryToRegisterNewOne); - protected: - void HandleRequest(::grpc::ServerCompletionQueue *cq, - const std::string &cq_name, - std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(int req_id); - void TryToRegisterNewGetOne(int req_id); - void TryToRegisterNewPrefetchOne(int req_id); + void TryToRegisterNewOne(const std::string& rpc_name, int req_id); void ShutdownQueue(); + void ShutDownImpl() override; private: - static const int kSendReqsBufSize = 100; - static const int kGetReqsBufSize = 100; - static const int kPrefetchReqsBufSize = 10; + static const int kRequestBufSize = 100; std::mutex cq_mutex_; volatile bool is_shut_down_ = false; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; - std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; - - RequestBase *send_reqs_[kSendReqsBufSize]; - RequestBase *get_reqs_[kGetReqsBufSize]; - RequestBase *prefetch_reqs_[kPrefetchReqsBufSize]; GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; - std::string address_; - const bool sync_mode_; - framework::Scope *scope_; - const platform::DeviceContext *dev_ctx_; - - // received variable from RPC, operators fetch variable from this queue. - framework::BlockingQueue var_get_queue_; - // client send variable to this queue. - ReceivedQueue var_recv_queue_; - // condition of the sub program std::mutex barrier_mutex_; mutable int barrier_cond_step_; std::condition_variable barrier_condition_; - std::vector> t_sends_; - std::vector> t_gets_; - std::vector> t_prefetchs_; - - std::unique_ptr t_prefetch_; - - std::unique_ptr prefetch_ctx_; - framework::ProgramDesc *program_; - framework::Executor *executor_; - int selected_port_; - std::mutex mutex_ready_; std::condition_variable condition_ready_; + int ready_; + + std::map> rpc_cq_; + std::map>> rpc_threads_; + std::map> rpc_reqs_; }; }; // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 350a7ee1234da5..f97f638701cfb2 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -24,13 +24,16 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/detail/request_handler_impl.h" + namespace framework = paddle::framework; namespace platform = paddle::platform; namespace detail = paddle::operators::detail; USE_OP(lookup_table); -std::unique_ptr rpc_service_; +std::unique_ptr g_rpc_service; +std::unique_ptr g_req_handler; framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { auto root_block = program->MutableBlock(0); @@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } } -void StartServer(const std::string& endpoint) { - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true)); +void StartServer() { framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; @@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) { auto prepared = exe.Prepare(program, block->ID()); InitTensorsOnServer(&scope, &place, 10); - rpc_service_->SetProgram(&program); - rpc_service_->SetPrefetchPreparedCtx(std::move(prepared)); - rpc_service_->SetDevCtx(&ctx); - rpc_service_->SetScope(&scope); - rpc_service_->SetExecutor(&exe); + g_req_handler->SetProgram(&program); + g_req_handler->SetPrefetchPreparedCtx(std::move(prepared)); + g_req_handler->SetDevCtx(&ctx); + g_req_handler->SetScope(&scope); + g_req_handler->SetExecutor(&exe); + + g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get()); + g_req_handler->SetRPCServer(g_rpc_service.get()); + + std::thread server_thread( + std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); - rpc_service_->RunSyncUpdate(); + // FIXME(gongwb): don't use hard time. + sleep(10); + LOG(INFO) << "got nccl id and stop server..."; + g_rpc_service->ShutDown(); + server_thread.join(); } -TEST(PREFETCH, DISABLED_CPU) { - // start up a server instance backend - std::thread server_thread(StartServer, "127.0.0.1:8889"); - sleep(2); +TEST(PREFETCH, CPU) { + g_req_handler.reset(new detail::RequestPrefetchHandler(true)); + g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); + + std::thread server_thread(StartServer); + g_rpc_service->WaitServerReady(); + + detail::RPCClient client; + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + framework::Scope scope; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); - // create var on local scope - int64_t rows_numel = 5; - InitTensorsOnClient(&scope, &place, rows_numel); - std::string in_var_name("ids"); - std::string out_var_name("out"); - - auto client = detail::RPCClient::GetInstance(); - client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, - out_var_name); - client->Wait(); - - auto var = scope.Var(out_var_name); - auto value = var->GetMutable()->value(); - auto ptr = value.mutable_data(place); - - rpc_service_->ShutDown(); - server_thread.join(); - rpc_service_.reset(nullptr); - - for (int64_t i = 0; i < rows_numel; ++i) { - EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast(i * 2)); + { + // create var on local scope + int64_t rows_numel = 5; + InitTensorsOnClient(&scope, &place, rows_numel); + std::string in_var_name("ids"); + std::string out_var_name("out"); + + client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); + client.Wait(); + auto var = scope.Var(out_var_name); + auto value = var->GetMutable()->value(); + auto ptr = value.mutable_data(place); + + for (int64_t i = 0; i < rows_numel; ++i) { + EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast(i * 2)); + } } + + server_thread.join(); + LOG(INFO) << "begin reset"; + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); } diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h new file mode 100644 index 00000000000000..4bc5e7f10ee2a8 --- /dev/null +++ b/paddle/fluid/operators/detail/request_handler.h @@ -0,0 +1,127 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +constexpr char kRequestSend[] = "RequestSend"; +constexpr char kRequestGet[] = "RequestGet"; +constexpr char kRequestPrefetch[] = "RequestPrefetch"; + +class RPCServer; + +class RequestHandler { + public: + explicit RequestHandler(bool sync_mode) + : sync_mode_(sync_mode), + dev_ctx_(nullptr), + executor_(nullptr), + scope_(nullptr), + program_(nullptr), + rpc_server_(nullptr) {} + + virtual ~RequestHandler() {} + + // Set attributes. + void SetScope(framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + void SetProgram(framework::ProgramDesc* program) { program_ = program; } + void SetExecutor(framework::Executor* executor) { executor_ = executor; } + void SetPrefetchPreparedCtx( + std::unique_ptr prepared) { + prefetch_ctx_.reset(prepared.release()); + } + + // Used for async. + void SetGradToPreparedCtx( + std::unordered_map< + std::string, std::shared_ptr>* g) { + grad_to_prepared_ctx_ = g; + } + + void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } + + // Get attributes. + bool sync_mode() { return sync_mode_; } + framework::Scope* scope() { return scope_; } + const platform::DeviceContext* dev_ctx() { return dev_ctx_; } + framework::ExecutorPrepareContext* prefetch_ctx() { + return prefetch_ctx_.get(); + } + framework::ProgramDesc* program() { return program_; } + framework::Executor* executor() { return executor_; } + std::vector& sparse_vars() { return sparse_vars_; } + + // This function processes user's rpc request. + // The implemention is in request_handler_impl. + // example: + // std::string varname = request_.varname(); + // + // auto scope = request_handler_->scope(); + // auto invar = scope->FindVar(varname); + // framework::Variable* outvar = nullptr; + // + // request_handler_->Handle(varname, scope, invar, &outvar); + // if (outvar) { + // SerializeToByteBuffer(varname, outvar, + // *request_handler_->dev_ctx(), &reply_); + // } + virtual bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, + framework::Variable** outvar) = 0; + + protected: + const bool sync_mode_; + + const platform::DeviceContext* dev_ctx_; + framework::Executor* executor_; + framework::Scope* scope_; + framework::ProgramDesc* program_; + std::unique_ptr prefetch_ctx_; + + // Used for async. + std::unordered_map>* + grad_to_prepared_ctx_; + + // Record received sparse variables, so that + // we could reset those after execute optimize program + std::vector sparse_vars_; + RPCServer* rpc_server_; + + std::mutex sparse_var_mutex_; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc new file mode 100644 index 00000000000000..f16c06d52f4fb8 --- /dev/null +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/detail/request_handler_impl.h" +#include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/detail/variable_response.h" + +namespace paddle { +namespace operators { +namespace detail { + +bool RequestSendHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar) { + VLOG(4) << "RequestSendHandler:" << varname; + + // Async + if (!sync_mode_) { + try { + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), + scope); + } catch (std::exception& e) { + LOG(ERROR) << "async: run sub program error " << e.what(); + return false; + } + return true; + } + + // Sync + if (varname == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "sync: recv batch barrier message"; + rpc_server_->IncreaseBatchBarrier(kRequestSend); + } else { + VLOG(3) << "sync: received var_name: " << varname; + if (sync_mode_) { + rpc_server_->WaitCond(kRequestSend); + } + + if (invar == nullptr) { + LOG(ERROR) << "sync: Can not find server side var: " << varname; + PADDLE_THROW("sync: Can not find server side var"); + return false; + } + + if (invar->IsType()) { + std::unique_lock lock(sparse_var_mutex_); + sparse_vars_.push_back(invar); + } + } + + return true; +} + +bool RequestGetHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar) { + VLOG(4) << "RequestGetHandler:" << varname; + + if (varname != FETCH_BARRIER_MESSAGE) { + if (sync_mode_) { + rpc_server_->WaitCond(kRequestGet); + } + *outvar = scope_->FindVar(varname); + return true; + } + + // FETCH_BARRIER_MESSAGE + if (sync_mode_) { + VLOG(3) << "sync: recv fetch barrier message"; + rpc_server_->IncreaseBatchBarrier(kRequestGet); + } + + return true; +} + +bool RequestPrefetchHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar) { + VLOG(4) << "RequestPrefetchHandler " << varname; + + auto var_desc = program_->Block(0).FindVar(varname); + *outvar = scope->FindVar(varname); + InitializeVariable(*outvar, var_desc->GetType()); + executor_->RunPreparedContext(prefetch_ctx_.get(), scope); + + return true; +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h new file mode 100644 index 00000000000000..8d0c62232b68ad --- /dev/null +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -0,0 +1,64 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" + +namespace paddle { +namespace operators { +namespace detail { + +class RequestSendHandler final : public RequestHandler { + public: + explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} + virtual ~RequestSendHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar) override; +}; + +class RequestGetHandler final : public RequestHandler { + public: + explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} + virtual ~RequestGetHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar) override; +}; + +class RequestPrefetchHandler final : public RequestHandler { + public: + explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} + virtual ~RequestPrefetchHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar) override; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc new file mode 100644 index 00000000000000..448763372a8c22 --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/operators/detail/rpc_server.h" + +namespace paddle { +namespace operators { +namespace detail { + +void RPCServer::ShutDown() { + LOG(INFO) << "RPCServer ShutDown "; + ShutDownImpl(); + + exit_flag_ = true; + barrier_cond_.notify_all(); + rpc_cond_.notify_all(); +} + +void RPCServer::SavePort() const { + auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); + std::ofstream port_file; + port_file.open(file_path); + port_file << selected_port_; + port_file.close(); + VLOG(4) << "selected port written to " << file_path; +} + +void RPCServer::WaitBarrier(const std::string& rpc_name) { + std::unique_lock lock(this->mutex_); + barrier_cond_.wait(lock, [=] { + return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); + }); + + VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name]; +} + +void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { + VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; + int b = 0; + { + std::unique_lock lock(mutex_); + b = ++barrier_counter_[rpc_name]; + } + + VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name + << ", barrier_count:" << b << ", fan_in" << client_num_; + + if (b >= client_num_) { + barrier_cond_.notify_all(); + } +} + +void RPCServer::ResetBarrierCounter() { + VLOG(3) << "RPCServer ResetBarrierCounter "; + std::unique_lock lock(mutex_); + for (auto& t : barrier_counter_) { + t.second = 0; + } +} + +void RPCServer::RegisterRPC(const std::string& rpc_name, + RequestHandler* handler, int thread_num) { + rpc_call_map_[rpc_name] = handler; + rpc_thread_num_[rpc_name] = thread_num; + + static int cond = -1; + rpc_cond_map_[rpc_name] = ++cond; + VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler + << ", cond:" << rpc_cond_map_[rpc_name]; +} + +void RPCServer::SetCond(const std::string& rpc_name) { + VLOG(3) << "RPCServer SetCond " << rpc_name; + { + std::unique_lock lock(mutex_); + cur_cond_ = rpc_cond_map_[rpc_name]; + } + + rpc_cond_.notify_all(); +} + +void RPCServer::WaitCond(const std::string& rpc_name) { + VLOG(3) << "RPCServer WaitCond " << rpc_name; + int cond = 0; + { + std::unique_lock lock(mutex_); + cond = rpc_cond_map_[rpc_name]; + } + + std::unique_lock lock(mutex_); + rpc_cond_.wait( + lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h new file mode 100644 index 00000000000000..c2e7ae706c9dc6 --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -0,0 +1,91 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include "paddle/fluid/operators/detail/request_handler.h" + +namespace paddle { +namespace operators { +namespace detail { + +class RPCServer { + public: + explicit RPCServer(const std::string& address, int client_num) + : cur_cond_(0), + bind_address_(address), + exit_flag_(false), + selected_port_(0), + client_num_(client_num) {} + + virtual ~RPCServer() {} + virtual void StartServer() = 0; + virtual void WaitServerReady() = 0; + + void ShutDown(); + + bool IsExit() { return exit_flag_.load(); } + + int GetSelectedPort() const { return selected_port_; } + void SavePort() const; + + // RegisterRPC, register the rpc method name to a handler + // class, and auto generate a condition id for this call + // to be used for the barrier. + void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, + int thread_num = 5); + + // Wait util all the clients have reached the barrier for one + // rpc method. This function should be called in the + // RequestHandler if you want to run the server/client in a + // synchronous mode. + void WaitBarrier(const std::string& rpc_name); + + void SetCond(const std::string& rpc_name); + void WaitCond(const std::string& rpc_name); + void IncreaseBatchBarrier(const std::string rpc_name); + void ResetBarrierCounter(); + + protected: + virtual void ShutDownImpl() = 0; + + private: + std::mutex mutex_; + std::unordered_map barrier_counter_; + std::condition_variable barrier_cond_; + + std::unordered_map rpc_cond_map_; + std::atomic cur_cond_; + std::condition_variable rpc_cond_; + + protected: + std::string bind_address_; + std::atomic exit_flag_; + int selected_port_; + + const int client_num_; + + std::unordered_map rpc_call_map_; + std::unordered_map rpc_thread_num_; + friend class RequestHandler; +}; + +}; // namespace detail +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index bf624da2a6c264..69cfd784f8dd4f 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -67,8 +67,8 @@ class VariableResponse { framework::Scope* GetMutableLocalScope() const { return local_scope_; } - inline std::string Varname() { return meta_.varname(); } - inline std::string OutVarname() { return meta_.out_varname(); } + inline std::string Varname() const { return meta_.varname(); } + inline std::string OutVarname() const { return meta_.out_varname(); } // should call parse first. framework::Variable* GetVar() { diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index a5678f63466d36..4bce2d322d8251 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/platform/nccl_helper.h" namespace paddle { @@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase { // NOTE: Can not use unique_ptr here because the default // deleter will call GRPC Server's base class's dtor and // that will cause a wired crash. - detail::AsyncGRPCServer rpc_service(endpoint, true); + detail::RequestSendHandler rpc_h(true); + detail::AsyncGRPCServer rpc_service(endpoint, 1); + rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h); + rpc_h.SetRPCServer(&rpc_service); + framework::ProgramDesc empty_program; framework::Executor executor(dev_ctx.GetPlace()); - rpc_service.SetScope(scope); - rpc_service.SetDevCtx(&dev_ctx); - rpc_service.SetProgram(&empty_program); - rpc_service.SetExecutor(&executor); + rpc_h.SetScope(scope); + rpc_h.SetDevCtx(&dev_ctx); + rpc_h.SetProgram(&empty_program); + rpc_h.SetExecutor(&executor); std::thread server_thread( - std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service)); - rpc_service.SetCond(0); + std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service)); + rpc_service.SetCond(detail::kRequestSend); VLOG(3) << "start getting nccl id from trainer 0..."; - auto recv = rpc_service.Get(); + rpc_service.WaitBarrier(detail::kRequestSend); VLOG(3) << "got nccl id and stop server..."; rpc_service.ShutDown(); VLOG(3) << "rpc server stopped"; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index df5f229acd75ee..71e75c25321812 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -19,14 +19,16 @@ limitations under the License. */ #include // NOLINT #include +#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -void RunServer(std::shared_ptr service) { - service->RunSyncUpdate(); +void RunServer(std::shared_ptr service) { + service->StartServer(); VLOG(4) << "RunServer thread end"; } static void split(const std::string &str, char sep, @@ -67,8 +69,6 @@ static void ParallelExecuteBlocks( for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } -std::atomic_int ListenAndServOp::selected_port_{0}; - ListenAndServOp::ListenAndServOp(const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, @@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type, ListenAndServOp::~ListenAndServOp() { Stop(); } void ListenAndServOp::Stop() { - rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->ShutDown(); server_thread_->join(); auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); @@ -87,26 +86,13 @@ void ListenAndServOp::Stop() { void ListenAndServOp::SavePort() const { // NOTE: default write file to /tmp/paddle.selected_port - selected_port_ = rpc_service_->GetSelectedPort(); - auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); - std::ofstream port_file; - port_file.open(file_path); - port_file << selected_port_.load(); - port_file.close(); - VLOG(4) << "selected port written to " << file_path; -} - -void ListenAndServOp::WaitServerReady() { - while (selected_port_.load() == 0) { - } + rpc_service_->SavePort(); } void ListenAndServOp::RunSyncLoop(framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const { - auto fan_in = Attr("Fanin"); - size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); @@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, optimize_prepared.begin(), std::shared_ptr(nullptr)); - bool exit_flag = false; + rpc_service_->ResetBarrierCounter(); // Record received sparse variables, so that // we could reset those after execute optimize program std::vector sparse_vars; - while (!exit_flag && !SignalHandler::IsProgramExit()) { + while (true) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(0); - size_t recv_var_cnt = 0; - int batch_barrier = 0; - while (batch_barrier != fan_in) { - const detail::ReceivedMessage v = rpc_service_->Get(); - auto recv_var_name = v.first; - if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; - break; - } else if (recv_var_name == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "recv batch barrier message"; - batch_barrier++; - continue; - } else { - VLOG(3) << "received grad: " << recv_var_name; - recv_var_cnt++; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - if (var->IsType()) { - sparse_vars.push_back(var); - } - } - } - if (exit_flag) { - rpc_service_->SetCond(1); - rpc_service_->ShutDown(); + rpc_service_->SetCond(detail::kRequestSend); + rpc_service_->WaitBarrier(detail::kRequestSend); + + if (rpc_service_->IsExit()) { + LOG(WARNING) << "get exit!rpc_processor break!"; + rpc_service_->SetCond(detail::kRequestGet); break; } // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads // and this will still work. - // The optimize blocks which have the same parent ID would run parallel // TODO(Yancey1989): need to use ParallelExecutor for future int32_t last_parent_blkid = program->Block(1).Parent(); @@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, var->GetMutable()->mutable_rows()->clear(); } - rpc_service_->SetCond(1); - // FIXME(typhoonzero): use another condition to sync wait clients get. - rpc_service_->WaitClientGet(fan_in); - sparse_vars.clear(); + rpc_service_->SetCond(detail::kRequestGet); + rpc_service_->WaitBarrier(detail::kRequestGet); + rpc_service_->ResetBarrierCounter(); } // while(true) } -static void AsyncUpdateThread( - const std::string &var_name, const bool &exit_flag, - const std::shared_ptr &queue, - framework::Executor *executor, - framework::ExecutorPrepareContext *prepared) { - VLOG(3) << "update thread for " << var_name << " started"; - while (!exit_flag && !SignalHandler::IsProgramExit()) { - const detail::ReceivedMessage v = queue->Pop(); - if (SignalHandler::IsProgramExit()) { - VLOG(3) << "update thread for " << var_name << " exit"; - break; - } - auto recv_var_name = v.first; - VLOG(4) << "async update " << recv_var_name; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - auto fs = framework::Async([var_name, &executor, &v, prepared] { - try { - executor->RunPreparedContext(prepared, - v.second->GetMutableLocalScope()); - } catch (const std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - }); - fs.wait(); - } -} - void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, framework::ProgramDesc *program) const { VLOG(3) << "RunAsyncLoop in"; // grad name to block id std::unordered_map grad_to_block_id; std::unordered_map id_to_grad; - std::unordered_map> - grad_to_queue; auto grad_to_block_id_str = Attr>("grad_to_block_id"); @@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); + int block_id = std::stoi(pieces[1]); grad_to_block_id[pieces[0]] = block_id; - std::shared_ptr queue = - std::make_shared(); - grad_to_queue[pieces[0]] = queue; - // record blocking queue in SignalHandler - SignalHandler::RegisterBlockingQueue(queue); id_to_grad[block_id] = pieces[0]; } size_t num_blocks = program->Size(); @@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; } - bool exit_flag = false; + request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); + request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); + request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); - VLOG(3) << "start async optimize threads"; - std::vector> fs; - for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) { - std::string grad_name = iter->first; - VLOG(3) << "create async update thread for " << grad_name; - fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor, - &grad_to_queue, &grad_to_prepared_ctx]() { - AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name], - executor, grad_to_prepared_ctx[grad_name].get()); - })); - } VLOG(3) << "RunAsyncLoop into while"; - while (!exit_flag && !SignalHandler::IsProgramExit()) { - const detail::ReceivedMessage v = rpc_service_->Get(); - auto recv_var_name = v.first; - if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; + while (true) { + if (rpc_service_->IsExit()) { + LOG(INFO) << "get exit!rpc_processor break!"; break; - } else { - VLOG(3) << "received grad: " << recv_var_name; - grad_to_queue[recv_var_name]->Push(v); } - if (exit_flag) { - rpc_service_->ShutDown(); - break; - } + sleep(1); } // while(true) } +static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope, + platform::DeviceContext *dev_ctx, + framework::Executor *executor, + framework::ProgramDesc *program, + framework::ExecutorPrepareContext *prefetch_ctx, + detail::RPCServer *rpc_server) { + h->SetScope(scope); + h->SetDevCtx(dev_ctx); + h->SetExecutor(executor); + h->SetProgram(program); + h->SetPrefetchPreparedCtx(std::move( + std::unique_ptr(prefetch_ctx))); + h->SetRPCServer(rpc_server); +} + void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { // Mark this as PS that it should decide profiling by listening from trainer. @@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, framework::Scope &recv_scope = scope.NewScope(); bool sync_mode = Attr("sync_mode"); + auto fan_in = Attr("Fanin"); PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); + LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in + << ", end_point:" << endpoint; + + // request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in)); + request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); + request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); + request_prefetch_handler_.reset( + new detail::RequestPrefetchHandler(sync_mode)); + + rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); + rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); + rpc_service_->RegisterRPC(detail::kRequestPrefetch, + request_prefetch_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); auto *prefetch_block = Attr(kPrefetchBlock); auto *program = optimize_block->Program(); framework::Executor executor(dev_place); - // prepare rpc_service - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - rpc_service_->SetProgram(program); - rpc_service_->SetExecutor(&executor); - // prepare for prefetch VLOG(3) << "prefetch block id is " << prefetch_block->ID(); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); - rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared)); + + auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, + &dev_ctx, &executor, program, prefetch_prepared.release(), + rpc_service_.get()); + + f(request_send_handler_.get()); + f(request_get_handler_.get()); + f(request_prefetch_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, signal(SIGTERM, SignalHandler::StopAndExit); // Write to a file of server selected port for python use. - std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port", - static_cast(::getpid())); SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block); @@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { } }; -bool SignalHandler::program_exit_flag_ = false; - -SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{}; - void SignalHandler::StopAndExit(int signal_num) { VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; - - program_exit_flag_ = true; - - // awake all blocking queues - for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin(); - iter != blocking_queue_set_.end(); iter++) { - iter->get()->Push( - std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr)); - } - - exit(EXIT_SUCCESS); -} - -void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) { - blocking_queue_set_.insert(queue); + exit(0); } } // namespace operators diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 6f868369dcf206..87952cb0e68359 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -23,7 +23,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/detail/rpc_server.h" namespace paddle { namespace operators { @@ -31,7 +32,7 @@ namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kPrefetchBlock[] = "PrefetchBlock"; -void RunServer(std::shared_ptr service); +void RunServer(std::shared_ptr service); class ListenAndServOp : public framework::OperatorBase { public: @@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase { void SavePort() const; - void WaitServerReady(); - - int GetSelectedPort() { return selected_port_; } + int GetSelectedPort() { return rpc_service_->GetSelectedPort(); } void Stop() override; void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override; - static void ResetPort() { selected_port_ = 0; } - protected: - mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr request_send_handler_; + mutable std::shared_ptr request_get_handler_; + mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr server_thread_; - // FIXME(wuyi): it's static so that the operator can be cloned. - static std::atomic_int selected_port_; }; class SignalHandler { - public: - typedef std::shared_ptr BlockingQueue; - typedef std::unordered_set BlockingQueueSet; - public: static void StopAndExit(int signal_num); - static void RegisterBlockingQueue(BlockingQueue&); - - static inline bool IsProgramExit() { return program_exit_flag_; } - private: - static bool program_exit_flag_; - - static BlockingQueueSet blocking_queue_set_; - DISABLE_COPY_AND_ASSIGN(SignalHandler); }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 2c77ee2e2792d6..bcd8e81609a37c 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase { auto rpc_client = detail::RPCClient::GetInstance(); + VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; + // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); if (sync_mode) { diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index 719f039a0f5fcd..a845ba2eb038fa 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -35,42 +37,44 @@ namespace m = paddle::operators::math; namespace detail = paddle::operators::detail; namespace string = paddle::string; -std::unique_ptr rpc_service; +std::unique_ptr g_rpc_service; +std::unique_ptr g_req_handler; -void StartServer(std::atomic* initialized) { +void StartServer() { f::Scope scope; p::CPUPlace place; scope.Var(NCCL_ID_VARNAME); p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); auto& dev_ctx = *pool.Get(p::CPUPlace()); - rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true)); - f::ProgramDesc empty_program; f::Executor executor(dev_ctx.GetPlace()); - rpc_service->SetScope(&scope); - rpc_service->SetDevCtx(&dev_ctx); - rpc_service->SetProgram(&empty_program); - rpc_service->SetExecutor(&executor); + g_req_handler->SetScope(&scope); + g_req_handler->SetDevCtx(&dev_ctx); + g_req_handler->SetProgram(&empty_program); + g_req_handler->SetExecutor(&executor); + + g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get()); + g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( - std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get())); - *initialized = true; - rpc_service->SetCond(0); - auto recv = rpc_service->Get(); + std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); + + g_rpc_service->SetCond(detail::kRequestSend); + std::cout << "before WaitFanInOfSend" << std::endl; + g_rpc_service->WaitBarrier(detail::kRequestSend); + LOG(INFO) << "got nccl id and stop server..."; - rpc_service->ShutDown(); + g_rpc_service->ShutDown(); server_thread.join(); } -TEST(SendNcclId, DISABLED_Normal) { - std::atomic initialized{false}; - std::thread server_thread(StartServer, &initialized); - while (!initialized) { - } - // wait server to start - // sleep(2); - rpc_service->WaitServerReady(); +TEST(SendNcclId, GrpcServer) { + g_req_handler.reset(new detail::RequestSendHandler(true)); + g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); + + std::thread server_thread(StartServer); + g_rpc_service->WaitServerReady(); f::Scope scope; p::CPUPlace place; @@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) { auto& dev_ctx = *pool.Get(p::CPUPlace()); auto var = scope.Var(NCCL_ID_VARNAME); - // var->SetType(f::proto::VarType_Type_RAW); auto id = var->GetMutable(); p::dynload::ncclGetUniqueId(id); - int port = rpc_service->GetSelectedPort(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = string::Sprintf("127.0.0.1:%d", port); detail::RPCClient client; - + LOG(INFO) << "connect to server" << ep; client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); client.Wait(); + client.AsyncSendBatchBarrier(ep); + client.Wait(); + server_thread.join(); - auto* ptr = rpc_service.release(); - delete ptr; + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); } diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 09367889a95179..6f8e3f22db54d1 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include // NOLINT #include #include