-
Notifications
You must be signed in to change notification settings - Fork 3
/
layers.py
232 lines (188 loc) · 8.48 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dropout import RNNDropout, EmbeddingDropout, WeightDropout
from utils import repackage_hidden
class RNNModel(nn.Module):
"""
Wrapper for Language Models.
"""
def __init__(self, encoder, decoder, tie_weights=True, initrange=0.1):
super(RNNModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.groups = ['encoder', 'rnn.0', 'rnn.1', 'rnn.2', 'decoder']
if tie_weights:
self.decoder.fc1.weight = self.encoder.embeddings.weight
# Initialize parameters
self.encoder.embeddings.weight.data.uniform_(-initrange, initrange)
self.decoder.fc1.bias.data.zero_()
self.decoder.fc1.weight.data.uniform_(-initrange, initrange)
def reset_hidden(self):
self.encoder.reset_hidden()
def freeze(self):
for p in self.parameters():
p.requires_grad = False
def unfreeze(self, ix):
to_unfreeze = self.groups[ix:]
for n, p in self.named_parameters():
for group in to_unfreeze:
if group in n: p.requires_grad = True
def unfreeze_all(self):
for p in self.parameters():
p.requires_grad = True
def forward(self, x, **kwargs):
out = self.decoder(*self.encoder(x), **kwargs)
return out
class RNNClassifier(nn.Module):
"""
Wrapper for Classifier. Used for ULMFiT.
"""
def __init__(self, encoder, decoder):
super(RNNClassifier, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.groups = ['encoder', 'rnn.0', 'rnn.1', 'rnn.2', 'decoder']
def reset_hidden(self):
self.encoder.reset_hidden()
def freeze(self):
for p in self.parameters():
p.requires_grad = False
def unfreeze(self, ix):
to_unfreeze = self.groups[ix:]
for n, p in self.named_parameters():
for group in to_unfreeze:
if group in n: p.requires_grad = True
def unfreeze_all(self):
for p in self.parameters():
p.requires_grad = True
def forward(self, x):
out, hidden, raw_out, dropped_out = self.encoder(x)
logits = self.decoder(out, hidden[-1])
return logits
class AWDLSTMEncoder(nn.Module):
"""
AWD-LSTM Encoder as proposed by Merity et al. (2017)
"""
def __init__(self, vocab_sz, emb_dim, hidden_dim, num_layers=1, emb_dp=0.1, weight_dp=0.5, input_dp=0.3, hidden_dp=0.3, tie_weights=True, padding_idx=1):
super(AWDLSTMEncoder, self).__init__()
self.embeddings = nn.Embedding(vocab_sz, emb_dim, padding_idx=padding_idx)
self.emb_dp = EmbeddingDropout(self.embeddings, emb_dp)
self.rnn = nn.ModuleList([nn.LSTM(emb_dim if l == 0 else hidden_dim, (hidden_dim if l != num_layers - 1 else emb_dim) if tie_weights else hidden_dim) for l in range(num_layers)])
self.weight_dp = nn.ModuleList([WeightDropout(rnn, weight_dp) for rnn in self.rnn])
self.hidden_dp = RNNDropout(hidden_dp)
self.input_dp = RNNDropout(input_dp)
self.hidden, self.cell = None, None
def init_hidden(self, bs):
weight = next(self.parameters())
hidden = [weight.new_zeros(1, bs, self.rnn[i].hidden_size) for i in range(len(self.rnn))]
cell = [weight.new_zeros(1, bs, self.rnn[i].hidden_size) for i in range(len(self.rnn))]
return hidden, cell
def reset_hidden(self):
self.hidden, self.cell = None, None
def forward(self, x):
msl, bs = x.shape
# Initialize hidden states or detatch from history
# We need to detatch or else the model will backprop
# through previous batches on a new batch
if self.hidden is None and self.cell is None:
self.hidden, self.cell = self.init_hidden(bs)
else:
self.hidden = [repackage_hidden(h) for h in self.hidden]
self.cell = [repackage_hidden(h) for h in self.cell]
out = self.emb_dp(x)
raw_output = []
dropped_output = []
out = self.input_dp(out)
for i in range(len(self.rnn)):
out, (self.hidden[i], self.cell[i]) = self.weight_dp[i](out, (self.hidden[i], self.cell[i]))
raw_output.append(out)
if i < len(self.rnn) - 1:
out = self.hidden_dp(out)
dropped_output.append(out)
# out is the final processed RNN output
# self.hidden is a list of all hidden states. Use the last one.
# raw_output is a list with all raw RNN outputs
# dropped_output is a list with all RNN outputs with RNN dropout applied
# dropped_output contains one less item than raw_output at all times
return out, self.hidden, raw_output, dropped_output
class LSTMEncoder(nn.Module):
"""
Basic LSTM Encoder for Language Modeling
"""
def __init__(self, vocab_sz, emb_dim, hidden_dim, num_layers=1, dropout=0.5):
super(LSTMEncoder, self).__init__()
self.embeddings = nn.Embedding(vocab_sz, emb_dim)
self.rnn = nn.LSTM(emb_dim, hidden_dim, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
self.hidden, self.cell = None, None
def init_hidden(self, bs):
weight = next(self.parameters())
nlayers = self.rnn.num_layers
nhid = self.rnn.hidden_size
return (weight.new_zeros(nlayers, bs, nhid), weight.new_zeros(nlayers, bs, nhid))
def reset_hidden(self):
self.hidden, self.cell = None, None
def forward(self, x, lens=None):
msl, bs = x.shape
if self.hidden is None and self.cell is None:
self.hidden, self.cell = self.init_hidden(bs)
else:
self.hidden, self.cell = repackage_hidden((self.hidden, self.cell))
out = self.embeddings(x)
out = self.dropout(out)
out, (self.hidden, self.cell) = self.rnn(out, (self.hidden, self.cell))
out = self.dropout(out)
return out, self.hidden, self.cell
class DropoutLinearDecoder(nn.Module):
"""
Linear Decoder with output RNN Dropout. Used with AWD LSTM.
"""
def __init__(self, hidden_dim, vocab_sz, out_dp=0.4):
super(DropoutLinearDecoder, self).__init__()
self.fc1 = nn.Linear(hidden_dim, vocab_sz)
self.out_dp = RNNDropout(out_dp)
def forward(self, out, hidden, raw, dropped, return_states=False):
# Applies RNN Dropout on the RNN output and
# appends to the dropped_output list. Raw_output
# and dropped_output should have equal number of
# elements now
out = self.out_dp(out)
dropped.append(out)
out = self.fc1(out)
if return_states:
return out, hidden, raw, dropped
return out
class LinearDecoder(nn.Module):
def __init__(self, hidden_dim, vocab_sz):
super(LinearDecoder, self).__init__()
self.fc1 = nn.Linear(hidden_dim, vocab_sz)
def forward(self, out, *args, **kwargs):
return self.fc1(out)
class ConcatPoolingDecoder(nn.Module):
"""
Concat Pooling Decoder from Howard & Ruder (2018)
"""
def __init__(self, hidden_dim, bneck_dim, out_dim, dropout_pool=0.2, dropout_proj=0.1, include_hidden=True):
super(ConcatPoolingDecoder, self).__init__()
self.bn1 = nn.BatchNorm1d(hidden_dim * 3 if include_hidden else hidden_dim * 2)
self.bn2 = nn.BatchNorm1d(bneck_dim)
self.linear1 = nn.Linear(hidden_dim * 3 if include_hidden else hidden_dim * 2, bneck_dim)
self.linear2 = nn.Linear(bneck_dim, out_dim)
self.dropout_pool = nn.Dropout(dropout_pool)
self.dropout_proj = nn.Dropout(dropout_proj)
self.include_hidden = include_hidden
def forward(self, out, hidden):
_, bs, _ = out.shape
avg_pool = F.adaptive_avg_pool1d(out.permute(1, 2, 0), 1).view(bs, -1)
max_pool = F.adaptive_max_pool1d(out.permute(1, 2, 0), 1).view(bs, -1)
if self.include_hidden:
pooled = torch.cat([hidden[-1], avg_pool, max_pool], dim=1)
else:
pooled = torch.cat([avg_pool, max_pool], dim=1)
out = self.dropout_pool(self.bn1(pooled))
out = torch.relu(self.linear1(out))
out = self.dropout_proj(self.bn2(out))
out = self.linear2(out)
return out