Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multi-gpu support via distributed wrapper #252

Merged
merged 2 commits into from
Feb 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ python train.py --tensorboard --logdir log_dir/ # Make sure the Tensorboard inst

For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training.

## Multi-GPU Training

We support multi-GPU training via the distributed parallel wrapper (see [here](https://github.com/NVIDIA/sentiment-discovery/blob/master/analysis/scale.md) and [here](https://github.com/SeanNaren/deepspeech.pytorch/issues/211) to see why we don't use DataParallel).

To use multi-GPU:

```
python -m multiproc python train.py --visdom --cuda # Add your parameters as normal, multiproc will scale to all GPUs automatically
```

multiproc will open a log for all processes other than the main process.

We suggest using the gloo backend which defaults to TCP if Infiniband isn't available. Using NCCL2 is also possible as a backend. More information [here](http://pytorch.org/docs/master/distributed.html#distributed-basics).

### Noise Augmentation/Injection

There is support for two different types of noise; noise augmentation and noise injection.
Expand Down
52 changes: 38 additions & 14 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,44 @@
import time
import torch
from torch.autograd import Variable
from tqdm import tqdm
from warpctc_pytorch import CTCLoss
from tqdm import trange
from model import DeepSpeech, supported_rnns
import torch.distributed as dist
import torch.utils.data.distributed

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=32, help='Size of input')
parser.add_argument('--seconds', type=int, default=15,
help='The size of the fake input in seconds using default stride of 0.01, '
'15s is usually the maximum duration')
parser.add_argument('--dry-runs', type=int, default=20, help='Dry runs before measuring performance')
parser.add_argument('--runs', type=int, default=20, help='How many benchmark runs to measure performance')
parser.add_argument('--dry-runs', type=int, default=2, help='Dry runs before measuring performance')
parser.add_argument('--runs', type=int, default=5, help='How many benchmark runs to measure performance')
parser.add_argument('--labels-path', default='labels.json', help='Path to the labels to infer over in the model')
parser.add_argument('--hidden-size', default=800, type=int, help='Hidden size of RNNs')
parser.add_argument('--hidden-layers', default=5, type=int, help='Number of RNN layers')
parser.add_argument('--rnn-type', default='gru', help='Type of the RNN. rnn|gru|lstm are supported')
parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate')
parser.add_argument('--window-size', default=.02, type=float, help='Window size for spectrogram in seconds')
parser.add_argument('--num-samples', default=1024, type=int, help='Number of samples to go through')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:1550', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='gloo', type=str, help='distributed backend')
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--rank', default=0, type=int, help='The rank of this process')
args = parser.parse_args()

input = torch.randn(args.batch_size, 1, 161, args.seconds * 100).cuda()
args.distributed = args.world_size > 1
if args.distributed:
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)

if args.distributed:
input_data = torch.randn(int(args.num_samples / args.world_size), 1, 161, args.seconds * 100).cuda()
else:
input_data = torch.randn(args.num_samples, 1, 161, args.seconds * 100).cuda()
input_data = torch.chunk(input_data, int(len(input_data) / args.batch_size))

rnn_type = args.rnn_type.lower()
assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
Expand All @@ -44,22 +62,26 @@
parameters = model.parameters()
optimizer = torch.optim.SGD(parameters, lr=3e-4,
momentum=0.9, nesterov=True)
model = torch.nn.DataParallel(model).cuda()
if args.distributed:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
else:
model = torch.nn.DataParallel(model).cuda()

criterion = CTCLoss()

seconds = int(args.seconds)
batch_size = int(args.batch_size)


def iteration(input_data):
def iteration(data):
target = torch.IntTensor(int(batch_size * ((seconds * 100) / 2))).fill_(1) # targets, align half of the audio
target_size = torch.IntTensor(batch_size).fill_(int((seconds * 100) / 2))
input_percentages = torch.IntTensor(batch_size).fill_(1)

inputs = Variable(input_data, requires_grad=False)
inputs = Variable(data, requires_grad=False)
target_sizes = Variable(target_size, requires_grad=False)
targets = Variable(target, requires_grad=False)
start = time.time()
out = model(inputs)
out = out.transpose(0, 1) # TxNxH

Expand All @@ -72,26 +94,28 @@ def iteration(input_data):
loss.backward()
optimizer.step()
torch.cuda.synchronize()
end = time.time()
del loss
del out
return start, end


def run_benchmark(input_data):
def run_benchmark():
print("Running dry runs...")
for n in trange(args.dry_runs):
iteration(input_data)
for data in tqdm(input_data, total=len(input_data)):
iteration(data)

print("\n Running measured runs...")
running_time = 0
for n in trange(args.runs):
start, end = iteration(input_data)
running_time += end - start
start_time = time.time()
for data in tqdm(input_data, total=len(input_data)):
iteration(data)
end_time = time.time()
running_time += (end_time - start_time)

return running_time / float(args.runs)


run_time = run_benchmark(input)
run_time = run_benchmark()

print("\n Average run time: %.2fs" % run_time)
44 changes: 43 additions & 1 deletion data/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import subprocess
from tempfile import NamedTemporaryFile

from torch.distributed import get_rank
from torch.distributed import get_world_size
from torch.utils.data.sampler import Sampler

import librosa
import numpy as np
import scipy.signal
import torch
import torchaudio
import math
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

Expand Down Expand Up @@ -219,10 +223,48 @@ def __iter__(self):
def __len__(self):
return len(self.bins)

def shuffle(self):
def shuffle(self, epoch):
np.random.shuffle(self.bins)


class DistributedBucketingSampler(Sampler):
def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None):
"""
Samples batches assuming they are in order of size to batch similarly sized samples together.
"""
super(DistributedBucketingSampler, self).__init__(data_source)
if num_replicas is None:
num_replicas = get_world_size()
if rank is None:
rank = get_rank()
self.data_source = data_source
self.ids = list(range(0, len(data_source)))
self.batch_size = batch_size
self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)]
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas

def __iter__(self):
offset = self.rank
# add extra samples to make it evenly divisible
bins = self.bins + self.bins[:(self.total_size - len(self.bins))]
assert len(bins) == self.total_size
samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank
return iter(samples)

def __len__(self):
return self.num_samples

def shuffle(self, epoch):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(epoch)
bin_ids = list(torch.randperm(len(self.bins), generator=g))
self.bins = [self.bins[i] for i in bin_ids]


def get_audio_length(path):
output = subprocess.check_output(['soxi -D \"%s\"' % path.strip()], shell=True)
return float(output)
Expand Down
65 changes: 65 additions & 0 deletions data/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module

'''
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
launcher included with this example. It assumes that your run is using multiprocess with 1
GPU/process, that the model is on the correct device, and that torch.set_device has been
used to set the device.

Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
and will be allreduced at the finish of the backward pass.
'''


class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

self.module = module

for p in self.module.state_dict().values():
if not torch.is_tensor(p):
continue
if dist._backend == dist.dist_backend.NCCL:
assert p.is_cuda, "NCCL backend only supports model parameters to be on GPU."
dist.broadcast(p, 0)

def allreduce_params():
if (self.needs_reduction):
self.needs_reduction = False
buckets = {}
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False

for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

for param in list(self.module.parameters()):
def allreduce_hook(*unused):
param._execution_engine.queue_callback(allreduce_params)

if param.requires_grad:
param.register_hook(allreduce_hook)

def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)
33 changes: 33 additions & 0 deletions multiproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import sys
import subprocess

argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()

if '--world-size' in argslist:
argslist[argslist.index('--world-size') + 1] = str(world_size)
else:
argslist.append('--world-size')
argslist.append(str(world_size))

workers = []

for i in range(world_size):
if '--rank' in argslist:
argslist[argslist.index('--rank') + 1] = str(i)
else:
argslist.append('--rank')
argslist.append(str(i))
if '--gpu-rank' in argslist:
argslist[argslist.index('--gpu-rank') + 1] = str(i)
else:
argslist.append('--gpu-rank')
argslist.append(str(i))
stdout = None if i == 0 else open("GPU_" + str(i) + ".log", "w")
print(argslist)
p = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout, stderr=stdout)
workers.append(p)

for p in workers:
p.wait()
Loading