Skip to content

Commit ddb1eab

Browse files
authored
Added support of Horovod (#1195)
* [WIP] Horovod comp model * [WIP] Horovod comp model - Implemented spawn - Added comp model tests * Refactored test_utils.py into 3 files - we can better test new coming comp models * [WIP] Run horovod tests * [WIP] Horovod comp model + tests * autopep8 fix * [WIP] More tests * Updated utils tests * autopep8 fix * [WIP] more tests * Updated tests and code and cifar10 example * autopep8 fix * Fixed failing CI and updated code * autopep8 fix * Fixes failing test * Fixed bug with new/old hvd API and the config * Added metric tests * Formatting and docs updated * Updated frequency test * Fixed formatting and a typo in idist.model_name docs * Fixed failing test * Docs updates and updated auto methods according to horovod API * autopep8 fix * Cosmetics Co-authored-by: AutoPEP8 <>
1 parent 6faa6ac commit ddb1eab

33 files changed

+1077
-68
lines changed

.circleci/config.yml

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ parameters:
44
pytorch_stable_image:
55
type: string
66
# https://hub.docker.com/r/pytorch/pytorch/tags
7-
default: "pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime"
7+
default: "pytorch/pytorch:1.5.1-cuda10.1-cudnn7-runtime"
8+
pytorch_stable_image_devel:
9+
type: string
10+
# https://hub.docker.com/r/pytorch/pytorch/tags
11+
default: "pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel"
812
workingdir:
913
type: string
1014
default: "/tmp/ignite"
@@ -40,6 +44,12 @@ pull_pytorch_stable_image: &pull_pytorch_stable_image
4044
command: |
4145
docker pull << pipeline.parameters.pytorch_stable_image >>
4246
47+
pull_pytorch_stable_devel_image: &pull_pytorch_stable_devel_image
48+
- run:
49+
name: Pull PyTorch Stable Develop Image
50+
command: |
51+
docker pull << pipeline.parameters.pytorch_stable_image_devel >>
52+
4353
4454
run_pytorch_container: &run_pytorch_container
4555
- run:
@@ -51,6 +61,17 @@ run_pytorch_container: &run_pytorch_container
5161
docker exec -it pthd nvidia-smi
5262
docker exec -it pthd ls
5363
64+
65+
run_pytorch_devel_container: &run_pytorch_devel_container
66+
- run:
67+
name: Start Pytorch dev container
68+
environment:
69+
wd: << pipeline.parameters.workingdir >>
70+
command: |
71+
docker run --gpus=all --rm -itd --shm-size 16G -v ${wd}:/ignite -w /ignite --name pthd << pipeline.parameters.pytorch_stable_image_devel >>
72+
docker exec -it pthd nvidia-smi
73+
docker exec -it pthd ls
74+
5475
install_dependencies: &install_dependencies
5576
- run:
5677
name: Install dependencies
@@ -194,6 +215,68 @@ jobs:
194215
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"
195216
196217
218+
two_gpus_hvd_tests:
219+
<<: *two_gpus
220+
221+
working_directory: << pipeline.parameters.workingdir >>
222+
223+
steps:
224+
- checkout
225+
- <<: *pull_pytorch_stable_devel_image
226+
- <<: *run_pytorch_devel_container
227+
- <<: *install_dependencies
228+
- run:
229+
name: "Install Horovod with NCCL GPU ops"
230+
command: |
231+
232+
# Following https://github.com/horovod/horovod/blob/master/Dockerfile.test.gpu
233+
# and https://github.com/horovod/horovod/issues/1944#issuecomment-628192778
234+
docker exec -it pthd /bin/bash -c "apt-get update && apt-get install -y git"
235+
docker exec -it pthd /bin/bash -c "git clone --recursive https://github.com/horovod/horovod.git /horovod && cd /horovod && python setup.py sdist"
236+
docker exec -it pthd /bin/bash -c "conda install -y cmake=3.16 nccl=2.5 -c conda-forge"
237+
docker exec -it pthd /bin/bash -c 'cd /horovod && HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITHOUT_MPI=1 HOROVOD_WITH_PYTORCH=1 pip install -v $(ls /horovod/dist/horovod-*.tar.gz) && ldconfig'
238+
docker exec -it pthd horovodrun --check-build
239+
240+
- run:
241+
name: Run 1 Node 2 GPUs Unit Tests
242+
command: |
243+
export test_cmd='sh tests/run_gpu_tests.sh'
244+
docker exec -it pthd /bin/bash -c "${test_cmd}"
245+
# no CUDA devices Horovod tests
246+
export test_cmd='CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --cov-report xml -vvv tests/ -m distributed'
247+
docker exec -it pthd /bin/bash -c "${test_cmd}"
248+
249+
- run:
250+
name: Codecov upload
251+
command: |
252+
bash <(curl -s https://codecov.io/bash) -Z -F gpu-2-hvd
253+
254+
- run:
255+
name: "Check CIFAR10 using horovodrun"
256+
command: |
257+
docker exec -it pthd pip install fire
258+
export example_path="examples/contrib/cifar10"
259+
# initial run
260+
export stop_cmd="--stop_iteration=500"
261+
export test_cmd="cd ${example_path} && CI=1 horovodrun -np 2 python -u main.py run --backend=horovod"
262+
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
263+
# resume
264+
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
265+
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"
266+
267+
- run:
268+
name: "Check CIFAR10 using spawn"
269+
command: |
270+
export example_path="examples/contrib/cifar10"
271+
# initial run
272+
export stop_cmd="--stop_iteration=500"
273+
export test_cmd="cd ${example_path} && CI=1 python -u main.py run --backend=horovod --nproc_per_node=2"
274+
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
275+
# resume
276+
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
277+
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"
278+
279+
197280
# -------------------------------------------------------------------------------------
198281
# Workflows
199282
# -------------------------------------------------------------------------------------
@@ -204,3 +287,4 @@ workflows:
204287
- one_gpu_tests
205288
- two_gpus_tests
206289
- two_gpus_check_dist_cifar10_example
290+
- two_gpus_hvd_tests

examples/contrib/cifar10/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,19 @@ or
5454
python -u main.py run --backend="nccl" --nproc_per_node=2
5555
```
5656

57-
If user would like to provide already downloaded dataset, the path can be setup in parameters as
57+
##### Using [Horovod](https://horovod.readthedocs.io/en/latest/index.html) as distributed backend
58+
59+
Please, make sure to have Horovod installed before running.
60+
61+
Let's start training on a single node with 2 gpus:
5862
```bash
59-
--data_path="/path/to/cifar10/"
63+
# horovodrun
64+
horovodrun -np=2 python -u main.py run --backend="horovod"
65+
```
66+
or
67+
```bash
68+
# using function spawn inside the code
69+
python -u main.py run --backend="horovod" --nproc_per_node=2
6070
```
6171

6272

ignite/distributed/auto.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.utils.data.sampler import Sampler
99

1010
from ignite.distributed import utils as idist
11+
from ignite.distributed.comp_models import horovod as idist_hvd
1112
from ignite.distributed.comp_models import native as idist_native
1213
from ignite.distributed.comp_models import xla as idist_xla
1314
from ignite.utils import setup_logger
@@ -130,6 +131,7 @@ def auto_model(model: nn.Module) -> nn.Module:
130131
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
131132
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
132133
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
134+
- broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.
133135
134136
Examples:
135137
@@ -166,13 +168,19 @@ def auto_model(model: nn.Module) -> nn.Module:
166168

167169
# distributed data parallel model
168170
if idist.get_world_size() > 1:
169-
if idist.backend() == idist_native.NCCL:
171+
bnd = idist.backend()
172+
if idist.has_native_dist_support and bnd == idist_native.NCCL:
170173
lrank = idist.get_local_rank()
171174
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
172175
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank,])
173-
elif idist.backend() == idist_native.GLOO:
176+
elif idist.has_native_dist_support and bnd == idist_native.GLOO:
174177
logger.info("Apply torch DistributedDataParallel on model")
175178
model = torch.nn.parallel.DistributedDataParallel(model)
179+
elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
180+
import horovod.torch as hvd
181+
182+
logger.info("Broadcast the initial variable states from rank 0 to all other processes")
183+
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
176184

177185
# not distributed but multiple GPUs reachable so data parallel model
178186
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
@@ -187,14 +195,18 @@ def auto_optim(optimizer: Optimizer) -> Optimizer:
187195
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
188196
189197
Internally, this method is no-op for non-distributed and torch native distributed configuration.
198+
190199
For XLA distributed configuration, we create a new class that inherits from provided optimizer.
191200
The goal is to override the `step()` method with specific `xm.optimizer_step`_ implementation.
192201
202+
For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and
203+
its state is broadcasted from rank 0 to all other processes.
204+
193205
Examples:
194206
195207
.. code-block:: python
196208
197-
import ignite.distribted as idist
209+
import ignite.distributed as idist
198210
199211
optimizer = idist.auto_optim(optimizer)
200212
@@ -208,11 +220,19 @@ def auto_optim(optimizer: Optimizer) -> Optimizer:
208220
.. _xm.optimizer_step: http://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.optimizer_step
209221
210222
"""
211-
if not (idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU):
223+
bnd = idist.backend()
224+
if idist.has_xla_support and bnd == idist_xla.XLA_TPU:
225+
cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_XLADistributedOptimizer.__dict__))
226+
return cls(optimizer)
227+
228+
if idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
229+
import horovod.torch as hvd
230+
231+
optimizer = hvd.DistributedOptimizer(optimizer)
232+
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
212233
return optimizer
213234

214-
cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_XLADistributedOptimizer.__dict__))
215-
return cls(optimizer)
235+
return optimizer
216236

217237

218238
class DistributedProxySampler(DistributedSampler):

ignite/distributed/comp_models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ignite.distributed.comp_models.base import _SerialModel
2+
from ignite.distributed.comp_models.horovod import has_hvd_support
23
from ignite.distributed.comp_models.native import has_native_dist_support
34
from ignite.distributed.comp_models.xla import has_xla_support
45

@@ -15,6 +16,10 @@ def setup_available_computation_models():
1516
from ignite.distributed.comp_models.xla import _XlaDistModel
1617

1718
models.append(_XlaDistModel)
19+
if has_hvd_support:
20+
from ignite.distributed.comp_models.horovod import _HorovodDistModel
21+
22+
models.append(_HorovodDistModel)
1823

1924
return tuple(models)
2025

ignite/distributed/comp_models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class _SerialModel(ComputationModel):
177177
"""
178178

179179
name = "serial"
180-
available_backends = tuple()
180+
available_backends = ()
181181

182182
def get_local_rank(self) -> int:
183183
return 0

0 commit comments

Comments
 (0)