-
Notifications
You must be signed in to change notification settings - Fork 34
/
crnn.py
118 lines (95 loc) · 3.85 KB
/
crnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
# pretrained features
FEATURES = {}
# output dimensionality for supported architectures
OUTPUT_DIM = {
'resnet_cifar': 512,
'densenet_cifar': 342,
'densenet121': 384,
'mobilenetv2_cifar': 1280,
'shufflenetv2_cifar': 1024,
}
class CRNN(nn.Module):
def __init__(self, features, meta):
super(CRNN, self).__init__()
self.features = nn.Sequential(*features)
self.avgpool = nn.AdaptiveAvgPool2d((1, None))
self.classifier = nn.Linear(meta['output_dim'], meta['num_classes'])
self.meta = meta
def forward(self, x):
# x -> features
out = self.features(x)
# features -> pool -> flatten -> decoder -> softmax
out = self.avgpool(out)
out = out.permute(3, 0, 1, 2).view(out.size(3), out.size(0), -1)
out = self.classifier(out)
out = F.log_softmax(out, dim=2)
return out
def __repr__(self):
tmpstr = super(CRNN, self).__repr__()[:-1]
tmpstr += self.meta_repr()
tmpstr = tmpstr + ')'
return tmpstr
def meta_repr(self):
tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
tmpstr += ' output dim: {}\n'.format(self.meta['output_dim'])
tmpstr += ' classes: {}\n'.format(self.meta['num_classes'])
tmpstr += ' mean: {}\n'.format(self.meta['mean'])
tmpstr += ' std: {}\n'.format(self.meta['std'])
tmpstr = tmpstr + ' )\n'
return tmpstr
def init_network(params):
# parse params with default values
architecture = params.get('architecture', 'densenet_cifar')
num_classes = params.get('num_classes', 11)
mean = params.get('mean', [0.396, 0.576, 0.562])
std = params.get('std', [0.154, 0.128, 0.130])
pretrained = params.get('pretrained', False)
# get output dimensionality size
dim = OUTPUT_DIM[architecture]
# loading network
if pretrained:
if architecture not in FEATURES:
# initialize with network pretrained on imagenet in pytorch
net_in = getattr(models, architecture)(pretrained=True)
else:
# initialize with random weights, later on we will fill features with custom pretrained network
net_in = getattr(models, architecture)(pretrained=False)
else:
# initialize with random weights
net_in = getattr(models, architecture)(pretrained=False)
# initialize features
# take only convolutions for features,
# always ends with ReLU to make last activations non-negative
if architecture.startswith('resnet'):
features = list(net_in.children())[:-2]
elif architecture.startswith('densenet'):
features = list(net_in.features.children())
features.append(nn.ReLU(inplace=True))
elif architecture.startswith('mobilenetv2'):
features = list(net_in.children())[:-2]
elif architecture.startswith('shufflenetv2'):
features = list(net_in.children())[:-2]
else:
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
# create meta information to be stored in the network
meta = {
'architecture': architecture,
'num_classes': num_classes,
'mean': mean,
'std': std,
'output_dim': dim,
}
# create a generic crnn network
net = CRNN(features, meta)
# initialize features with custom pretrained network if needed
if pretrained and architecture in FEATURES:
print(">> {}: for '{}' custom pretrained features '{}' are used".format(
os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
net.features.load_state_dict(torch.load(FEATURES[architecture]))
return net