-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsa_glu.py
60 lines (49 loc) · 2.31 KB
/
sa_glu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
from models.layers import Embedding, MLP, SelfAttnLayer
class SA_GLUModel(nn.Module):
"""
Model: Self Attention + Gated Linear Unit
Ref: Y N. Dauphin, et al. Language Modeling with Gated Convolutional Networks, 2017
"""
def __init__(self, nfield, nfeat, nemb, mlp_layers, mlp_hid,
dropout, ensemble, deep_layers, deep_hid):
super().__init__()
self.nfield, self.nfeat, self.nemb = nfield, nfeat, nemb
self.ensemble = ensemble
self.dropout = nn.Dropout(p=dropout)
# embedding
self.embedding = Embedding(nfeat, nemb)
self.emb_bn = nn.BatchNorm1d(nfield)
# self attn
self.self_attn_w = SelfAttnLayer(nemb)
self.self_attn_v = SelfAttnLayer(nemb)
self.w_b = nn.Parameter(torch.zeros(nemb,))
self.v_b = nn.Parameter(torch.zeros(nemb,))
# MLP
self.mlp = MLP(nfield*nemb, mlp_layers, mlp_hid, dropout)
if ensemble:
self.deep_embedding = Embedding(nfeat, nemb)
self.deep_mlp = MLP(nfield*nemb, deep_layers, deep_hid, dropout)
self.ensemble_layer = nn.Linear(2, 1)
nn.init.constant_(self.ensemble_layer.weight, 0.5)
nn.init.constant_(self.ensemble_layer.bias, 0.)
def forward(self, x):
"""
:param x: {'id': LongTensor B*F, 'value': FloatTensor B*F}
:return: y of size B, Regression and Classification (+sigmoid)
"""
x['value'].clamp_(0.001, 1.)
x_emb = self.embedding(x) # B*F*E
xw = self.self_attn_w(x_emb)[0]+self.w_b # B*F*E
xv = self.self_attn_v(x_emb)[0]+self.v_b # B*F*E
glu = xw * torch.sigmoid(xv) # B*F*E
glu = self.dropout(glu.view(xw.size(0), -1)) # B*(FxE)
y = self.mlp(glu) # B*1
if self.ensemble:
deep_emb = self.deep_embedding(x)
y_deep = self.deep_mlp(
deep_emb.view(-1, self.nfield*self.nemb)) # B*1
y = torch.cat([y, y_deep], dim=1) # B*2
y = self.ensemble_layer(y) # B*1
return y.squeeze(1) # B