Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix race condition in kvstore.pushpull (#17007)
Browse files Browse the repository at this point in the history
* add back gluon test

* fix typo

* change back gpu ctx

* also handle the case there some are pull and some are pushpull

* fix typo
  • Loading branch information
eric-haibin-lin committed Dec 11, 2019
1 parent 04ebe45 commit 05af5c4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
35 changes: 24 additions & 11 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,34 @@ class KVStoreDistServer {
if (log_verbose_) {
LOG(INFO) << "sent response to " << update_buf->request.size() << " workers";
}
/**
* Request can be for either push, pull or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
bool has_pull = false;
for (const auto& req : update_buf->request) {
/**
* Request can be for either push, pull or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
} else {
has_pull = has_pull || req.pull;
}
if (has_pull) {
// if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
for (const auto& req : update_buf->request) {
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
}
}
update_buf->request.clear();
} else {
// otherwise, send response directly
for (const auto& req : update_buf->request) {
server->Response(req);
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
} else {
update_buf->merged.WaitToRead();
}
Expand Down
35 changes: 19 additions & 16 deletions tests/nightly/dist_device_sync_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def check_diff_to_scalar(A, x, rank=None):
def init_kv():
# init kv dns keys
kv.init(keys, [mx.nd.ones(shape)] * len(keys))
kv.init('9', mx.nd.ones(shape))
kv.init('10', mx.nd.ones(shape))
kv.init('99', mx.nd.ones(big_shape))
kv.init('100', mx.nd.ones(big_shape))
# worker info
my_rank = kv.rank
nworker = kv.num_workers
Expand All @@ -55,33 +58,30 @@ def init_kv():
def test_sync_push_pull():
kv, my_rank, nworker = init_kv()
num_gpus = 2
def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False):
def check_default_keys(kv, my_rank, nworker, nrepeat=3):
# checks pull after push in loop, because behavior during
# consecutive pushes doesn't offer any guarantees
for i in range(offset, nrepeat):
for i in range(nrepeat):
scale = my_rank + 1
num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1

arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
val = mx.nd.zeros(shape)
if use_pushpull:
kv.pushpull('3', arr, out=val)
else:
kv.push('3', arr)
kv.pull('3', out=val)
kv.push('9', arr)
kv.pull('9', out=val)
check_diff_to_scalar(val, num)
kv.pushpull('10', arr, out=val)
check_diff_to_scalar(val, num)

big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
big_val = mx.nd.zeros(big_shape)
if use_pushpull:
kv.pushpull('99', big_arr, out=big_val)
else:
kv.push('99', big_arr)
kv.pull('99', out=big_val)
kv.push('99', big_arr)
kv.pull('99', out=big_val)
check_diff_to_scalar(big_val, num)
kv.pushpull('100', big_arr, out=big_val)
check_diff_to_scalar(big_val, num)

check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False)
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True)
check_default_keys(kv, my_rank, nworker, nrepeat=3)
print('worker ' + str(my_rank) + ' is done')

def test_sync_init():
Expand All @@ -106,10 +106,12 @@ def check_trainer_kv_update(update_on_kv):
x = params.get('x', shape=(10,1), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
try:
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
kvstore=kv, update_on_kvstore=update_on_kv)
trainer._init_kvstore()
assert trainer._kv_initialized
assert trainer._update_on_kvstore is True
if update_on_kv is not None:
assert trainer._update_on_kvstore is update_on_kv
except ValueError:
assert update_on_kv is False

Expand All @@ -122,3 +124,4 @@ def check_trainer_kv_update(update_on_kv):
if __name__ == "__main__":
test_sync_init()
test_sync_push_pull()
test_gluon_trainer_type()

0 comments on commit 05af5c4

Please sign in to comment.