Skip to content

Commit

Permalink
Adding bidirectional functionality to IndRNN class and parameter to a…
Browse files Browse the repository at this point in the history
…ctivate on test problems. Updates #2
  • Loading branch information
Stefan Oehmcke committed Jun 1, 2018
1 parent a8a2584 commit a2a8448
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 31 deletions.
11 changes: 8 additions & 3 deletions addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
help='disables CUDA training')
parser.add_argument('--batch-norm', action='store_true', default=False,
help='enable frame-wise batch normalization after each layer')
parser.add_argument('--bidirectional', action='store_true', default=False,
help='enable bidirectional processing')
parser.add_argument('--log-interval', type=int, default=100,
help='after how many iterations to report performance')
parser.add_argument('--model', type=str, default="IndRNN",
Expand All @@ -46,14 +48,17 @@ class Net(nn.Module):
def __init__(self, input_size, hidden_size, n_layer=2):
super(Net, self).__init__()
self.indrnn = IndRNN(
input_size, hidden_size, n_layer, batch_norm=args.batch_norm,
input_size, hidden_size,
n_layer, batch_norm=args.batch_norm,
bidirectional=args.bidirectional,
hidden_max_abs=RECURRENT_MAX)
self.lin = nn.Linear(hidden_size, 1)
self.lin = nn.Linear(
hidden_size*2 if args.bidirectional else hidden_size, 1)
self.lin.bias.data.fill_(.1)
self.lin.weight.data.normal_(0, .01)

def forward(self, x, hidden=None):
y = self.indrnn(x, hidden)
y, _ = self.indrnn(x, hidden)
return self.lin(y[-1]).squeeze(1)


Expand Down
73 changes: 54 additions & 19 deletions indrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,16 @@ class IndRNN(nn.Module):
"""

def __init__(self, input_size, hidden_size, n_layer=1, batch_norm=False,
batch_first=False, **kwargs):
batch_first=False, bidirectional=False, **kwargs):
super(IndRNN, self).__init__()
self.hidden_size = hidden_size
self.batch_norm = batch_norm
self.n_layer = n_layer
self.batch_first = batch_first
self.bidirectional = bidirectional

num_directions = 2 if self.bidirectional else 1

if batch_first:
self.time_index = 1
self.batch_index = 0
Expand All @@ -194,19 +198,25 @@ def __init__(self, input_size, hidden_size, n_layer=1, batch_norm=False,

cells = []
for i in range(n_layer):
if i == 0:
cells += [IndRNNCell(input_size, hidden_size, **kwargs)]
else:
cells += [IndRNNCell(hidden_size, hidden_size, **kwargs)]
directions = []
for dir in range(num_directions):
if i == 0:
directions += [IndRNNCell(input_size, hidden_size, **kwargs)]
else:
directions += [IndRNNCell(hidden_size*num_directions, hidden_size, **kwargs)]
cells += [nn.ModuleList(directions)]
self.cells = nn.ModuleList(cells)

if batch_norm:
bns = []
for i in range(n_layer):
bns += [nn.BatchNorm1d(hidden_size)]
directions = []
for dir in range(num_directions):
directions += [nn.BatchNorm1d(hidden_size)]
bns += [nn.ModuleList(directions)]
self.bns = nn.ModuleList(bns)

h0 = torch.zeros(hidden_size)
h0 = torch.zeros(hidden_size*num_directions)
self.register_buffer('h0', torch.autograd.Variable(h0))


Expand All @@ -216,16 +226,41 @@ def _gather_batch_first(self, x, index):
def _gather_time_first(self, x, index):
return x[index]


def forward(self, x, hidden=None):
for i, cell in enumerate(self.cells):
cell.check_bounds()
hx = self.h0.unsqueeze(0).expand(x.size(self.batch_index), self.hidden_size).contiguous()
outputs = []
for t in range(x.size(self.time_index)):
x_t = self._gather(x, t)
hx = cell(x_t, hx)
if self.batch_norm:
hx = self.bns[i](hx)
outputs += [hx]
x = torch.stack(outputs, self.time_index)
return x.squeeze(2)
num_directions = 2 if self.bidirectional else 1

for i, directions in enumerate(self.cells):
hx = self.h0.unsqueeze(0).expand(
x.size(self.batch_index),
self.hidden_size*num_directions).contiguous()

x_n = []
for dir, cell in enumerate(directions):
hx_cell = hx[
:, self.hidden_size*dir: self.hidden_size*(dir+1)]
cell.check_bounds()
outputs = []
hiddens = []
r = range(x.size(self.time_index))
if dir == 1:
r = reversed(r)
for t in r:
x_t = self._gather(x, t)
hx_cell = cell(x_t, hx_cell)
if self.batch_norm:
hx_cell = self.bns[i][dir](hx_cell)
outputs += [hx_cell]
x_cell = torch.stack(outputs, self.time_index)
if dir == 1:
flip(x_cell, self.time_index)
x_n += [x_cell]
hiddens += [hx_cell]
x = torch.cat(x_n, -1)
return x.squeeze(2), torch.cat(hiddens, -1)

def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
26 changes: 17 additions & 9 deletions seq_mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
parser.add_argument('--no-batch-norm', action='store_true', default=False,
help='disable frame-wise batch normalization after each layer')
parser.add_argument('--log_epoch', type=int, default=1,
help='after how many iterations to report performance')


help='after how many epochs to report performance')
parser.add_argument('--log_iteration', type=int, default=-1,
help='after how many iterations to report performance, deactivates with -1 (default: -1)')
parser.add_argument('--bidirectional', action='store_true', default=False,
help='enable bidirectional processing')
parser.add_argument('--batch-size', type=int, default=256,
help='input batch size for training (default: 256)')
parser.add_argument('--max-steps', type=int, default=10000,
help='input batch size for training (default: 10000)')
help='max iterations of training (default: 10000)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
Expand All @@ -49,15 +51,16 @@ def __init__(self, input_size, hidden_size, n_layer=2):
super(Net, self).__init__()
self.indrnn = IndRNN(
input_size, hidden_size, n_layer, batch_norm=args.batch_norm,
hidden_max_abs=RECURRENT_MAX,
batch_first=True)
self.lin = nn.Linear(hidden_size, 10)
hidden_max_abs=RECURRENT_MAX, batch_first=True,
bidirectional=args.bidirectional)
self.lin = nn.Linear(
hidden_size*2 if args.bidirectional else hidden_size, 10)
self.lin.bias.data.fill_(.1)
self.lin.weight.data.normal_(0, .01)

def forward(self, x, hidden=None):
y = self.indrnn(x, hidden)
return self.lin(y[:, -1])
return self.lin(y[:, -1]).squeeze(1)


def main():
Expand Down Expand Up @@ -86,8 +89,13 @@ def main():
loss = F.cross_entropy(out, target)
loss.backward()
optimizer.step()
losses.append(loss.data.cpu()[0])
losses.append(loss.data.cpu().item())
step += 1

if step % args.log_iteration == 0 and args.log_iteration != -1:
print(
"\tStep {} cross_entropy {}".format(
step, np.mean(losses)))
if step >= args.max_steps:
break
if epochs % args.log_epoch == 0:
Expand Down

0 comments on commit a2a8448

Please sign in to comment.