Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grpc bugs #7435

Merged
merged 12 commits into from
Jan 15, 2018
2 changes: 1 addition & 1 deletion cmake/external/grpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x"
GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
Expand Down
16 changes: 11 additions & 5 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::wait() {
bool RPCClient::Wait() {
bool ok = true;

while (true) {
Expand All @@ -96,7 +96,6 @@ bool RPCClient::wait() {
}

if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
Expand All @@ -110,22 +109,25 @@ bool RPCClient::Proceed() {

// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
req_count_--;

GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);

// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c;
return true;
return false;
}

c->Process();
delete c;
req_count_--;
return true;
}

Expand All @@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second;
}

grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));

channels_[ep] = ch;
return ch;
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool wait();
bool Wait();

private:
bool Proceed();
Expand Down
35 changes: 19 additions & 16 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {}
: service_(service), cq_(cq), status_(PROCESS) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }

protected:
grpc::ServerContext ctx_;
Expand All @@ -56,12 +59,14 @@ class RequestSend final : public RequestBase {

virtual ~RequestSend() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -81,13 +86,16 @@ class RequestGet final : public RequestBase {

virtual ~RequestGet() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -100,6 +108,8 @@ class RequestGet final : public RequestBase {
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);

cq_send_ = builder.AddCompletionQueue();
Expand Down Expand Up @@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}

void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}

last->SetStatus(FINISH);
return;
}

void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break;
}

PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}

RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
VLOG(4) << cq_name << " recv no regular event";
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
TryToRegisterNewOne();
delete base;
continue;
Expand All @@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
SetFinishOrDelete(base);
break;
}
case FINISH: {
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();

private:
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase {
rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

client_.wait();
PADDLE_ENFORCE(client_.Wait());
}

private:
Expand Down