Skip to content

Commit

Permalink
Fix grpc bugs (#7435)
Browse files Browse the repository at this point in the history
Fix grpc bugs
  • Loading branch information
gongweibao authored Jan 15, 2018
1 parent 448fee3 commit 535fefb
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 25 deletions.
2 changes: 1 addition & 1 deletion cmake/external/grpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x"
GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
Expand Down
16 changes: 11 additions & 5 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::wait() {
bool RPCClient::Wait() {
bool ok = true;

while (true) {
Expand All @@ -96,7 +96,6 @@ bool RPCClient::wait() {
}

if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
Expand All @@ -110,22 +109,25 @@ bool RPCClient::Proceed() {

// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
req_count_--;

GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);

// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c;
return true;
return false;
}

c->Process();
delete c;
req_count_--;
return true;
}

Expand All @@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second;
}

grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));

channels_[ep] = ch;
return ch;
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool wait();
bool Wait();

private:
bool Proceed();
Expand Down
35 changes: 19 additions & 16 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {}
: service_(service), cq_(cq), status_(PROCESS) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }

protected:
grpc::ServerContext ctx_;
Expand All @@ -56,12 +59,14 @@ class RequestSend final : public RequestBase {

virtual ~RequestSend() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -81,13 +86,16 @@ class RequestGet final : public RequestBase {

virtual ~RequestGet() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -100,6 +108,8 @@ class RequestGet final : public RequestBase {
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);

cq_send_ = builder.AddCompletionQueue();
Expand Down Expand Up @@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}

void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}

last->SetStatus(FINISH);
return;
}

void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break;
}

PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}

RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
VLOG(4) << cq_name << " recv no regular event";
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
TryToRegisterNewOne();
delete base;
continue;
Expand All @@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
SetFinishOrDelete(base);
break;
}
case FINISH: {
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();

private:
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase {
rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

client_.wait();
PADDLE_ENFORCE(client_.Wait());
}

private:
Expand Down

0 comments on commit 535fefb

Please sign in to comment.