Skip to content

Commit

Permalink
[API] unified API for custom kvstores (apache#17010)
Browse files Browse the repository at this point in the history
* abstract kvstore api

* add test

* reorg folder

* add split kvstore.py to kvstore and base

* fix import

* add horovod class

* add registration

* add unit test for kvstore base

* add nightly test

* fix pushpull

* fix test

* simply API

* Trainer API

* fix a bug

* Fix typo

* update horovod tutorial

* better error message

* fix incorrect usage of name

* + query capacity

* unit test for test kvstore

* add trainer  test

* refactor

* renmae

* remove horovod example

* revert horovod example

* rename test_kvstore_custom.py

* more tests

* more tests for teststore

* fix type name

* fix lint

* fix lint

* fix lint

* Update dist_device_sync_kvstore_custom.py

* address CR

* add optimizer test

* add optimizer test
  • Loading branch information
eric-haibin-lin committed Dec 17, 2019
1 parent 814be59 commit f86a8d1
Show file tree
Hide file tree
Showing 12 changed files with 955 additions and 151 deletions.
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
export DMLC_LOG_STACK_TRACE_DEPTH=10
cd tests/nightly/
../../tools/launch.py -n 4 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 4 --launcher local python dist_device_sync_kvstore_custom.py
../../tools/launch.py -n 4 --launcher local python dist_sync_kvstore.py --type=init_gpu
popd
}
Expand Down
24 changes: 12 additions & 12 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
from . import base

# version info
__version__ = base.__version__

from . import contrib
from . import ndarray
from . import ndarray as nd
Expand Down Expand Up @@ -59,8 +63,6 @@
from . import callback
# from . import misc
from . import lr_scheduler
# use mx.kv as short for kvstore
from . import kvstore as kv
# Runtime compile module
from . import rtc
# Attribute scope to add attributes to symbolic graphs
Expand All @@ -84,22 +86,20 @@
from . import test_utils

from . import rnn

from . import gluon

# Dynamic library module should be done after ndarray and symbol are initialized
from . import library
from . import tvmop

__version__ = base.__version__

# Dist kvstore module which launches a separate process when role is set to "server".
# This should be done after other modules are initialized.
# With the native kvstore module (such as 'dist_sync_device'), the module launches a separate
# process when role is set to "server". This should be done after other modules are initialized.
# Otherwise this may result in errors when unpickling custom LR scheduler/optimizers.
# For example, the LRScheduler in gluoncv depends on a specific version of MXNet, and
# checks the __version__ attr of MXNet, which is not set on kvstore server due to the
# fact that kvstore-server module is imported before the __version__ attr is set.
from . import kvstore_server
# use mx.kv as short for kvstore
from . import kvstore as kv

# Dynamic library module should be done after ndarray and symbol are initialized
from . import library
from . import tvmop

from . import numpy_op_signature
from . import numpy_dispatch_protocol
Expand Down
50 changes: 37 additions & 13 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .. import optimizer as opt
from ..model import _create_kvstore, _create_sparse_kvstore
from .parameter import ParameterDict, Parameter
from ..kvstore import KVStore

class Trainer(object):
"""Applies an `Optimizer` on a set of Parameters. Trainer should
Expand Down Expand Up @@ -153,9 +154,10 @@ def _init_params(self):
else:
param_arrays = param._check_and_get(param._data, list)
idx = self._param2idx[param.name]
self._kvstore.init(idx, param_arrays[0])
if param._stype == 'default':
self._kvstore.pull(idx, param_arrays, priority=-idx)
if param._stype != 'default':
self._kvstore.init(idx, param_arrays[0])
else:
self._kvstore.broadcast(idx, param_arrays[0], param_arrays)

self._params_to_init = params_to_init

Expand Down Expand Up @@ -218,6 +220,10 @@ def _init_kvstore(self):
raise ValueError("Cannot set update_on_kvstore=False on dist kvstore "
"when sparse gradients are present.")
update_on_kvstore = config['update_on_kvstore']
# raise err if a custom kvstore is used for sparse training
if not isinstance(kvstore, KVStore):
raise ValueError("Cannot use {} for multi-device training with sparse gradients"
.format(type(kvstore)))

else:
# Training with dense weight and dense gradients.
Expand All @@ -234,6 +240,12 @@ def _init_kvstore(self):
"when training in async mode.")
if config['update_on_kvstore'] is not None:
update_on_kvstore = config['update_on_kvstore']
# raise err if update_on_kvstore is set to True with kvstores that do not support optimizers
if update_on_kvstore and not type(kvstore).is_capable('optimizer'):
if config['update_on_kvstore']:
raise ValueError("Please set update_on_kvstore=False "
"when training with {}".format(type(kvstore)))
update_on_kvstore = False

# set grad compression and optimizers
if kvstore:
Expand Down Expand Up @@ -357,14 +369,30 @@ def allreduce_grads(self):
self._allreduce_grads()

def _allreduce_grads(self):
if self._kvstore:
for i, param in enumerate(self._params):
if param.grad_req != 'null':
# nothing to reduce
if not self._kvstore:
return
for i, param in enumerate(self._params):
if param.grad_req != 'null':

self._kvstore.push(i, param.list_grad(), priority=-i)
if not self._update_on_kvstore:
self._kvstore.pull(i, param.list_grad(), priority=-i,
grad_list = param.list_grad()
# sparse gradients, call push and pull separately
if grad_list[0].stype != 'default':
self._kvstore.push(i, grad_list, priority=-i)
if param._stype == 'default':
if self._update_on_kvstore:
pull_list = param.list_data()
else:
pull_list = param.list_grad()
self._kvstore.pull(i, pull_list, priority=-i,
ignore_sparse=self._distributed)
else:
# allreduce dense gradients if not update_on_kvstore,
# otherwise push dense gradients, pull dense weights
if self._update_on_kvstore:
self._kvstore.pushpull(i, grad_list, out=param.list_data(), priority=-i)
else:
self._kvstore.pushpull(i, grad_list, priority=-i)

def update(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update.
Expand Down Expand Up @@ -419,10 +447,6 @@ def _update(self, ignore_stale_grad=False):
%(param.name, str(data.context)))

if self._kvstore and self._update_on_kvstore:
if param._stype == 'default':
# 'row_sparse' parameters are not pulled immediately - they're pulled
# in `Block.forward`
self._kvstore.pull(i, param.list_data(), priority=-i)
continue

for upd, arr, grad in zip(updates, param.list_data(), param.list_grad()):
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/kvstore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Key-value store for distributed communication"""
from .kvstore import *
from .base import *
from .kvstore_server import *
Loading

0 comments on commit f86a8d1

Please sign in to comment.