-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
370 lines (308 loc) · 13 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import json
import os
import time
from django.conf import settings
from cnntools import caffefileproc
from cnntools.common_utils import ensuredir
from cnntools.models import CaffeCNNTrainingRun
from cnntools.snapshot_utils import upload_snapshot
from cnntools.utils import (add_caffe_to_path, add_to_path,
random_file_from_content)
def _load_training_run(caffe_cnn_trrun_id):
entry = CaffeCNNTrainingRun.objects.get(id=caffe_cnn_trrun_id)
outputs_str = json.loads(entry.outputs_json)
# Convert to numbers
outputs = []
for op_str in outputs_str:
op = {}
for op_num_str, its_str in op_str.iteritems():
its = {}
for it_str, val_str in its_str.iteritems():
its[int(it_str)] = float(val_str)
op[int(op_num_str)] = its
outputs.append(op)
max_iter = entry.max_iteration
return outputs, max_iter
def _refresh_training_run(caffe_cnn_trrun_id, outputs, output_names, itnum,
max_iter):
outputs_json = json.dumps(outputs)
output_names_json = json.dumps(output_names)
entry = CaffeCNNTrainingRun.objects.get(id=caffe_cnn_trrun_id)
entry.outputs_json = outputs_json
entry.output_names_json = output_names_json
entry.final_iteration = itnum
entry.max_iteration = max_iter
entry.save()
def _refresh_solverfile(caffe_cnn_trrun_id, solver_file_content):
entry = CaffeCNNTrainingRun.objects.get(id=caffe_cnn_trrun_id)
# Overwrite solverfile content with new data
entry.solver_file_snapshot.delete(save=False)
entry.solver_file_snapshot = random_file_from_content(solver_file_content)
entry.save()
def extract_batchsize_testsetsize(model_file_content):
from caffe.proto import caffe_pb2
# Note that we handle only MULTI_IMAGE_DATA layer and some other layers (see below), we require include.phase be TEST
model_params = caffefileproc.parse_model_definition_file_content(model_file_content)
batch_size = None
testset_size = None
layer_types = [
'ImageData',
'MultiImageData',
'Python',
]
for layer in model_params.layer:
if layer.type in layer_types:
if not layer.include:
continue
if layer.include[0].phase == caffe_pb2.TEST:
if layer.type == 'Python':
# Try to parse json
try:
params = json.loads(layer.python_param.param_str)
batch_size = params['batch_size']
source = params['source']
except Exception as e:
print 'Failed to parse param_str: ', e
print 'Skipping this layer.'
continue
elif layer.type == 'ImageData':
batch_size = layer.image_data_param.batch_size
source = layer.image_data_param.source
elif layer.type == 'MultiImageData':
batch_size = layer.multi_image_data_param.batch_size
source = layer.multi_image_data_param.source
else:
continue
# Note that the source should be relative to the caffe root path!
testset_size = len(caffefileproc.freadlines(os.path.join(settings.CAFFE_ROOT, source)))
break
return batch_size, testset_size
def setup_solverfile(model_name, model_file_content, solver_file_content,
options, caffe_cnn_trrun_id, device_id):
rand_name = '%d' % caffe_cnn_trrun_id
root_path = os.path.join(
settings.CAFFE_ROOT,
'training_runs',
'-'.join([rand_name, model_name])
)
ensuredir(root_path)
trainfilename = 'train_val.prototxt'
solverfilename = 'solver.prototxt'
trainfile_path = os.path.join(root_path, trainfilename)
solverfile_path = os.path.join(root_path, solverfilename)
# Save the model_file_content to a file, so Caffe can read it
with open(trainfile_path, 'w') as f:
f.write(model_file_content)
# copy the sample solver file and modify it
solver_params = caffefileproc.parse_solver_file_content(solver_file_content)
# modify solver params according to the command line parameters
solver_params.net = trainfile_path
if 'base_lr' in options and options['base_lr'] is not None:
solver_params.base_lr = options['base_lr']
# Switch on debug_info to facilitate debugging
if 'debug_info' in options and options['debug_info'] is not None:
solver_params.debug_info = options['debug_info']
if 'weight_decay' in options and options['weight_decay'] is not None:
solver_params.weight_decay = options['weight_decay']
if 'max_iter' in options and options['max_iter'] is not None:
solver_params.max_iter = options['max_iter']
snapshot_path = os.path.join(root_path, 'snapshots')
ensuredir(snapshot_path)
solver_params.snapshot_prefix = os.path.join(
snapshot_path,
'train_{}-base_lr{}'.format(model_name, solver_params.base_lr)
)
# compute the proper test_iter
batch_size, testset_size = extract_batchsize_testsetsize(model_file_content)
if batch_size and testset_size:
if options['verbose']:
print 'Extracted batch_size ({0}) and testset_size ({1})'.format(
batch_size, testset_size)
# Note the solver file should have exactly one test_iter
solver_params.test_iter[0] = int(testset_size/batch_size)
else:
if options['verbose']:
print 'WARNING: Couldn\'t find the batch_size or the source file ' + \
'containing the testset, please set the test_iter to ' + \
'testset_size / batch_size!'
# Setting random seed value for reproducible results
solver_params.random_seed = settings.CAFFE_SEED
# Silence Caffe
from os import environ
environ['GLOG_minloglevel'] = '2'
add_caffe_to_path()
import caffe
from caffe.proto import caffe_pb2
if settings.CAFFE_GPU and (options['cpu'] is None or not options['cpu']):
if options['verbose']:
print 'Using GPU'
caffe.set_mode_gpu()
caffe.set_device(device_id)
solver_params.solver_mode = caffe_pb2.SolverParameter.GPU
else:
if options['verbose']:
print 'Using CPU'
caffe.set_mode_cpu()
solver_params.solver_mode = caffe_pb2.SolverParameter.CPU
caffefileproc.save_protobuf_file(solverfile_path, solver_params)
return solver_params, solverfile_path
def get_solver_type(caffe, solver_params, verbose=True):
if not hasattr(solver_params, 'type'):
solver_type = caffe.SGDSolver
else:
valid_solver_names = [
'SGD', 'AdaDelta', 'AdaGrad', 'Adam', 'Nesterov', 'RMSProp',
]
if solver_params.type not in valid_solver_names:
raise ValueError('Unexpected solver type: %s' % solver_params.type)
solver_type = getattr(caffe, solver_params.type + 'Solver')
if verbose:
print 'Creating solver "%s"...' % solver_type.__name__
return solver_type
def train_network(solver_params, solverfile_path, options,
caffe_cnn_trrun_id):
add_caffe_to_path()
import caffe
for p in settings.TRAINING_EXTRA_PYTHON_PATH:
add_to_path(p)
restore = False
solver = get_solver_type(caffe, solver_params)(str(solverfile_path))
solver_nets = [solver.net] + list(solver.test_nets)
if 'weights' in options and options['weights'] is not None:
_, ext = os.path.splitext(options['weights'])
if ext == '.solverstate':
solver.restore(
os.path.join(settings.CAFFE_ROOT, options['weights'])
)
restore = True
else:
for n in solver_nets:
n.copy_from(
os.path.join(settings.CAFFE_ROOT, options['weights'])
)
print 'solver_net count:', len(solver_nets)
for n in solver_nets:
data_layer = n.layers[0]
# If the data layer is python, we try to set the random seed for
# reproducibility
if data_layer.type == 'Python':
# Note: The python implementation of the layer should have a
# "set_random_seed" function. If we can't find a function with this name,
# we won't set the random seed
set_random_seed_func = getattr(data_layer, 'set_random_seed', None)
if callable(set_random_seed_func):
set_random_seed_func(settings.CAFFE_SEED)
# Note: The python implementation of the layer should have a
# "set_params" function.
set_params_func = getattr(data_layer, 'set_params', None)
if 'data_layer_params' in options and callable(set_params_func):
set_params_func(options['data_layer_params'])
n.reshape()
# {key: output_num, value: output_name}
output_names = [{}, {}]
# for backward compatibility
name_to_num = [{}, {}]
# {key: output_num, value: {key: it_num value: output_value}}
outputs = [{}, {}]
for i, op in enumerate(solver.net.outputs):
output_names[0][i] = op
name_to_num[0][op] = i
outputs[0][i] = {}
op_num = 0
for test_net in solver.test_nets:
for op in test_net.outputs:
output_names[1][op_num] = op
name_to_num[1][op] = op_num
outputs[1][op_num] = {}
op_num += 1
if restore:
# Load back the saved outputs, so we can start from the figures where
# we left off
outputs, max_iter = _load_training_run(caffe_cnn_trrun_id)
start_it = int(solver.iter)
# Filter out data which happened after the snapshot
for op in outputs:
for op_num, its in op.iteritems():
to_remove = []
for it in its:
if it > start_it:
to_remove.append(it)
for it in to_remove:
its.pop(it)
else:
max_iter = solver_params.max_iter
start_it = 0
final_snapshot_id = None
for it in range(start_it, max_iter+1):
#start = time.clock()
solver.step(1) # SGD by Caffe
#elapsed = time.clock() - start
#if options['verbose']:
#print 'One iteration took {:.2f} seconds'.format(elapsed)
display = solver_params.display and it % solver_params.display == 0
if display:
for op in solver.net.outputs:
val = solver.net.blobs[op].data
op_num = name_to_num[0][op]
outputs[0][op_num][it] = float(val)
test_display = solver_params.test_interval and \
it % solver_params.test_interval == 0 and \
(it != 0 or solver_params.test_initialization)
if test_display:
for i, test_net in enumerate(solver.test_nets):
for op in test_net.outputs:
val = solver.test_mean_scores[i][op]
op_num = name_to_num[1][op]
outputs[1][op_num][it] = float(val)
if display or test_display:
_refresh_training_run(
caffe_cnn_trrun_id,
outputs,
output_names,
it,
max_iter,
)
snapshot_path = os.path.join(
settings.CAFFE_ROOT,
'{}_iter_{}.caffemodel'.format(
solver_params.snapshot_prefix,
it
)
)
snapshot = it % solver_params.snapshot == 0 and (it != 0 or options.get('start_snapshot', False))
if snapshot:
if it == 0:
solver.net.save(snapshot_path)
final_snapshot = upload_snapshot(
caffe_cnn_trrun_id=caffe_cnn_trrun_id,
snapshot_path=snapshot_path,
it=it,
verbose=options['verbose'],
)
final_snapshot_id = final_snapshot.id
return final_snapshot_id
def start_training(model_name, model_file_content, solver_file_content,
options, caffe_cnn_trrun_id, device_id=0):
if options['verbose']:
print 'Running training for model {}...'.format(model_name)
print 'with options: {}'.format(options)
solver_params, solverfile_path = setup_solverfile(
model_name=model_name,
model_file_content=model_file_content,
solver_file_content=solver_file_content,
options=options,
caffe_cnn_trrun_id=caffe_cnn_trrun_id,
device_id=device_id,
)
# Save the final solver file's content to database
_refresh_solverfile(
caffe_cnn_trrun_id,
caffefileproc.gen_protobuf_file_content(solver_params),
)
# Change working directory to Caffe
os.chdir(settings.CAFFE_ROOT)
# TODO: Figure out to detect divergence and that the loss doesn't decrease,
# so we can stop training
return train_network(
solver_params, solverfile_path, options, caffe_cnn_trrun_id,
)