forked from thuiar/MMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MLMF.py
221 lines (190 loc) · 9.02 KB
/
MLMF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
paper: Efficient Low-rank Multimodal Fusion with Modality-Specific Factors
ref: https://github.com/Justin1904/Low-rank-Multimodal-Fusion
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_normal_
__all__ = ['MLMF']
class SubNet(nn.Module):
'''
The subnetwork that is used in LMF for video and audio in the pre-fusion stage
'''
def __init__(self, in_size, hidden_size, dropout):
'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
dropout: dropout probability
Output:
(return value in forward) a tensor of shape (batch_size, hidden_size)
'''
super(SubNet, self).__init__()
self.norm = nn.BatchNorm1d(in_size)
self.drop = nn.Dropout(p=dropout)
self.linear_1 = nn.Linear(in_size, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
'''
Args:
x: tensor of shape (batch_size, in_size)
'''
normed = self.norm(x)
dropped = self.drop(normed)
y_1 = F.relu(self.linear_1(dropped), inplace=True)
y_2 = F.relu(self.linear_2(y_1), inplace=True)
y_3 = F.relu(self.linear_3(y_2), inplace=True)
return y_3
class TextSubNet(nn.Module):
'''
The LSTM-based subnetwork that is used in LMF for text
'''
def __init__(self, in_size, hidden_size, out_size, num_layers=1, dropout=0.2, bidirectional=False):
'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
num_layers: specify the number of layers of LSTMs.
dropout: dropout probability
bidirectional: specify usage of bidirectional LSTM
Output:
(return value in forward) a tensor of shape (batch_size, out_size)
'''
super(TextSubNet, self).__init__()
if num_layers == 1:
dropout = 0.0
self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.linear_1 = nn.Linear(hidden_size, out_size)
def forward(self, x):
'''
Args:
x: tensor of shape (batch_size, sequence_len, in_size)
'''
_, final_states = self.rnn(x)
h = self.dropout(final_states[0].squeeze())
y_1 = self.linear_1(h)
return y_1
class MLMF(nn.Module):
'''
Multi-task Low-rank Multimodal Fusion
'''
def __init__(self, args):
'''
Args:
input_dims - a length-3 tuple, contains (audio_dim, video_dim, text_dim)
hidden_dims - another length-3 tuple, hidden dims of the sub-networks
text_out - int, specifying the resulting dimensions of the text subnetwork
dropouts - a length-4 tuple, contains (audio_dropout, video_dropout, text_dropout, post_fusion_dropout)
output_dim - int, specifying the size of output
rank - int, specifying the size of rank in LMF
Output:
(return value in forward) a scalar value between -3 and 3
'''
super(MLMF, self).__init__()
# dimensions are specified in the order of audio, video and text
self.text_in, self.audio_in, self.video_in = args.feature_dims
self.text_hidden, self.audio_hidden, self.video_hidden = args.hidden_dims
self.text_out= self.text_hidden // 2
self.output_dim = args.output_dim
self.rank = args.rank
self.use_softmax = args.use_softmax
self.audio_prob, self.video_prob, self.text_prob, self.post_fusion_prob = args.dropouts
self.post_text_prob, self.post_audio_prob, self.post_video_prob = args.post_dropouts
self.post_text_dim = args.post_text_dim
self.post_audio_dim = args.post_audio_dim
self.post_video_dim = args.post_video_dim
# define the pre-fusion subnetworks
self.audio_subnet = SubNet(self.audio_in, self.audio_hidden, self.audio_prob)
self.video_subnet = SubNet(self.video_in, self.video_hidden, self.video_prob)
self.text_subnet = TextSubNet(self.text_in, self.text_hidden, self.text_out, dropout=self.text_prob)
# self.post_fusion_layer_1 = nn.Linear((self.text_out + 1) * (self.video_hidden + 1) * (self.audio_hidden + 1), self.post_fusion_dim)
self.audio_factor = Parameter(torch.Tensor(self.rank, self.audio_hidden + 1, self.output_dim))
self.video_factor = Parameter(torch.Tensor(self.rank, self.video_hidden + 1, self.output_dim))
self.text_factor = Parameter(torch.Tensor(self.rank, self.text_out + 1, self.output_dim))
# define the classify layer for text
self.post_text_dropout = nn.Dropout(p=self.post_text_prob)
self.post_text_layer_1 = nn.Linear(self.text_out, self.post_text_dim)
self.post_text_layer_2 = nn.Linear(self.post_text_dim, self.post_text_dim)
self.post_text_layer_3 = nn.Linear(self.post_text_dim, 1)
# define the classify layer for audio
self.post_audio_dropout = nn.Dropout(p=self.post_audio_prob)
self.post_audio_layer_1 = nn.Linear(self.audio_hidden, self.post_audio_dim)
self.post_audio_layer_2 = nn.Linear(self.post_audio_dim, self.post_audio_dim)
self.post_audio_layer_3 = nn.Linear(self.post_audio_dim, 1)
# define the classify layer for video
self.post_video_dropout = nn.Dropout(p=self.post_video_prob)
self.post_video_layer_1 = nn.Linear(self.video_hidden, self.post_video_dim)
self.post_video_layer_2 = nn.Linear(self.post_video_dim, self.post_video_dim)
self.post_video_layer_3 = nn.Linear(self.post_video_dim, 1)
self.fusion_weights = Parameter(torch.Tensor(1, self.rank))
self.fusion_bias = Parameter(torch.Tensor(1, self.output_dim))
# init teh factors
xavier_normal_(self.audio_factor)
xavier_normal_(self.video_factor)
xavier_normal_(self.text_factor)
xavier_normal_(self.fusion_weights)
self.fusion_bias.data.fill_(0)
def forward(self, text_x, audio_x, video_x):
'''
Args:
audio_x: tensor of shape (batch_size, audio_in)
video_x: tensor of shape (batch_size, video_in)
text_x: tensor of shape (batch_size, sequence_len, text_in)
'''
audio_x = audio_x.squeeze(1)
video_x = video_x.squeeze(1)
audio_h = self.audio_subnet(audio_x)
video_h = self.video_subnet(video_x)
text_h = self.text_subnet(text_x)
# text
x_t = self.post_text_dropout(text_h)
x_t = F.relu(self.post_text_layer_1(x_t), inplace=True)
x_t = F.relu(self.post_text_layer_2(x_t), inplace=True)
output_text = self.post_text_layer_3(x_t)
# audio
x_a = self.post_audio_dropout(audio_h)
x_a = F.relu(self.post_audio_layer_1(x_a), inplace=True)
x_a = F.relu(self.post_audio_layer_2(x_a), inplace=True)
output_audio = self.post_audio_layer_3(x_a)
# video
x_v = self.post_video_dropout(video_h)
x_v = F.relu(self.post_video_layer_1(video_h), inplace=True)
x_v = F.relu(self.post_video_layer_2(x_v), inplace=True)
output_video = self.post_video_layer_3(x_v)
batch_size = audio_h.data.shape[0]
# next we perform low-rank multimodal fusion
# here is a more efficient implementation than the one the paper describes
# basically swapping the order of summation and elementwise product
# next we perform "tensor fusion", which is essentially appending 1s to the tensors and take Kronecker product
add_one = torch.ones(size=[batch_size, 1], requires_grad=False).type_as(audio_h).to(text_x.device)
_audio_h = torch.cat((add_one, audio_h), dim=1)
_video_h = torch.cat((add_one, video_h), dim=1)
_text_h = torch.cat((add_one, text_h), dim=1)
fusion_audio = torch.matmul(_audio_h, self.audio_factor)
fusion_video = torch.matmul(_video_h, self.video_factor)
fusion_text = torch.matmul(_text_h, self.text_factor)
# fusion
fusion_zy = fusion_audio * fusion_video * fusion_text
# output = torch.sum(fusion_zy, dim=0).squeeze()
# use linear transformation instead of simple summation, more flexibility
output = torch.matmul(self.fusion_weights, fusion_zy.permute(1, 0, 2)).squeeze() + self.fusion_bias
output = output.view(-1, self.output_dim)
if self.use_softmax:
output = F.softmax(output)
res = {
'Feature_t': text_h,
'Feature_a': audio_h,
'Feature_v': video_h,
'Feature_f': fusion_zy.permute(1, 0, 2).squeeze(),
'M': output,
'T': output_text,
'A': output_audio,
'V': output_video
}
return res