diff --git a/README.md b/README.md index 4e670c26..c34a24bf 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,20 @@ python train.py --tensorboard --logdir log_dir/ # Make sure the Tensorboard inst For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training. +## Multi-GPU Training + +We support multi-GPU training via the distributed parallel wrapper (see [here](https://github.com/NVIDIA/sentiment-discovery/blob/master/analysis/scale.md) and [here](https://github.com/SeanNaren/deepspeech.pytorch/issues/211) to see why we don't use DataParallel). + +To use multi-GPU: + +``` +python -m multiproc python train.py --visdom --cuda # Add your parameters as normal, multiproc will scale to all GPUs automatically +``` + +multiproc will open a log for all processes other than the main process. + +We suggest using the gloo backend which defaults to TCP if Infiniband isn't available. Using NCCL2 is also possible as a backend. More information [here](http://pytorch.org/docs/master/distributed.html#distributed-basics). + ### Noise Augmentation/Injection There is support for two different types of noise; noise augmentation and noise injection. diff --git a/benchmark.py b/benchmark.py index b3ed05a2..7e7f4998 100644 --- a/benchmark.py +++ b/benchmark.py @@ -3,26 +3,44 @@ import time import torch from torch.autograd import Variable +from tqdm import tqdm from warpctc_pytorch import CTCLoss from tqdm import trange from model import DeepSpeech, supported_rnns +import torch.distributed as dist +import torch.utils.data.distributed parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=int, default=32, help='Size of input') parser.add_argument('--seconds', type=int, default=15, help='The size of the fake input in seconds using default stride of 0.01, ' '15s is usually the maximum duration') -parser.add_argument('--dry-runs', type=int, default=20, help='Dry runs before measuring performance') -parser.add_argument('--runs', type=int, default=20, help='How many benchmark runs to measure performance') +parser.add_argument('--dry-runs', type=int, default=2, help='Dry runs before measuring performance') +parser.add_argument('--runs', type=int, default=5, help='How many benchmark runs to measure performance') parser.add_argument('--labels-path', default='labels.json', help='Path to the labels to infer over in the model') parser.add_argument('--hidden-size', default=800, type=int, help='Hidden size of RNNs') parser.add_argument('--hidden-layers', default=5, type=int, help='Number of RNN layers') parser.add_argument('--rnn-type', default='gru', help='Type of the RNN. rnn|gru|lstm are supported') parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') parser.add_argument('--window-size', default=.02, type=float, help='Window size for spectrogram in seconds') +parser.add_argument('--num-samples', default=1024, type=int, help='Number of samples to go through') +parser.add_argument('--dist_url', default='tcp://127.0.0.1:1550', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist_backend', default='gloo', type=str, help='distributed backend') +parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') +parser.add_argument('--rank', default=0, type=int, help='The rank of this process') args = parser.parse_args() -input = torch.randn(args.batch_size, 1, 161, args.seconds * 100).cuda() +args.distributed = args.world_size > 1 +if args.distributed: + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + +if args.distributed: + input_data = torch.randn(int(args.num_samples / args.world_size), 1, 161, args.seconds * 100).cuda() +else: + input_data = torch.randn(args.num_samples, 1, 161, args.seconds * 100).cuda() +input_data = torch.chunk(input_data, int(len(input_data) / args.batch_size)) rnn_type = args.rnn_type.lower() assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru" @@ -44,22 +62,26 @@ parameters = model.parameters() optimizer = torch.optim.SGD(parameters, lr=3e-4, momentum=0.9, nesterov=True) -model = torch.nn.DataParallel(model).cuda() +if args.distributed: + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) +else: + model = torch.nn.DataParallel(model).cuda() + criterion = CTCLoss() seconds = int(args.seconds) batch_size = int(args.batch_size) -def iteration(input_data): +def iteration(data): target = torch.IntTensor(int(batch_size * ((seconds * 100) / 2))).fill_(1) # targets, align half of the audio target_size = torch.IntTensor(batch_size).fill_(int((seconds * 100) / 2)) input_percentages = torch.IntTensor(batch_size).fill_(1) - inputs = Variable(input_data, requires_grad=False) + inputs = Variable(data, requires_grad=False) target_sizes = Variable(target_size, requires_grad=False) targets = Variable(target, requires_grad=False) - start = time.time() out = model(inputs) out = out.transpose(0, 1) # TxNxH @@ -72,26 +94,28 @@ def iteration(input_data): loss.backward() optimizer.step() torch.cuda.synchronize() - end = time.time() del loss del out - return start, end -def run_benchmark(input_data): +def run_benchmark(): print("Running dry runs...") for n in trange(args.dry_runs): - iteration(input_data) + for data in tqdm(input_data, total=len(input_data)): + iteration(data) print("\n Running measured runs...") running_time = 0 for n in trange(args.runs): - start, end = iteration(input_data) - running_time += end - start + start_time = time.time() + for data in tqdm(input_data, total=len(input_data)): + iteration(data) + end_time = time.time() + running_time += (end_time - start_time) return running_time / float(args.runs) -run_time = run_benchmark(input) +run_time = run_benchmark() print("\n Average run time: %.2fs" % run_time) diff --git a/data/data_loader.py b/data/data_loader.py index 90cb9d3c..cfa345a1 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,6 +1,9 @@ import os import subprocess from tempfile import NamedTemporaryFile + +from torch.distributed import get_rank +from torch.distributed import get_world_size from torch.utils.data.sampler import Sampler import librosa @@ -8,6 +11,7 @@ import scipy.signal import torch import torchaudio +import math from torch.utils.data import DataLoader from torch.utils.data import Dataset @@ -219,10 +223,48 @@ def __iter__(self): def __len__(self): return len(self.bins) - def shuffle(self): + def shuffle(self, epoch): np.random.shuffle(self.bins) +class DistributedBucketingSampler(Sampler): + def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): + """ + Samples batches assuming they are in order of size to batch similarly sized samples together. + """ + super(DistributedBucketingSampler, self).__init__(data_source) + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() + self.data_source = data_source + self.ids = list(range(0, len(data_source))) + self.batch_size = batch_size + self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)] + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + offset = self.rank + # add extra samples to make it evenly divisible + bins = self.bins + self.bins[:(self.total_size - len(self.bins))] + assert len(bins) == self.total_size + samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank + return iter(samples) + + def __len__(self): + return self.num_samples + + def shuffle(self, epoch): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(epoch) + bin_ids = list(torch.randperm(len(self.bins), generator=g)) + self.bins = [self.bins[i] for i in bin_ids] + + def get_audio_length(path): output = subprocess.check_output(['soxi -D \"%s\"' % path.strip()], shell=True) return float(output) diff --git a/data/distributed.py b/data/distributed.py new file mode 100644 index 00000000..cab6c05f --- /dev/null +++ b/data/distributed.py @@ -0,0 +1,65 @@ +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import torch.distributed as dist +from torch.nn.modules import Module + +''' +This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py +launcher included with this example. It assumes that your run is using multiprocess with 1 +GPU/process, that the model is on the correct device, and that torch.set_device has been +used to set the device. + +Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, +and will be allreduced at the finish of the backward pass. +''' + + +class DistributedDataParallel(Module): + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + + for p in self.module.state_dict().values(): + if not torch.is_tensor(p): + continue + if dist._backend == dist.dist_backend.NCCL: + assert p.is_cuda, "NCCL backend only supports model parameters to be on GPU." + dist.broadcast(p, 0) + + def allreduce_params(): + if (self.needs_reduction): + self.needs_reduction = False + buckets = {} + for param in self.module.parameters(): + if param.requires_grad and param.grad is not None: + tp = type(param.data) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case.") + self.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + for param in list(self.module.parameters()): + def allreduce_hook(*unused): + param._execution_engine.queue_callback(allreduce_params) + + if param.requires_grad: + param.register_hook(allreduce_hook) + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) diff --git a/multiproc.py b/multiproc.py new file mode 100644 index 00000000..a78e8a08 --- /dev/null +++ b/multiproc.py @@ -0,0 +1,33 @@ +import torch +import sys +import subprocess + +argslist = list(sys.argv)[1:] +world_size = torch.cuda.device_count() + +if '--world-size' in argslist: + argslist[argslist.index('--world-size') + 1] = str(world_size) +else: + argslist.append('--world-size') + argslist.append(str(world_size)) + +workers = [] + +for i in range(world_size): + if '--rank' in argslist: + argslist[argslist.index('--rank') + 1] = str(i) + else: + argslist.append('--rank') + argslist.append(str(i)) + if '--gpu-rank' in argslist: + argslist[argslist.index('--gpu-rank') + 1] = str(i) + else: + argslist.append('--gpu-rank') + argslist.append(str(i)) + stdout = None if i == 0 else open("GPU_" + str(i) + ".log", "w") + print(argslist) + p = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout, stderr=stdout) + workers.append(p) + +for p in workers: + p.wait() diff --git a/train.py b/train.py index 5f92832b..afce058f 100644 --- a/train.py +++ b/train.py @@ -4,11 +4,14 @@ import os import time -import torch -from tqdm import tqdm +import torch.distributed as dist +import torch.utils.data.distributed from torch.autograd import Variable +from tqdm import tqdm from warpctc_pytorch import CTCLoss -from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler + +from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler +from data.distributed import DistributedDataParallel from decoder import GreedyDecoder from model import DeepSpeech, supported_rnns @@ -59,6 +62,15 @@ help='Turn off shuffling and sample from dataset based on sequence length (smallest to largest)') parser.add_argument('--no-bidirectional', dest='bidirectional', action='store_false', default=True, help='Turn off bi-directional RNNs, introduces lookahead convolution') +parser.add_argument('--dist-url', default='tcp://127.0.0.1:1550', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend') +parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--rank', default=0, type=int, + help='The rank of this process') +parser.add_argument('--gpu-rank', default=None, + help='If using distributed parallel for multi-gpu, sets the GPU for the process') torch.manual_seed(123456) torch.cuda.manual_seed_all(123456) @@ -89,19 +101,27 @@ def update(self, val, n=1): if __name__ == '__main__': args = parser.parse_args() + args.distributed = args.world_size > 1 + main_proc = True + if args.distributed: + if args.gpu_rank: + torch.cuda.set_device(int(args.gpu_rank)) + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + main_proc = args.rank == 0 # Only the first proc should save models save_folder = args.save_folder loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor( args.epochs) best_wer = None - if args.visdom: + if args.visdom and main_proc: from visdom import Visdom viz = Visdom() opts = dict(title=args.id, ylabel='', xlabel='Epoch', legend=['Loss', 'WER', 'CER']) viz_window = None epochs = torch.arange(1, args.epochs + 1) - if args.tensorboard: + if args.tensorboard and main_proc: try: os.makedirs(args.log_dir) except OSError as e: @@ -151,7 +171,7 @@ def update(self, val, n=1): avg_loss = int(package.get('avg_loss', 0)) loss_results, cer_results, wer_results = package['loss_results'], package[ 'cer_results'], package['wer_results'] - if args.visdom and \ + if main_proc and args.visdom and \ package[ 'loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph x_axis = epochs[0:start_epoch] @@ -163,7 +183,7 @@ def update(self, val, n=1): Y=y_axis, opts=opts, ) - if args.tensorboard and \ + if main_proc and args.tensorboard and \ package[ 'loss_results'] is not None and start_epoch > 0: # Previous scores to tensorboard logs for i in range(start_epoch): @@ -202,7 +222,11 @@ def update(self, val, n=1): normalize=True, augment=args.augment) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, normalize=True, augment=False) - train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) + if not args.distributed: + train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size) + else: + train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size, + num_replicas=args.world_size, rank=args.rank) train_loader = AudioDataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, @@ -210,10 +234,13 @@ def update(self, val, n=1): if not args.no_shuffle and start_epoch != 0: print("Shuffling batches for the following epochs") - train_sampler.shuffle() + train_sampler.shuffle(start_epoch) - if args.cuda: + if args.cuda and not args.distributed: model = torch.nn.DataParallel(model).cuda() + elif args.cuda and args.distributed: + model.cuda() + model = DistributedDataParallel(model) print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) @@ -280,7 +307,7 @@ def update(self, val, n=1): 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( (epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time, data_time=data_time, loss=losses)) - if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0: + if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc: file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (save_folder, epoch + 1, i + 1) print("Saving checkpoint model to %s" % file_path) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i, @@ -344,7 +371,7 @@ def update(self, val, n=1): 'Average CER {cer:.3f}\t'.format( epoch + 1, wer=wer, cer=cer)) - if args.visdom: + if args.visdom and main_proc: x_axis = epochs[0:epoch + 1] y_axis = torch.stack((loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1) if viz_window is None: @@ -360,7 +387,7 @@ def update(self, val, n=1): win=viz_window, update='replace', ) - if args.tensorboard: + if args.tensorboard and main_proc: values = { 'Avg Train Loss': avg_loss, 'Avg WER': wer, @@ -372,7 +399,7 @@ def update(self, val, n=1): tag = tag.replace('.', '/') tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1) tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1) - if args.checkpoint: + if args.checkpoint and main_proc: file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, wer_results=wer_results, cer_results=cer_results), @@ -383,7 +410,7 @@ def update(self, val, n=1): optimizer.load_state_dict(optim_state) print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr'])) - if best_wer is None or best_wer > wer: + if (best_wer is None or best_wer > wer) and main_proc: print("Found better validated model, saving to %s" % args.model_path) torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results, wer_results=wer_results, cer_results=cer_results) @@ -393,4 +420,4 @@ def update(self, val, n=1): avg_loss = 0 if not args.no_shuffle: print("Shuffling batches...") - train_sampler.shuffle() + train_sampler.shuffle(epoch)