From a2a84485429cc407c1c609843c232c2f80c02e82 Mon Sep 17 00:00:00 2001 From: Stefan Oehmcke Date: Fri, 1 Jun 2018 13:25:52 +0200 Subject: [PATCH] Adding bidirectional functionality to IndRNN class and parameter to activate on test problems. Updates #2 --- addition_test.py | 11 +++++-- indrnn.py | 73 +++++++++++++++++++++++++++++++++++------------ seq_mnist_test.py | 26 +++++++++++------ 3 files changed, 79 insertions(+), 31 deletions(-) diff --git a/addition_test.py b/addition_test.py index 864dedf..5785c95 100644 --- a/addition_test.py +++ b/addition_test.py @@ -27,6 +27,8 @@ help='disables CUDA training') parser.add_argument('--batch-norm', action='store_true', default=False, help='enable frame-wise batch normalization after each layer') +parser.add_argument('--bidirectional', action='store_true', default=False, + help='enable bidirectional processing') parser.add_argument('--log-interval', type=int, default=100, help='after how many iterations to report performance') parser.add_argument('--model', type=str, default="IndRNN", @@ -46,14 +48,17 @@ class Net(nn.Module): def __init__(self, input_size, hidden_size, n_layer=2): super(Net, self).__init__() self.indrnn = IndRNN( - input_size, hidden_size, n_layer, batch_norm=args.batch_norm, + input_size, hidden_size, + n_layer, batch_norm=args.batch_norm, + bidirectional=args.bidirectional, hidden_max_abs=RECURRENT_MAX) - self.lin = nn.Linear(hidden_size, 1) + self.lin = nn.Linear( + hidden_size*2 if args.bidirectional else hidden_size, 1) self.lin.bias.data.fill_(.1) self.lin.weight.data.normal_(0, .01) def forward(self, x, hidden=None): - y = self.indrnn(x, hidden) + y, _ = self.indrnn(x, hidden) return self.lin(y[-1]).squeeze(1) diff --git a/indrnn.py b/indrnn.py index 69816bc..0093264 100644 --- a/indrnn.py +++ b/indrnn.py @@ -177,12 +177,16 @@ class IndRNN(nn.Module): """ def __init__(self, input_size, hidden_size, n_layer=1, batch_norm=False, - batch_first=False, **kwargs): + batch_first=False, bidirectional=False, **kwargs): super(IndRNN, self).__init__() self.hidden_size = hidden_size self.batch_norm = batch_norm self.n_layer = n_layer self.batch_first = batch_first + self.bidirectional = bidirectional + + num_directions = 2 if self.bidirectional else 1 + if batch_first: self.time_index = 1 self.batch_index = 0 @@ -194,19 +198,25 @@ def __init__(self, input_size, hidden_size, n_layer=1, batch_norm=False, cells = [] for i in range(n_layer): - if i == 0: - cells += [IndRNNCell(input_size, hidden_size, **kwargs)] - else: - cells += [IndRNNCell(hidden_size, hidden_size, **kwargs)] + directions = [] + for dir in range(num_directions): + if i == 0: + directions += [IndRNNCell(input_size, hidden_size, **kwargs)] + else: + directions += [IndRNNCell(hidden_size*num_directions, hidden_size, **kwargs)] + cells += [nn.ModuleList(directions)] self.cells = nn.ModuleList(cells) if batch_norm: bns = [] for i in range(n_layer): - bns += [nn.BatchNorm1d(hidden_size)] + directions = [] + for dir in range(num_directions): + directions += [nn.BatchNorm1d(hidden_size)] + bns += [nn.ModuleList(directions)] self.bns = nn.ModuleList(bns) - h0 = torch.zeros(hidden_size) + h0 = torch.zeros(hidden_size*num_directions) self.register_buffer('h0', torch.autograd.Variable(h0)) @@ -216,16 +226,41 @@ def _gather_batch_first(self, x, index): def _gather_time_first(self, x, index): return x[index] + def forward(self, x, hidden=None): - for i, cell in enumerate(self.cells): - cell.check_bounds() - hx = self.h0.unsqueeze(0).expand(x.size(self.batch_index), self.hidden_size).contiguous() - outputs = [] - for t in range(x.size(self.time_index)): - x_t = self._gather(x, t) - hx = cell(x_t, hx) - if self.batch_norm: - hx = self.bns[i](hx) - outputs += [hx] - x = torch.stack(outputs, self.time_index) - return x.squeeze(2) + num_directions = 2 if self.bidirectional else 1 + + for i, directions in enumerate(self.cells): + hx = self.h0.unsqueeze(0).expand( + x.size(self.batch_index), + self.hidden_size*num_directions).contiguous() + + x_n = [] + for dir, cell in enumerate(directions): + hx_cell = hx[ + :, self.hidden_size*dir: self.hidden_size*(dir+1)] + cell.check_bounds() + outputs = [] + hiddens = [] + r = range(x.size(self.time_index)) + if dir == 1: + r = reversed(r) + for t in r: + x_t = self._gather(x, t) + hx_cell = cell(x_t, hx_cell) + if self.batch_norm: + hx_cell = self.bns[i][dir](hx_cell) + outputs += [hx_cell] + x_cell = torch.stack(outputs, self.time_index) + if dir == 1: + flip(x_cell, self.time_index) + x_n += [x_cell] + hiddens += [hx_cell] + x = torch.cat(x_n, -1) + return x.squeeze(2), torch.cat(hiddens, -1) + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] diff --git a/seq_mnist_test.py b/seq_mnist_test.py index dcf78ac..4275260 100644 --- a/seq_mnist_test.py +++ b/seq_mnist_test.py @@ -24,13 +24,15 @@ parser.add_argument('--no-batch-norm', action='store_true', default=False, help='disable frame-wise batch normalization after each layer') parser.add_argument('--log_epoch', type=int, default=1, - help='after how many iterations to report performance') - - + help='after how many epochs to report performance') +parser.add_argument('--log_iteration', type=int, default=-1, + help='after how many iterations to report performance, deactivates with -1 (default: -1)') +parser.add_argument('--bidirectional', action='store_true', default=False, +help='enable bidirectional processing') parser.add_argument('--batch-size', type=int, default=256, help='input batch size for training (default: 256)') parser.add_argument('--max-steps', type=int, default=10000, - help='input batch size for training (default: 10000)') + help='max iterations of training (default: 10000)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() @@ -49,15 +51,16 @@ def __init__(self, input_size, hidden_size, n_layer=2): super(Net, self).__init__() self.indrnn = IndRNN( input_size, hidden_size, n_layer, batch_norm=args.batch_norm, - hidden_max_abs=RECURRENT_MAX, - batch_first=True) - self.lin = nn.Linear(hidden_size, 10) + hidden_max_abs=RECURRENT_MAX, batch_first=True, + bidirectional=args.bidirectional) + self.lin = nn.Linear( + hidden_size*2 if args.bidirectional else hidden_size, 10) self.lin.bias.data.fill_(.1) self.lin.weight.data.normal_(0, .01) def forward(self, x, hidden=None): y = self.indrnn(x, hidden) - return self.lin(y[:, -1]) + return self.lin(y[:, -1]).squeeze(1) def main(): @@ -86,8 +89,13 @@ def main(): loss = F.cross_entropy(out, target) loss.backward() optimizer.step() - losses.append(loss.data.cpu()[0]) + losses.append(loss.data.cpu().item()) step += 1 + + if step % args.log_iteration == 0 and args.log_iteration != -1: + print( + "\tStep {} cross_entropy {}".format( + step, np.mean(losses))) if step >= args.max_steps: break if epochs % args.log_epoch == 0: