-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix grpc bugs #7435
Fix grpc bugs #7435
Changes from 4 commits
bbd7455
98870fe
7e1388b
78f4ebf
1f40d6f
f029ff7
b811458
8a2978e
df51ac8
69cff8a
7e2bea9
a291062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,9 @@ class RequestBase { | |
public: | ||
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, | ||
grpc::ServerCompletionQueue* cq) | ||
: service_(service), cq_(cq), status_(PROCESS) {} | ||
: service_(service), cq_(cq), status_(PROCESS) { | ||
assert(cq_); | ||
} | ||
virtual ~RequestBase() {} | ||
virtual void Process() { assert(false); } | ||
|
||
|
@@ -100,6 +102,8 @@ class RequestGet final : public RequestBase { | |
void AsyncGRPCServer::RunSyncUpdate() { | ||
grpc::ServerBuilder builder; | ||
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); | ||
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); | ||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); | ||
builder.RegisterService(&service_); | ||
|
||
cq_send_ = builder.AddCompletionQueue(); | ||
|
@@ -162,6 +166,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { | |
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) { | ||
std::unique_lock<std::mutex> lock(cq_mutex_); | ||
if (is_shut_down_) { | ||
VLOG(4) << "delete Requestget status:" << last->Status(); | ||
delete last; | ||
last = NULL; | ||
return; | ||
|
@@ -184,13 +189,14 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, | |
break; | ||
} | ||
|
||
assert(tag); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (wait && !done_) { | ||
Wait(); | ||
} | ||
|
||
RequestBase* base = (RequestBase*)tag; | ||
if (!ok) { | ||
VLOG(4) << cq_name << " recv no regular event"; | ||
LOG(WARNING) << cq_name << " recv no regular event"; | ||
TryToRegisterNewOne(); | ||
delete base; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No other places to release this memory then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure it's a grpc bug or it's our application bug. |
||
continue; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,6 +100,8 @@ class RecvOp : public framework::OperatorBase { | |
// TODO(gognwb): simply this loop. | ||
// Get from multiple trainers, we don't care about order in which | ||
// the gradient arrives, just add suffix 0~n then average the gradient. | ||
VLOG(4) << "param_count:" << param_count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reduce There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
<< " trainer_count:" << trainer_count; | ||
for (size_t i = 0; i < param_count * trainer_count; ++i) { | ||
// blocking get one var from client. | ||
const detail::MessageWithName &v = rpc_service_->Get(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please split the fix and the book dist sample in two PRs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import sys | ||
|
||
import paddle.v2 as paddle | ||
import paddle.v2.fluid as fluid | ||
import os | ||
import sys | ||
|
||
TRAINERS = 1 | ||
BATCH_SIZE = 128 | ||
PASS_NUM = 1 | ||
|
||
|
||
def resnet_cifar10(input, depth=32): | ||
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): | ||
tmp = fluid.layers.conv2d( | ||
input=input, | ||
filter_size=filter_size, | ||
num_filters=ch_out, | ||
stride=stride, | ||
padding=padding, | ||
act=None, | ||
bias_attr=False) | ||
return fluid.layers.batch_norm(input=tmp, act=act) | ||
|
||
def shortcut(input, ch_in, ch_out, stride): | ||
if ch_in != ch_out: | ||
return conv_bn_layer(input, ch_out, 1, stride, 0, None) | ||
else: | ||
return input | ||
|
||
def basicblock(input, ch_in, ch_out, stride): | ||
tmp = conv_bn_layer(input, ch_out, 3, stride, 1) | ||
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) | ||
short = shortcut(input, ch_in, ch_out, stride) | ||
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') | ||
|
||
def layer_warp(block_func, input, ch_in, ch_out, count, stride): | ||
tmp = block_func(input, ch_in, ch_out, stride) | ||
for i in range(1, count): | ||
tmp = block_func(tmp, ch_out, ch_out, 1) | ||
return tmp | ||
|
||
assert (depth - 2) % 6 == 0 | ||
n = (depth - 2) / 6 | ||
conv1 = conv_bn_layer( | ||
input=input, ch_out=16, filter_size=3, stride=1, padding=1) | ||
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) | ||
res2 = layer_warp(basicblock, res1, 16, 32, n, 2) | ||
res3 = layer_warp(basicblock, res2, 32, 64, n, 2) | ||
pool = fluid.layers.pool2d( | ||
input=res3, pool_size=8, pool_type='avg', pool_stride=1) | ||
return pool | ||
|
||
|
||
def vgg16_bn_drop(input): | ||
def conv_block(input, num_filter, groups, dropouts): | ||
return fluid.nets.img_conv_group( | ||
input=input, | ||
pool_size=2, | ||
pool_stride=2, | ||
conv_num_filter=[num_filter] * groups, | ||
conv_filter_size=3, | ||
conv_act='relu', | ||
conv_with_batchnorm=True, | ||
conv_batchnorm_drop_rate=dropouts, | ||
pool_type='max') | ||
|
||
conv1 = conv_block(input, 64, 2, [0.3, 0]) | ||
conv2 = conv_block(conv1, 128, 2, [0.4, 0]) | ||
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) | ||
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) | ||
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) | ||
|
||
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) | ||
fc1 = fluid.layers.fc(input=drop, size=512, act=None) | ||
bn = fluid.layers.batch_norm(input=fc1, act='relu') | ||
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) | ||
fc2 = fluid.layers.fc(input=drop2, size=512, act=None) | ||
return fc2 | ||
|
||
|
||
classdim = 10 | ||
data_shape = [3, 32, 32] | ||
|
||
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') | ||
label = fluid.layers.data(name='label', shape=[1], dtype='int64') | ||
|
||
net_type = "vgg" | ||
if len(sys.argv) >= 2: | ||
net_type = sys.argv[1] | ||
|
||
if net_type == "vgg": | ||
print("train vgg net") | ||
net = vgg16_bn_drop(images) | ||
elif net_type == "resnet": | ||
print("train resnet") | ||
net = resnet_cifar10(images, 32) | ||
else: | ||
raise ValueError("%s network is not supported" % net_type) | ||
|
||
predict = fluid.layers.fc(input=net, size=classdim, act='softmax') | ||
cost = fluid.layers.cross_entropy(input=predict, label=label) | ||
avg_cost = fluid.layers.mean(x=cost) | ||
|
||
optimizer = fluid.optimizer.Adam(learning_rate=0.001) | ||
optimize_ops, params_grads = optimizer.minimize(avg_cost) | ||
|
||
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) | ||
|
||
train_reader = paddle.batch( | ||
paddle.reader.shuffle( | ||
paddle.dataset.cifar.train10(), buf_size=128 * 10), | ||
batch_size=BATCH_SIZE) | ||
|
||
place = fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
t = fluid.DistributeTranspiler() | ||
# all parameter server endpoints list for spliting parameters | ||
pserver_endpoints = os.getenv("PSERVERS") | ||
# server endpoint for current node | ||
current_endpoint = os.getenv("SERVER_ENDPOINT") | ||
# run as trainer or parameter server | ||
training_role = os.getenv("TRAINING_ROLE", | ||
"TRAINER") # get the training role: trainer/pserver | ||
t.transpile( | ||
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) | ||
|
||
if training_role == "PSERVER": | ||
if not current_endpoint: | ||
print("need env SERVER_ENDPOINT") | ||
exit(1) | ||
print("start pserver at:", current_endpoint) | ||
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) | ||
exe.run(fluid.default_startup_program()) | ||
exe.run(pserver_prog) | ||
elif training_role == "TRAINER": | ||
print("start trainer") | ||
feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) | ||
exe.run(fluid.default_startup_program()) | ||
for pass_id in range(PASS_NUM): | ||
accuracy.reset(exe) | ||
for data in train_reader(): | ||
loss, acc = exe.run(fluid.default_main_program(), | ||
feed=feeder.feed(data), | ||
fetch_list=[avg_cost] + accuracy.metrics) | ||
pass_acc = accuracy.eval(exe) | ||
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( | ||
pass_acc)) | ||
# this model is slow, so if we can train two mini batch, we think it works properly. | ||
else: | ||
print("environment var TRAINER_ROLE should be TRAINER os PSERVER") | ||
exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.