Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grpc bugs #7435

Merged
merged 12 commits into from
Jan 15, 2018
2 changes: 1 addition & 1 deletion cmake/external/grpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x"
GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
Expand Down
14 changes: 10 additions & 4 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ bool RPCClient::wait() {
}

if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
Expand All @@ -110,22 +109,25 @@ bool RPCClient::Proceed() {

// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
req_count_--;

GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);

// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only log one time for the error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LOG(ERROR) << "grpc error:" << c->status_.error_message();
delete c;
return true;
return false;
}

c->Process();
delete c;
req_count_--;
return true;
}

Expand All @@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second;
}

grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));

channels_[ep] = ch;
return ch;
Expand Down
35 changes: 19 additions & 16 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {}
: service_(service), cq_(cq), status_(PROCESS) {
assert(cq_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }

CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }

protected:
grpc::ServerContext ctx_;
Expand All @@ -56,12 +59,15 @@ class RequestSend final : public RequestBase {

virtual ~RequestSend() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -81,13 +87,16 @@ class RequestGet final : public RequestBase {

virtual ~RequestGet() {}

virtual std::string GetReqName() { return request_.varname(); }

virtual void Process() {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}

protected:
Expand All @@ -100,6 +109,8 @@ class RequestGet final : public RequestBase {
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);

cq_send_ = builder.AddCompletionQueue();
Expand Down Expand Up @@ -159,18 +170,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}

void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}

last->SetStatus(FINISH);
return;
}

void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
Expand All @@ -184,15 +183,20 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break;
}

assert(tag);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PADDLE_ENFORCE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (wait && !done_) {
Wait();
}

RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
if (!ok) {
VLOG(4) << cq_name << " recv no regular event";
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
// FIXME(gongwb): delete the old one?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment does not make things clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't get more context when ok != true.

TryToRegisterNewOne();
delete base;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other places to release this memory then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's a grpc bug or it's our application bug.
When delete base often, I met an error

continue;
}

Expand All @@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
SetFinishOrDelete(base);
break;
}
case FINISH: {
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();

private:
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class RecvOp : public framework::OperatorBase {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
VLOG(4) << "param_count:" << param_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce VLOG appearances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

<< " trainer_count:" << trainer_count;
for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client.
const detail::MessageWithName &v = rpc_service_->Get();
Expand Down
5 changes: 4 additions & 1 deletion paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

client_.wait();
if (!client_.wait()) {
LOG(ERROR) << "send op exit";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This log is too simple.

Copy link
Contributor Author

@gongweibao gongweibao Jan 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detail logs had been logged in functions it calls.

exit(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use exit in operators, use PADDLE_ENFORCE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
}

private:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please split the fix and the book dist sample in two PRs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import sys

import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import os
import sys

TRAINERS = 2
BATCH_SIZE = 128
PASS_NUM = 30


def resnet_cifar10(input, depth=32):
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=tmp, act=act)

def shortcut(input, ch_in, ch_out, stride):
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
return input

def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None)
short = shortcut(input, ch_in, ch_out, stride)
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')

def layer_warp(block_func, input, ch_in, ch_out, count, stride):
tmp = block_func(input, ch_in, ch_out, stride)
for i in range(1, count):
tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp

assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool


def vgg16_bn_drop(input):
def conv_block(input, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max')

conv1 = conv_block(input, 64, 2, [0.3, 0])
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])

drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
bn = fluid.layers.batch_norm(input=fc1, act='relu')
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
return fc2


classdim = 10
data_shape = [3, 32, 32]

images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')

net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]

if net_type == "vgg":
print("train vgg net")
net = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)

predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)

optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimize_ops, params_grads = optimizer.minimize(avg_cost)

accuracy = fluid.evaluator.Accuracy(input=predict, label=label)

train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=128 * 10),
batch_size=BATCH_SIZE)

place = fluid.CPUPlace()
exe = fluid.Executor(place)

t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)

if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
print("start pserver at:", current_endpoint)
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
elif training_role == "TRAINER":
print("start trainer")
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
accuracy.reset(exe)
for data in train_reader():
loss, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe)
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
# this model is slow, so if we can train two mini batch, we think it works properly.
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)