forked from swathikirans/ego-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
objectAttentionModelConvLSTM.py
42 lines (39 loc) · 1.94 KB
/
objectAttentionModelConvLSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import resnetMod
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from MyConvLSTMCell import *
class attentionModel(nn.Module):
def __init__(self, num_classes=61, mem_size=512, attention=1):
super(attentionModel, self).__init__()
self.num_classes = num_classes
self.attention = attention
self.resNet = resnetMod.resnet34(True, True)
self.mem_size = mem_size
self.weight_softmax = self.resNet.fc.weight
self.lstm_cell = MyConvLSTMCell(512, mem_size)
self.avgpool = nn.AvgPool2d(7)
self.dropout = nn.Dropout(0.7)
self.fc = nn.Linear(mem_size, self.num_classes)
self.classifier = nn.Sequential(self.dropout, self.fc)
def forward(self, inputVariable):
state = (Variable(torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()),
Variable(torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()))
for t in range(inputVariable.size(0)):
logit, feature_conv, feature_convNBN = self.resNet(inputVariable[t])
if self.attention == 1:
bz, nc, h, w = feature_conv.size()
feature_conv1 = feature_conv.view(bz, nc, h*w)
probs, idxs = logit.sort(1, True)
class_idx = idxs[:, 0]
cam = torch.bmm(self.weight_softmax[class_idx].unsqueeze(1), feature_conv1)
attentionMAP = torch.softmax(cam.squeeze(1), dim=1)
attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7)
attentionFeat = feature_convNBN * attentionMAP.expand_as(feature_conv)
state = self.lstm_cell(attentionFeat, state)
elif self.attention == 0:
state = self.lstm_cell(feature_conv, state)
feats1 = self.avgpool(state[1]).view(state[1].size(0), -1)
feats = self.classifier(feats1)
return feats, feats1