-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
41 lines (34 loc) · 1.16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn as nn
from basis import BasisEmbedding
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self,
ntoken,
ninp,
nhid,
nlayers,
criterion=None,
dropout=0.5,
tie_weights=False,
basis=0,
num_clusters=400,
):
super(RNNModel, self).__init__()
self.nhid = nhid
self.nlayers = nlayers
self.drop = nn.Dropout(dropout)
if basis != 0:
self.encoder = BasisEmbedding(
ntoken, ninp, basis, num_clusters,
)
else:
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(
ninp, nhid, nlayers, dropout=dropout, batch_first=True)
self.criterion = criterion
def forward(self, input, target, lengths=None):
emb = self.drop(self.encoder(input))
output, unused_hidden = self.rnn(emb)
output = self.drop(output)
loss = self.criterion(output, target, lengths)
return loss