-
-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
backbones.py
executable file
·89 lines (78 loc) · 2.57 KB
/
backbones.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch.nn as nn
from torchvision import models
resnet_dict = {
"resnet18": models.resnet18,
"resnet34": models.resnet34,
"resnet50": models.resnet50,
"resnet101": models.resnet101,
"resnet152": models.resnet152,
}
def get_backbone(name):
if "resnet" in name.lower():
return ResNetBackbone(name)
elif "alexnet" == name.lower():
return AlexNetBackbone()
elif "dann" == name.lower():
return DaNNBackbone()
class DaNNBackbone(nn.Module):
def __init__(self, n_input=224*224*3, n_hidden=256):
super(DaNNBackbone, self).__init__()
self.layer_input = nn.Linear(n_input, n_hidden)
self.dropout = nn.Dropout(p=0.5)
self.relu = nn.ReLU()
self._feature_dim = n_hidden
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_input(x)
x = self.dropout(x)
x = self.relu(x)
return x
def output_num(self):
return self._feature_dim
# convnet without the last layer
class AlexNetBackbone(nn.Module):
def __init__(self):
super(AlexNetBackbone, self).__init__()
model_alexnet = models.alexnet(pretrained=True)
self.features = model_alexnet.features
self.classifier = nn.Sequential()
for i in range(6):
self.classifier.add_module(
"classifier"+str(i), model_alexnet.classifier[i])
self._feature_dim = model_alexnet.classifier[6].in_features
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256*6*6)
x = self.classifier(x)
return x
def output_num(self):
return self._feature_dim
class ResNetBackbone(nn.Module):
def __init__(self, network_type):
super(ResNetBackbone, self).__init__()
resnet = resnet_dict[network_type](pretrained=True)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
self._feature_dim = resnet.fc.in_features
del resnet
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
def output_num(self):
return self._feature_dim