-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathattentionModel.py
96 lines (81 loc) · 3.6 KB
/
attentionModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''
Author: fuchy@stu.pku.edu.cn
Date: 2021-09-17 23:30:48
LastEditTime: 2021-12-02 22:18:56
LastEditors: FCY SR
Description: attentionModel
FilePath: /compression/attentionModel.py
'''
import torch
import torch.nn as nn
import math
import copy
from networkTool import device
class SelfMultiheadAttention(nn.Module):
def __init__(self, emsize, nhead, dropout=0.5):
super(SelfMultiheadAttention, self).__init__()
self.nhead = nhead # 4
self.head_size = emsize // nhead
assert self.head_size * nhead == emsize, "embed_dim must be divisible by num_heads"
self.all_head_size = int(self.nhead * self.head_size)
self.mlpKey = nn.Linear(emsize, self.all_head_size)
self.mlpQuery = nn.Linear(emsize, self.all_head_size)
self.mlpValue = nn.Linear(emsize, self.all_head_size)
self.dropout = nn.Dropout(dropout)
# Slice the output of mlpKQV to implement multi-head attention.
def slice(self,x,dim):
new_x_shape = x.size()[:-1] + (self.nhead, self.head_size)
x = x.view(*new_x_shape)
if (dim == 3):
x = x.permute(0, 2, 1, 3)
elif (dim == 4):
x = x.permute(0,1,3,2,4)
assert 0
return x
#em.shape = [bptt,batch_size,emsize] mask.shape=[bptt, bptt]
def forward(self,em,mask):
em = em.transpose(0,1).contiguous()
Key = self.slice(self.mlpKey(em),em.dim())
Query = self.slice(self.mlpQuery(em),em.dim())
Value = self.slice(self.mlpValue(em),em.dim())
attention_score = torch.matmul(Query, Key.transpose(-1, -2)) / math.sqrt(self.head_size)
attention_score = attention_score + mask #torch.Size([32, 4, 256, 256]) ,mask [[0,-inf,-inf,..],[0,0,-inf,...],[0,0,0,...]]
attention_map = self.dropout(nn.Softmax(dim=-1)(attention_score))
context = torch.matmul(attention_map, Value)
if (context.dim() == 4):
context = context.permute(0, 2, 1, 3).contiguous()
elif (context.dim()==5):
context = context.permute(0, 1, 3, 2, 4).contiguous()
context_shape = context.size()[:-2] + (self.all_head_size,)
context = context.view(*context_shape)
context = context.transpose(0,1).contiguous()
return context
class TransformerLayer(nn.Module):
def __init__(self, ninp, nhead, nhid, dropout=0.1):
super(TransformerLayer, self).__init__()
self.MultiAttention = SelfMultiheadAttention(emsize=ninp,nhead=nhead)
self.linear1 = nn.Linear(ninp,nhid)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(nhid,ninp)
self.norm1 = nn.LayerNorm(ninp, eps=1e-5) # It will affect parallel coding
self.norm2 = nn.LayerNorm(ninp, eps=1e-5)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# src is the integration of leaf node and its ancestors.
def forward(self, src, src_mask):
src2 = self.MultiAttention(src,src_mask) #Multi-head Attention
src = self.dropout1(src2) + src
src = self.norm1(src)
src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class TransformerModule(nn.Module):
def __init__(self,layer, nlayers):
super(TransformerModule, self).__init__()
self.layers = torch.nn.ModuleList([copy.deepcopy(layer) for i in range(nlayers)])
def forward(self,src,src_mask):
output = src
for mod in self.layers:
output = mod(output, src_mask=src_mask)
return output