-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmodel_search.py
262 lines (227 loc) · 9.45 KB
/
model_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import torch.nn as nn
import torch.nn.functional as F
from operations import *
from torch.autograd import Variable
from genotypes import PRIMITIVES
from genotypes import Genotype
import numpy as np
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights, selected_idx=None):
if selected_idx is None:
return sum(w * op(x) for w, op in zip(weights, self._ops))
else: # unchosen operations are pruned
return self._ops[selected_idx](x)
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
self.reduction = reduction
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self._bns = nn.ModuleList()
for i in range(self._steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride)
self._ops.append(op)
def forward(self, s0, s1, weights, selected_idxs=None):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
o_list = []
for j, h in enumerate(states):
if selected_idxs[offset + j] == -1: # undecided mix edges
o = self._ops[offset + j](h, weights[offset + j])
o_list.append(o)
elif selected_idxs[offset + j] == PRIMITIVES.index('none'): # pruned edges
continue
else: # decided discrete edges
o = self._ops[offset + j](h, None, selected_idxs[offset + j])
o_list.append(o)
s = sum(o_list)
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
C_curr = stem_multiplier * C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._initialize_alphas()
self.normal_selected_idxs = None
self.reduce_selected_idxs = None
self.normal_candidate_flags = None
self.reduce_candidate_flags = None
def new(self):
model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
model_new.normal_selected_idxs = self.normal_selected_idxs
model_new.reduce_selected_idxs = self.reduce_selected_idxs
return model_new
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
selected_idxs = self.reduce_selected_idxs
alphas = self.alphas_reduce
else:
selected_idxs = self.normal_selected_idxs
alphas = self.alphas_normal
weights = []
n = 2
start = 0
for _ in range(self._steps):
end = start + n
for j in range(start, end):
weights.append(F.softmax(alphas[j], dim=-1))
start = end
n += 1
s0, s1 = s1, cell(s0, s1, weights, selected_idxs)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits
def _loss(self, input, target):
logits = self(input)
return self._criterion(logits, target)
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(PRIMITIVES)
self.alphas_normal = []
self.alphas_reduce = []
for i in range(self._steps):
for n in range(2 + i):
self.alphas_normal.append(Variable(1e-3 * torch.randn(num_ops).cuda(), requires_grad=True))
self.alphas_reduce.append(Variable(1e-3 * torch.randn(num_ops).cuda(), requires_grad=True))
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
def arch_parameters(self):
return self.alphas_normal + self.alphas_reduce # concat lists
def check_edges(self, flags, selected_idxs, reduction=False):
n = 2
max_num_edges = 2
start = 0
for i in range(self._steps):
end = start + n
num_selected_edges = torch.sum(1 - flags[start:end].int())
if num_selected_edges >= max_num_edges:
for j in range(start, end):
if flags[j]:
flags[j] = False
selected_idxs[j] = PRIMITIVES.index('none') # pruned edges
if reduction:
self.alphas_reduce[j].requires_grad = False
else:
self.alphas_normal[j].requires_grad = False
else:
pass
start = end
n += 1
return flags, selected_idxs
def parse_gene(self, selected_idxs):
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
for j in range(start, end):
if selected_idxs[j] == 0:
pass
elif selected_idxs[j] == -1:
raise Exception("Contain undecided edges")
else:
gene.append((PRIMITIVES[selected_idxs[j]], j - start))
start = end
n += 1
return gene
def parse_gene_force(self, flags, selected_idxs, alphas):
gene = []
n = 2
max_num_edges = 2
start = 0
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach()
importance = torch.sum(mat[:, 1:], dim=-1)
masked_importance = torch.min(importance, (2 * flags.float() - 1) * np.inf)
for _ in range(self._steps):
end = start + n
num_selected_edges = torch.sum(1 - flags[start:end].int())
num_edges_to_select = max_num_edges - num_selected_edges
if num_edges_to_select > 0:
post_select_edges = torch.topk(masked_importance[start: end], k=num_edges_to_select).indices + start
else:
post_select_edges = []
for j in range(start, end):
if selected_idxs[j] == 0:
pass
elif selected_idxs[j] == -1:
if num_edges_to_select <= 0:
raise Exception("Unknown errors")
else:
if j in post_select_edges:
idx = torch.argmax(alphas[j][1:]) + 1
gene.append((PRIMITIVES[idx], j - start))
else:
gene.append((PRIMITIVES[selected_idxs[j]], j - start))
start = end
n += 1
return gene
def get_genotype(self, force=False):
if force:
gene_normal = self.parse_gene_force(self.normal_candidate_flags,
self.normal_selected_idxs,
self.alphas_normal)
gene_reduce = self.parse_gene_force(self.reduce_candidate_flags,
self.reduce_selected_idxs,
self.alphas_reduce)
else:
gene_normal = self.parse_gene(self.normal_selected_idxs)
gene_reduce = self.parse_gene(self.reduce_selected_idxs)
n = 2
concat = range(n + self._steps - self._multiplier, self._steps + n)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype