Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Mini: Dont include CuPy dependency if CUDA not available #4

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions examples/multigpu_dataparallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import time

import numpy as np

import torch
import torch.nn as nn

class Model(nn.Module):

def __init__(self, hidden_size=1024, parallel=True, layers=3, vocab=100):
super().__init__()

self.embedding = nn.Embedding(vocab, hidden_size)

from torchqrnn import QRNN
self.rnn = QRNN(hidden_size, hidden_size, num_layers=layers)
#self.rnn = nn.LSTM(hidden_size, hidden_size)
# Note: we tell DataParallel to split on the second dimension as RNNs are batch second by default in PyTorch
if parallel: self.rnn = nn.DataParallel(self.rnn, dim=1)

def forward(self, x):
x = self.embedding(x)
out, hidden = self.rnn(x)
return out[:-1]

H = 256
SEQ = 100
BATCH = 64

H = 1024
SEQ = 500
BATCH = 128

LOOPS = 500

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

x = torch.autograd.Variable(torch.LongTensor(np.random.randint(0, 100, [BATCH, SEQ])))
x = x.cuda()

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

print('Single')
model = Model(H, parallel=False)
model = model.cuda()
# Call once to compile CUDA kernel / set up new GPUs
model(x)
start = time.time()
for _ in range(LOOPS): y = model(x)
print('Time:', time.time() - start)
del model

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

print('Multi')
model = Model(H, parallel=True)
model = model.cuda()
# Call once to compile CUDA kernel / set up new GPUs
model(x)
start = time.time()
for _ in range(LOOPS): y2 = model(x)
print('Time:', time.time() - start)

print('Difference:')
print((y - y2).sum())
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
name='PyTorch-QRNN',
version='0.1',
packages=['torchqrnn',],
license='BSD 3-Clause License',
long_description=open('README.md').read(),
license='BSD 3-Clause License'
)
47 changes: 26 additions & 21 deletions torchqrnn/forget_mult.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import torch
from torch.autograd import Variable
from cupy.cuda import function
if torch.cuda.is_available():
from cupy.cuda import function
from pynvrtc.compiler import Program
from collections import namedtuple

Expand Down Expand Up @@ -89,29 +90,32 @@ def forward(self, f, x, hidden_init=None):


class GPUForgetMult(torch.autograd.Function):
forget_mult = None
bwd_forget_mult = None
stream = None
configured_gpus = {}
ptx = None
def __init__(self):
super(GPUForgetMult, self).__init__()
if not self.forget_mult or not self.bwd_forget_mult:
GPUForgetMult.compile()

@staticmethod
def compile():
program = Program(kernel.encode(), 'recurrent_forget_mult.cu'.encode())
ptx = program.compile()

m = function.Module()
m.load(bytes(ptx.encode()))

GPUForgetMult.forget_mult = m.get_function('recurrent_forget_mult')
GPUForgetMult.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult')

Stream = namedtuple('Stream', ['ptr'])
GPUForgetMult.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)

def compile(self):
if self.ptx is None:
program = Program(kernel.encode(), 'recurrent_forget_mult.cu'.encode())
GPUForgetMult.ptx = program.compile()

if torch.cuda.current_device() not in GPUForgetMult.configured_gpus:
m = function.Module()
m.load(bytes(self.ptx.encode()))

self.forget_mult = m.get_function('recurrent_forget_mult')
self.bwd_forget_mult = m.get_function('bwd_recurrent_forget_mult')

Stream = namedtuple('Stream', ['ptr'])
self.stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)

GPUForgetMult.configured_gpus[torch.cuda.current_device()] = (self.forget_mult, self.bwd_forget_mult, self.stream)

self.forget_mult, self.bwd_forget_mult, self.stream = GPUForgetMult.configured_gpus[torch.cuda.current_device()]

def forward(self, f, x, hidden_init=None):
self.compile()
seq_size, batch_size, hidden_size = f.size()
result = f.new(seq_size + 1, batch_size, hidden_size)
# We only zero the result array (result[0]) if we don't set a hidden initial state
Expand All @@ -127,6 +131,7 @@ def forward(self, f, x, hidden_init=None):
return result[1:, :, :]

def backward(self, grad_h):
self.compile()
f, x, hidden_init = self.saved_tensors
h = self.result
###
Expand Down Expand Up @@ -228,7 +233,7 @@ def forward(self, f, x, hidden_init=None, use_cuda=True):
print('=-=-' * 5)
residual = (resulta - resultb)
print(residual.abs().sum().data[0])

# Had to loosen gradient checking, potentially due to general floating point badness?
from torch.autograd import gradcheck
inputs = [forget, a, last_h]
Expand Down
2 changes: 1 addition & 1 deletion torchqrnn/qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, w
assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2"
self.window = window
self.input_size = input_size
self.hidden_size = hidden_size if hidden_size else hidden_size
self.hidden_size = hidden_size if hidden_size else input_size
self.zoneout = zoneout
self.save_prev_x = save_prev_x
self.prevX = None
Expand Down