Skip to content
This repository was archived by the owner on Feb 12, 2022. It is now read-only.

Modifications to work with pytorch 0.4 #43

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions embed_regularize.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
import numpy as np

import torch
from torch.autograd import Variable
from torch.nn import functional as F


def embedded_dropout(embed, words, dropout=0.1, scale=None):
if dropout:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
mask = Variable(mask)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight

padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = embed._backend.Embedding.apply(words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
return X
if dropout:
mask = embed.weight.data.new().resize_(
(embed.weight.size(0), 1)).bernoulli_(
1 - dropout).expand_as(embed.weight) / (1 - dropout)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(
masked_embed_weight) * masked_embed_weight

padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1

X = F.embedding(
words, masked_embed_weight,
padding_idx, embed.max_norm,
embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
return X

if __name__ == '__main__':
V = 50
h = 4
bptt = 10
batch_size = 2
V = 50
h = 4
bptt = 10
batch_size = 2

embed = torch.nn.Embedding(V, h)
embed = torch.nn.Embedding(V, h)

words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt))
words = torch.LongTensor(words)
words = Variable(words)
words = np.random.random_integers(
low=0, high=V - 1, size=(batch_size, bptt))
words = torch.LongTensor(words)

origX = embed(words)
X = embedded_dropout(embed, words)
origX = embed(words)
X = embedded_dropout(embed, words)

print(origX)
print(X)
print(origX)
print(X)
11 changes: 6 additions & 5 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
np.random.seed(331)
import torch
import torch.nn as nn
from torch.autograd import Variable

import data
import model
Expand Down Expand Up @@ -104,15 +103,16 @@

def evaluate(data_source, batch_size=10):
# Turn on evaluation mode which disables dropout.
if args.model == 'QRNN': model.reset()
if args.model == 'QRNN':
model.reset()
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
output_flat = model.decoder(output).view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
Expand Down Expand Up @@ -144,6 +144,7 @@ def train():
optimizer.zero_grad()

output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
output = model.decoder(output)
raw_loss = criterion(output.view(-1, ntokens), targets)

loss = raw_loss
Expand All @@ -154,7 +155,7 @@ def train():
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

total_loss += raw_loss.data
Expand All @@ -175,7 +176,7 @@ def train():

# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
model = torch.load(f)[0]


# Loop over epochs.
Expand Down
13 changes: 7 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

import data
import model
Expand Down Expand Up @@ -166,7 +165,7 @@ def evaluate(data_source, batch_size=10):
output, hidden = model(data, hidden)
total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
return total_loss.item() / len(data_source)


def train():
Expand Down Expand Up @@ -203,15 +202,18 @@ def train():
# Temporal Activation Regularization (slowness)
if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
loss.backward()
# print([(name, abs(p.grad).sum().item())
# for name, p in model.named_parameters()])

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
if args.clip: torch.nn.utils.clip_grad_norm(params, args.clip)
if args.clip:
torch.nn.utils.clip_grad_norm_(params, args.clip)
optimizer.step()

total_loss += raw_loss.data
optimizer.param_groups[0]['lr'] = lr2
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
cur_loss = total_loss.item() / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
Expand Down Expand Up @@ -244,12 +246,11 @@ def train():
for prm in model.parameters():
tmp[prm] = prm.data.clone()
prm.data = optimizer.state[prm]['ax'].clone()

val_loss2 = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2)))
epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2)))
print('-' * 89)

if val_loss2 < stored_loss:
Expand Down
14 changes: 6 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

from embed_regularize import embedded_dropout
from locked_dropout import LockedDropout
from weight_drop import WeightDrop


class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

Expand All @@ -30,7 +30,7 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=
self.rnns = [QRNNLayer(input_size=ninp if l == 0 else nhid, hidden_size=nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(nlayers)]
for rnn in self.rnns:
rnn.linear = WeightDrop(rnn.linear, ['weight'], dropout=wdrop)
print(self.rnns)

self.rnns = torch.nn.ModuleList(self.rnns)
self.decoder = nn.Linear(nhid, ntoken)

Expand Down Expand Up @@ -71,19 +71,17 @@ def forward(self, input, hidden, return_h=False):
#emb = self.idrop(emb)

emb = self.lockdrop(emb, self.dropouti)

raw_output = emb
new_hidden = []
#raw_output, hidden = self.rnn(emb, hidden)
raw_outputs = []
outputs = []
for l, rnn in enumerate(self.rnns):
current_input = raw_output
raw_output, new_h = rnn(raw_output, hidden[l])
new_hidden.append(new_h)
raw_outputs.append(raw_output)
if l != self.nlayers - 1:
#self.hdrop(raw_output)
# self.hdrop(raw_output)
raw_output = self.lockdrop(raw_output, self.dropouth)
outputs.append(raw_output)
hidden = new_hidden
Expand All @@ -99,9 +97,9 @@ def forward(self, input, hidden, return_h=False):
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return [(Variable(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()),
Variable(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()))
return [(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_(),
weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_())
for l in range(self.nlayers)]
elif self.rnn_type == 'QRNN' or self.rnn_type == 'GRU':
return [Variable(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_())
return [weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()
for l in range(self.nlayers)]
16 changes: 10 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from torch.autograd import Variable
import torch


def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
"""Wraps hidden states in new Tensors,
to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)


def batchify(data, bsz, args):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
Expand All @@ -18,8 +21,9 @@ def batchify(data, bsz, args):
data = data.cuda()
return data


def get_batch(source, i, args, seq_len=None, evaluation=False):
seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target