-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix grpc bugs #7435
Fix grpc bugs #7435
Changes from 9 commits
bbd7455
98870fe
7e1388b
78f4ebf
1f40d6f
f029ff7
b811458
8a2978e
df51ac8
69cff8a
7e2bea9
a291062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,12 +28,15 @@ class RequestBase { | |
public: | ||
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, | ||
grpc::ServerCompletionQueue* cq) | ||
: service_(service), cq_(cq), status_(PROCESS) {} | ||
: service_(service), cq_(cq), status_(PROCESS) { | ||
assert(cq_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
virtual ~RequestBase() {} | ||
virtual void Process() { assert(false); } | ||
|
||
CallStatus Status() { return status_; } | ||
void SetStatus(CallStatus status) { status_ = status; } | ||
virtual std::string GetReqName() { assert(false); } | ||
|
||
protected: | ||
grpc::ServerContext ctx_; | ||
|
@@ -56,12 +59,15 @@ class RequestSend final : public RequestBase { | |
|
||
virtual ~RequestSend() {} | ||
|
||
virtual std::string GetReqName() { return request_.varname(); } | ||
|
||
virtual void Process() { | ||
MessageWithName msg_with_name = | ||
std::make_pair(request_.varname(), std::move(request_)); | ||
queue_->Push(std::move(msg_with_name)); | ||
// TODO(gongwb): check var's info. | ||
responder_.Finish(reply_, grpc::Status::OK, this); | ||
status_ = FINISH; | ||
} | ||
|
||
protected: | ||
|
@@ -81,13 +87,16 @@ class RequestGet final : public RequestBase { | |
|
||
virtual ~RequestGet() {} | ||
|
||
virtual std::string GetReqName() { return request_.varname(); } | ||
|
||
virtual void Process() { | ||
// proc request. | ||
std::string var_name = request_.varname(); | ||
auto* var = scope_->FindVar(var_name); | ||
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); | ||
// TODO(gongwb): check var's info. | ||
responder_.Finish(reply_, grpc::Status::OK, this); | ||
status_ = FINISH; | ||
} | ||
|
||
protected: | ||
|
@@ -100,6 +109,8 @@ class RequestGet final : public RequestBase { | |
void AsyncGRPCServer::RunSyncUpdate() { | ||
grpc::ServerBuilder builder; | ||
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); | ||
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); | ||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); | ||
builder.RegisterService(&service_); | ||
|
||
cq_send_ = builder.AddCompletionQueue(); | ||
|
@@ -159,18 +170,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { | |
VLOG(4) << "create Requestget status:" << get->Status(); | ||
} | ||
|
||
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) { | ||
std::unique_lock<std::mutex> lock(cq_mutex_); | ||
if (is_shut_down_) { | ||
delete last; | ||
last = NULL; | ||
return; | ||
} | ||
|
||
last->SetStatus(FINISH); | ||
return; | ||
} | ||
|
||
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, | ||
std::string cq_name, | ||
std::function<void()> TryToRegisterNewOne) { | ||
|
@@ -184,15 +183,20 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, | |
break; | ||
} | ||
|
||
assert(tag); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (wait && !done_) { | ||
Wait(); | ||
} | ||
|
||
RequestBase* base = (RequestBase*)tag; | ||
// reference: | ||
// https://github.com/tensorflow/tensorflow/issues/5596 | ||
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM | ||
if (!ok) { | ||
VLOG(4) << cq_name << " recv no regular event"; | ||
LOG(WARNING) << cq_name << " recv no regular event:argument name" | ||
<< base->GetReqName(); | ||
// FIXME(gongwb): delete the old one? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment does not make things clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't get more context when |
||
TryToRegisterNewOne(); | ||
delete base; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No other places to release this memory then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure it's a grpc bug or it's our application bug. |
||
continue; | ||
} | ||
|
||
|
@@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, | |
VLOG(4) << cq_name << " status:" << base->Status(); | ||
TryToRegisterNewOne(); | ||
base->Process(); | ||
SetFinishOrDelete(base); | ||
break; | ||
} | ||
case FINISH: { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,6 +100,8 @@ class RecvOp : public framework::OperatorBase { | |
// TODO(gognwb): simply this loop. | ||
// Get from multiple trainers, we don't care about order in which | ||
// the gradient arrives, just add suffix 0~n then average the gradient. | ||
VLOG(4) << "param_count:" << param_count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reduce There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
<< " trainer_count:" << trainer_count; | ||
for (size_t i = 0; i < param_count * trainer_count; ++i) { | ||
// blocking get one var from client. | ||
const detail::MessageWithName &v = rpc_service_->Get(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,10 @@ class SendOp : public framework::OperatorBase { | |
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); | ||
} | ||
|
||
client_.wait(); | ||
if (!client_.wait()) { | ||
LOG(ERROR) << "send op exit"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This log is too simple. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Detail logs had been logged in functions it calls. |
||
exit(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
} | ||
|
||
private: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please split the fix and the book dist sample in two PRs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import sys | ||
|
||
import paddle.v2 as paddle | ||
import paddle.v2.fluid as fluid | ||
import os | ||
import sys | ||
|
||
TRAINERS = 2 | ||
BATCH_SIZE = 128 | ||
PASS_NUM = 30 | ||
|
||
|
||
def resnet_cifar10(input, depth=32): | ||
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): | ||
tmp = fluid.layers.conv2d( | ||
input=input, | ||
filter_size=filter_size, | ||
num_filters=ch_out, | ||
stride=stride, | ||
padding=padding, | ||
act=None, | ||
bias_attr=False) | ||
return fluid.layers.batch_norm(input=tmp, act=act) | ||
|
||
def shortcut(input, ch_in, ch_out, stride): | ||
if ch_in != ch_out: | ||
return conv_bn_layer(input, ch_out, 1, stride, 0, None) | ||
else: | ||
return input | ||
|
||
def basicblock(input, ch_in, ch_out, stride): | ||
tmp = conv_bn_layer(input, ch_out, 3, stride, 1) | ||
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) | ||
short = shortcut(input, ch_in, ch_out, stride) | ||
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') | ||
|
||
def layer_warp(block_func, input, ch_in, ch_out, count, stride): | ||
tmp = block_func(input, ch_in, ch_out, stride) | ||
for i in range(1, count): | ||
tmp = block_func(tmp, ch_out, ch_out, 1) | ||
return tmp | ||
|
||
assert (depth - 2) % 6 == 0 | ||
n = (depth - 2) / 6 | ||
conv1 = conv_bn_layer( | ||
input=input, ch_out=16, filter_size=3, stride=1, padding=1) | ||
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) | ||
res2 = layer_warp(basicblock, res1, 16, 32, n, 2) | ||
res3 = layer_warp(basicblock, res2, 32, 64, n, 2) | ||
pool = fluid.layers.pool2d( | ||
input=res3, pool_size=8, pool_type='avg', pool_stride=1) | ||
return pool | ||
|
||
|
||
def vgg16_bn_drop(input): | ||
def conv_block(input, num_filter, groups, dropouts): | ||
return fluid.nets.img_conv_group( | ||
input=input, | ||
pool_size=2, | ||
pool_stride=2, | ||
conv_num_filter=[num_filter] * groups, | ||
conv_filter_size=3, | ||
conv_act='relu', | ||
conv_with_batchnorm=True, | ||
conv_batchnorm_drop_rate=dropouts, | ||
pool_type='max') | ||
|
||
conv1 = conv_block(input, 64, 2, [0.3, 0]) | ||
conv2 = conv_block(conv1, 128, 2, [0.4, 0]) | ||
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) | ||
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) | ||
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) | ||
|
||
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) | ||
fc1 = fluid.layers.fc(input=drop, size=512, act=None) | ||
bn = fluid.layers.batch_norm(input=fc1, act='relu') | ||
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) | ||
fc2 = fluid.layers.fc(input=drop2, size=512, act=None) | ||
return fc2 | ||
|
||
|
||
classdim = 10 | ||
data_shape = [3, 32, 32] | ||
|
||
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') | ||
label = fluid.layers.data(name='label', shape=[1], dtype='int64') | ||
|
||
net_type = "vgg" | ||
if len(sys.argv) >= 2: | ||
net_type = sys.argv[1] | ||
|
||
if net_type == "vgg": | ||
print("train vgg net") | ||
net = vgg16_bn_drop(images) | ||
elif net_type == "resnet": | ||
print("train resnet") | ||
net = resnet_cifar10(images, 32) | ||
else: | ||
raise ValueError("%s network is not supported" % net_type) | ||
|
||
predict = fluid.layers.fc(input=net, size=classdim, act='softmax') | ||
cost = fluid.layers.cross_entropy(input=predict, label=label) | ||
avg_cost = fluid.layers.mean(x=cost) | ||
|
||
optimizer = fluid.optimizer.Adam(learning_rate=0.001) | ||
optimize_ops, params_grads = optimizer.minimize(avg_cost) | ||
|
||
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) | ||
|
||
train_reader = paddle.batch( | ||
paddle.reader.shuffle( | ||
paddle.dataset.cifar.train10(), buf_size=128 * 10), | ||
batch_size=BATCH_SIZE) | ||
|
||
place = fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
t = fluid.DistributeTranspiler() | ||
# all parameter server endpoints list for spliting parameters | ||
pserver_endpoints = os.getenv("PSERVERS") | ||
# server endpoint for current node | ||
current_endpoint = os.getenv("SERVER_ENDPOINT") | ||
# run as trainer or parameter server | ||
training_role = os.getenv("TRAINING_ROLE", | ||
"TRAINER") # get the training role: trainer/pserver | ||
t.transpile( | ||
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) | ||
|
||
if training_role == "PSERVER": | ||
if not current_endpoint: | ||
print("need env SERVER_ENDPOINT") | ||
exit(1) | ||
print("start pserver at:", current_endpoint) | ||
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) | ||
exe.run(fluid.default_startup_program()) | ||
exe.run(pserver_prog) | ||
elif training_role == "TRAINER": | ||
print("start trainer") | ||
feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) | ||
exe.run(fluid.default_startup_program()) | ||
for pass_id in range(PASS_NUM): | ||
accuracy.reset(exe) | ||
for data in train_reader(): | ||
loss, acc = exe.run(fluid.default_main_program(), | ||
feed=feeder.feed(data), | ||
fetch_list=[avg_cost] + accuracy.metrics) | ||
pass_acc = accuracy.eval(exe) | ||
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( | ||
pass_acc)) | ||
# this model is slow, so if we can train two mini batch, we think it works properly. | ||
else: | ||
print("environment var TRAINER_ROLE should be TRAINER os PSERVER") | ||
exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only log one time for the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.