-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathcontextual_layer.py
55 lines (47 loc) · 2.33 KB
/
contextual_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
Original souce code: https://github.com/weizheliu/Context-Aware-Crowd-Counting
'''
import torch.nn as nn
import torch
from torch.nn import functional as F
from torchvision import models
class ContextualModule(nn.Module):
def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
super(ContextualModule, self).__init__()
self.scales = []
self.scales = nn.ModuleList([self._make_scale(features, size) for size in sizes])
self.bottleneck = nn.Conv2d(features * 2, out_features, kernel_size=1)
self.relu = nn.ReLU()
self.weight_net = nn.Conv2d(features,features,kernel_size=1)
self._initialize_weights()
def __make_weight(self,feature,scale_feature):
weight_feature = feature - scale_feature
return F.sigmoid(self.weight_net(weight_feature))
def _make_scale(self, features, size):
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
return nn.Sequential(prior, conv)
def forward(self, feats):
h, w = feats.size(2), feats.size(3)
multi_scales = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.scales]
weights = [self.__make_weight(feats,scale_feature) for scale_feature in multi_scales]
overall_features = [(multi_scales[0]*weights[0]+multi_scales[1]*weights[1]+multi_scales[2]*weights[2]+multi_scales[3]*weights[3])/(weights[0]+weights[1]+weights[2]+weights[3])]+ [feats]
bottle = self.bottleneck(torch.cat(overall_features, 1))
return self.relu(bottle)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ContextualEncoder(nn.Module):
def __init__(self):
super(ContextualEncoder, self).__init__()
self.can = ContextualModule(512, 512)
def forward(self, *input):
conv2_2, conv3_3, conv4_3, conv5_3 = input
conv5_3 = self.can(conv5_3)
return conv2_2, conv3_3, conv4_3, conv5_3