forked from wei-tim/YOWO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cfam.py
120 lines (64 loc) · 2.7 KB
/
cfam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
We thank CASIA IVA for sharing his code 'https://github.com/junfu1115/DANet'
that we have build our code on top.
"""
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
class CAM_Module(nn.Module):
""" Channel attention module """
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X H X W )
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class CFAMBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(CFAMBlock, self).__init__()
inter_channels = 1024
self.conv_bn_relu1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv_bn_relu2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sc = CAM_Module(inter_channels)
self.conv_bn_relu3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv_out = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
x = self.conv_bn_relu1(x)
x = self.conv_bn_relu2(x)
x = self.sc(x)
x = self.conv_bn_relu3(x)
output = self.conv_out(x)
return output
if __name__ == "__main__":
data = torch.randn(18, 2473, 7, 7).cuda()
in_channels = data.size()[1]
out_channels = 145
model = CFAMBlock(in_channels, out_channels).cuda()
print(model)
output = model(data)
print(output.size())