-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
64 lines (49 loc) · 2.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
class LangModelWithDense(nn.Module):
def __init__(self, lang_model, emb_size, num_classes, fine_tune):
"""
Create the model.
:param lang_model: The language model.
:param emb_size: The size of the contextualized embeddings.
:param num_classes: The number of classes.
:param fine_tune: whether to fine-tune or freeze the language model's weights.
"""
super().__init__()
self.num_classes = num_classes
self.fine_tune = fine_tune
self.lang_model = lang_model
self.linear = nn.Linear(emb_size, num_classes)
self.dropout = nn.Dropout(0.1)
def forward(self, x, mask):
"""
Forward function of the model.
:param x: The inputs. Shape: [batch_size, seq_len].
:param mask: The attention mask. Ones are unmasked, zeros are masked.
Shape: [batch_size, seq_len].
:return: The logits. Shape: [batch_size, seq_len, num_classes].
Example:
model = LangModelWithDense(...)
x = np.array([[2, 2], [1, 3]])
mask = np.array([[1, 1], [1, 0]])
logits = model.foward(x, mask)
"""
batch_size = x.shape[0]
seq_len = x.shape[1]
# this will modify the language model's weights
if not self.fine_tune:
with torch.no_grad():
self.lang_model.eval()
embeddings = self.lang_model(x, attention_mask=mask)[0]
# this will not
else:
embeddings = self.lang_model(x, attention_mask=mask)[0]
# create a vector to retain the output for each token. Shape:
# [batch_size, seq_len, num_classes]
logits = torch.zeros((batch_size, seq_len, self.num_classes))
# feed-forward for each token in the sequence and save it in outputs
for i in range(seq_len):
# the logits for a single token. Shape: [batch_size, num_classes]
logit = self.dropout(self.linear(embeddings[:, i, :]))
logits[:, i, :] = logit
return logits