diff --git a/examples/multigpu_dataparallel.py b/examples/multigpu_dataparallel.py new file mode 100644 index 0000000..e484a3c --- /dev/null +++ b/examples/multigpu_dataparallel.py @@ -0,0 +1,71 @@ +import time + +import numpy as np + +import torch +import torch.nn as nn + +class Model(nn.Module): + + def __init__(self, hidden_size=1024, parallel=True, layers=3, vocab=100): + super().__init__() + + self.embedding = nn.Embedding(vocab, hidden_size) + + from torchqrnn import QRNN + self.rnn = QRNN(hidden_size, hidden_size, num_layers=layers) + #self.rnn = nn.LSTM(hidden_size, hidden_size) + # Note: we tell DataParallel to split on the second dimension as RNNs are batch second by default in PyTorch + if parallel: self.rnn = nn.DataParallel(self.rnn, dim=1) + + def forward(self, x): + x = self.embedding(x) + out, hidden = self.rnn(x) + return out[:-1] + +H = 256 +SEQ = 100 +BATCH = 64 + +H = 1024 +SEQ = 500 +BATCH = 128 + +LOOPS = 500 + +np.random.seed(42) +torch.manual_seed(42) +torch.cuda.manual_seed(42) + +x = torch.autograd.Variable(torch.LongTensor(np.random.randint(0, 100, [BATCH, SEQ]))) +x = x.cuda() + +np.random.seed(42) +torch.manual_seed(42) +torch.cuda.manual_seed(42) + +print('Single') +model = Model(H, parallel=False) +model = model.cuda() +# Call once to compile CUDA kernel / set up new GPUs +model(x) +start = time.time() +for _ in range(LOOPS): y = model(x) +print('Time:', time.time() - start) +del model + +np.random.seed(42) +torch.manual_seed(42) +torch.cuda.manual_seed(42) + +print('Multi') +model = Model(H, parallel=True) +model = model.cuda() +# Call once to compile CUDA kernel / set up new GPUs +model(x) +start = time.time() +for _ in range(LOOPS): y2 = model(x) +print('Time:', time.time() - start) + +print('Difference:') +print((y - y2).sum()) diff --git a/setup.py b/setup.py index f64736e..b2ba705 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,5 @@ name='PyTorch-QRNN', version='0.1', packages=['torchqrnn',], - license='BSD 3-Clause License', - long_description=open('README.md').read(), + license='BSD 3-Clause License' ) diff --git a/torchqrnn/forget_mult.py b/torchqrnn/forget_mult.py index 156db5c..de26046 100644 --- a/torchqrnn/forget_mult.py +++ b/torchqrnn/forget_mult.py @@ -1,7 +1,8 @@ import math import torch from torch.autograd import Variable -from cupy.cuda import function +if torch.cuda.is_available(): + from cupy.cuda import function from pynvrtc.compiler import Program from collections import namedtuple @@ -89,29 +90,32 @@ def forward(self, f, x, hidden_init=None): class GPUForgetMult(torch.autograd.Function): - forget_mult = None - bwd_forget_mult = None - stream = None + configured_gpus = {} + ptx = None def __init__(self): super(GPUForgetMult, self).__init__() - if not self.forget_mult or not self.bwd_forget_mult: - GPUForgetMult.compile() - - @staticmethod - def compile(): - program = Program(kernel.encode(), 'recurrent_forget_mult.cu'.encode()) - ptx = program.compile() - - m = function.Module() - m.load(bytes(ptx.encode())) - - GPUForgetMult.forget_mult = m.get_function('recurrent_forget_mult') - GPUForgetMult.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult') - - Stream = namedtuple('Stream', ['ptr']) - GPUForgetMult.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + + def compile(self): + if self.ptx is None: + program = Program(kernel.encode(), 'recurrent_forget_mult.cu'.encode()) + GPUForgetMult.ptx = program.compile() + + if torch.cuda.current_device() not in GPUForgetMult.configured_gpus: + m = function.Module() + m.load(bytes(self.ptx.encode())) + + self.forget_mult = m.get_function('recurrent_forget_mult') + self.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult') + + Stream = namedtuple('Stream', ['ptr']) + self.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream) + + GPUForgetMult.configured_gpus[torch.cuda.current_device()] = (self.forget_mult, self.bwd_forget_mult, self.stream) + + self.forget_mult, self.bwd_forget_mult, self.stream = GPUForgetMult.configured_gpus[torch.cuda.current_device()] def forward(self, f, x, hidden_init=None): + self.compile() seq_size, batch_size, hidden_size = f.size() result = f.new(seq_size + 1, batch_size, hidden_size) # We only zero the result array (result[0]) if we don't set a hidden initial state @@ -127,6 +131,7 @@ def forward(self, f, x, hidden_init=None): return result[1:, :, :] def backward(self, grad_h): + self.compile() f, x, hidden_init = self.saved_tensors h = self.result ### @@ -228,7 +233,7 @@ def forward(self, f, x, hidden_init=None, use_cuda=True): print('=-=-' * 5) residual = (resulta - resultb) print(residual.abs().sum().data[0]) - + # Had to loosen gradient checking, potentially due to general floating point badness? from torch.autograd import gradcheck inputs = [forget, a, last_h] diff --git a/torchqrnn/qrnn.py b/torchqrnn/qrnn.py index 71f3cda..ca5a96b 100644 --- a/torchqrnn/qrnn.py +++ b/torchqrnn/qrnn.py @@ -35,7 +35,7 @@ def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, w assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2" self.window = window self.input_size = input_size - self.hidden_size = hidden_size if hidden_size else hidden_size + self.hidden_size = hidden_size if hidden_size else input_size self.zoneout = zoneout self.save_prev_x = save_prev_x self.prevX = None