Skip to content

Commit

Permalink
fix grpc bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
gongweibao committed Jan 15, 2018
1 parent 1f40d6f commit f029ff7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
14 changes: 10 additions & 4 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ bool RPCClient::wait() {
}

if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
Expand All @@ -110,22 +109,25 @@ bool RPCClient::Proceed() {

// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
req_count_--;

GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);

// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String();
LOG(ERROR) << "grpc error:" << c->status_.error_message();
delete c;
return true;
return false;
}

c->Process();
delete c;
req_count_--;
return true;
}

Expand All @@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second;
}

grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));

channels_[ep] = ch;
return ch;
Expand Down
27 changes: 10 additions & 17 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class RequestBase {

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }

protected:
grpc::ServerContext ctx_;
Expand All @@ -58,6 +59,8 @@ class RequestSend final : public RequestBase {

virtual ~RequestSend() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
Expand All @@ -84,6 +87,8 @@ class RequestGet final : public RequestBase {

virtual ~RequestGet() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
// proc request.
std::string var_name = request_.varname();
Expand Down Expand Up @@ -165,19 +170,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}

void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(4) << "delete Requestget status:" << last->Status();
delete last;
last = NULL;
return;
}

last->SetStatus(FINISH);
return;
}

void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -197,10 +189,12 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}

RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event";
TryToRegisterNewOne();
delete base;
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
// FIXME(gongwb): delete this one? register new one?
continue;
}

Expand All @@ -209,7 +203,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
// SetFinishOrDelete(base);
break;
}
case FINISH: {
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();

private:
Expand Down
5 changes: 4 additions & 1 deletion paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

client_.wait();
if (!client_.wait()) {
LOG(ERROR) << "send op exit";
exit(1);
}
}

private:
Expand Down

0 comments on commit f029ff7

Please sign in to comment.