-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfetcomp.py
484 lines (395 loc) · 14.9 KB
/
fetcomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import os
from cnntools.common_utils import iter_batch, progress_bar_widgets
from cnntools.descriptor_aggregator import DescriptorAggregator
from cnntools.descstore import DescriptorStoreHdf5
from cnntools.models import CaffeCNNSnapshot
from cnntools.redis_aggregator import (detach_patch_interrupt_signal,
patch_interrupt_signal)
from cnntools.tasks import compute_cnn_features_gpu_task
from cnntools.utils import RedisItemKey
from progressbar import ProgressBar
def get_snapshot_id(caffe_cnn, snapshot_id):
if snapshot_id:
snapshot = CaffeCNNSnapshot.objects.get(id=snapshot_id)
if snapshot.training_run.net.id != caffe_cnn.id:
raise ValueError(
'Snapshot\'s corresponding network id doesn\'t match given'
' network id!'
)
return snapshot_id
# Get last snapshot
snapshot = CaffeCNNSnapshot.objects.\
filter(training_run__net__id=caffe_cnn.id).\
order_by('-id')[0]
return snapshot.id
def get_slug(caffe_cnn, snapshot_id, slug_extra):
snapshot_id = get_snapshot_id(caffe_cnn, snapshot_id)
slug = 'netid-{}:snapshot-{}'.format(caffe_cnn.netid, snapshot_id)
if slug_extra:
slug += slug_extra
return slug, snapshot_id
def get_redis_key_core(slug, feature_name_list):
return 'desc:slug-%s:fetname%s:batchid#' % (
slug,
'-'.join(feature_name_list),
)
def get_redis_key(slug, feature_name_list, batchid_counter):
return '%s%s' % (
get_redis_key_core(slug, feature_name_list),
batchid_counter,
)
def get_task_id(slug, feature_name_list):
return '%stask' % (
get_redis_key_core(slug, feature_name_list),
)
def get_descstore_dirname(netid, feature_name):
dirname = '{}-{}'.format(netid, feature_name)
# clean up
invalid_chars = ['/', '(', ')', '[', ']', '{', '}', ' ', ':']
for c in invalid_chars:
dirname = dirname.replace(c, '-')
return dirname
def get_descstore_filename(netid, feature_name):
dirname = get_descstore_dirname(netid, feature_name)
return '{}.hdf5'.format(dirname)
def get_descstore_filepaths(desc_rootpath, feature_name_list, caffe_cnn,
snapshot_id, slug_extra=''):
slug, snapshot_id = get_slug(caffe_cnn, snapshot_id, slug_extra)
return [
os.path.join(desc_rootpath, get_descstore_filename(slug, feature_name))
for feature_name in feature_name_list
]
def dispatch_feature_comp(
desc_rootpath,
item_type,
item_ids,
node_batchsize,
feature_name_list,
num_dims_list,
fetcomp_func,
fetcomp_kwargs,
slug,
verbose=False,
):
if verbose:
print 'Dispatching feature computation for incomplete {} models'.format(item_type)
# TODO: Use DescriptorAggregator instead, so we don't have to assemble the filename...
completed_ids = set()
for feature_name in feature_name_list:
filename = get_descstore_filename(slug, feature_name)
if os.path.exists(os.path.join(desc_rootpath, filename)):
src_store = DescriptorStoreHdf5(
path=os.path.join(desc_rootpath, filename),
readonly=True,
verbose=verbose,
)
if verbose:
print 'Loaded completed descriptors for {}.'.format(filename)
# Get already completed ids
current_ids = set(src_store.ids[...])
if completed_ids:
completed_ids.intersection_update(current_ids)
else:
completed_ids = current_ids
if verbose:
print 'Getting ids for incomplete items...'
todo_ids = item_ids.difference(completed_ids)
if verbose:
print '{} ids completed already'.format(len(completed_ids))
print '{} ids to do...'.format(len(todo_ids))
task_id = get_task_id(slug, feature_name_list)
if verbose:
pbar = ProgressBar(widgets=progress_bar_widgets(), maxval=len(todo_ids))
pbar_counter = 0
pbar.start()
batchid_counter = 0
batch = []
for batch in iter_batch(todo_ids, node_batchsize):
# Dispatch job
batch_id = get_redis_key(
slug,
feature_name_list,
batchid_counter,
)
fetcomp_func.delay(
item_type,
task_id,
batch_id,
batch,
feature_name_list,
fetcomp_kwargs,
)
batchid_counter += 1
if verbose:
pbar_counter += len(batch)
pbar.update(pbar_counter)
if verbose:
pbar.finish()
def extract_item_ids(item_type, item_ids):
# If we are using redis objects instead of database objects, we need to
# extract the object IDs from the redis keys
if item_type == 'redis':
all_ids = [
RedisItemKey.create_from_key(key).item_id
for key in item_ids
]
else:
all_ids = list(item_ids)
return all_ids
def aggregate_feature_comp(
desc_rootpath,
item_type,
item_ids,
feature_name_list,
num_dims_list,
aggr_batchsize,
slug,
handle_interrupt_signal=True,
verbose=False,
):
if verbose:
print 'Aggregating feature computation for {} models'.format(item_type)
if handle_interrupt_signal:
patch_interrupt_signal()
dirname_list = [
get_descstore_dirname(slug, feature_name)
for feature_name in feature_name_list
]
aggregator = DescriptorAggregator(
feature_name_list=feature_name_list, filename_list=dirname_list,
num_dims_list=num_dims_list, verbose=verbose
)
aggregator.load(desc_rootpath, readonly=False)
task_id = get_task_id(slug, feature_name_list)
# If we use 'redis' objects, this is important
all_ids = extract_item_ids(item_type=item_type, item_ids=item_ids)
ret = aggregator.run(
all_ids=all_ids, task_id=task_id, aggr_batchsize=aggr_batchsize)
if handle_interrupt_signal:
detach_patch_interrupt_signal()
return ret
def retrieve_features(
desc_rootpath,
item_type,
item_ids,
feature_name_list,
slug,
verbose=False,
):
ret_item_ids = item_ids is None
if item_ids is None:
item_ids_to_get = None
else:
# If we use 'redis' objects, this is important
# Note: The order of item_ids is very important!
item_ids_to_get = extract_item_ids(item_type=item_type, item_ids=item_ids)
features = {}
for feature_name in feature_name_list:
filename = get_descstore_filename(slug, feature_name)
src_store = DescriptorStoreHdf5(
path=os.path.join(desc_rootpath, filename),
readonly=True,
verbose=verbose,
)
# Will be called only once
if item_ids_to_get is None:
item_ids_to_get = src_store.ids[...]
features[feature_name] = src_store.block_get(
item_ids_to_get, show_progress=verbose
)
del src_store
if ret_item_ids:
return item_ids_to_get, features
else:
return features
def compute_features(
desc_rootpath,
item_type,
item_ids,
node_batchsize,
aggr_batchsize,
feature_name_list,
num_dims_list,
fetcomp_func,
fetcomp_kwargs,
slug,
handle_interrupt_signal=True,
verbose=False,
):
"""
Extracts the specified feature for the specified items using a trained CNN.
:param desc_rootpath: The root path of the directory where the computed
features will be stored in.
:param item_type: The type of the model class for the item which are
classified (e.g. FgPhoto). This class should have 'title', 'photo'
attributes/properties. The photo attribute should have most of the Photo
model's fields. It is advised to use an actual Photo instance here.
:param item_ids: List (or numpy array) of ids into the :ref:`item_type`
table. The length of this list is the same as the length of :ref:`y_true`
list and they have the same order.
:param node_batchsize: The number of feature computations to put in one
task executed on a worker node.
:param aggr_batchsize: The number of batches to wait for before forcing the
aggregator to download and remove those batches from redis.
:param feature_name_list: The features' names in the network which will be
extracted.
:param num_dims_list: Dimensions of the computed features.
:param fetcomp_func: The function which will be executed to compute
features.
:param fetcomp_kwargs: The parameters to pass to the feature computer
function.
:param slug: The unique humanly readable name associated with the feature
computation task.
:param handle_interrupt_signal: If True, we patch the interrupt signal so the descriptor store is saved before exiting, when the user hits Ctrl + C.
:param verbose: If True, print progress information to the console.
"""
item_ids_set = set(item_ids)
single_feature = False
if isinstance(feature_name_list, (str, unicode)):
if isinstance(num_dims_list, list):
raise ValueError(
'If "feature_name_list" is not specified as a list, '
'"num_dims_list" has to be a single number!'
)
feature_name_list = [feature_name_list]
num_dims_list = [num_dims_list]
single_feature = True
dispatch_feature_comp(
desc_rootpath=desc_rootpath,
item_type=item_type,
item_ids=item_ids_set,
node_batchsize=node_batchsize,
feature_name_list=feature_name_list,
num_dims_list=num_dims_list,
fetcomp_func=fetcomp_func,
fetcomp_kwargs=fetcomp_kwargs,
slug=slug,
verbose=verbose,
)
was_interrupted = not aggregate_feature_comp(
desc_rootpath=desc_rootpath,
item_type=item_type,
item_ids=item_ids_set,
feature_name_list=feature_name_list,
num_dims_list=num_dims_list,
aggr_batchsize=aggr_batchsize,
slug=slug,
handle_interrupt_signal=handle_interrupt_signal,
verbose=verbose,
)
if was_interrupted:
return None
# The order of the item ids is important!
fets = retrieve_features(
desc_rootpath=desc_rootpath,
item_type=item_type,
item_ids=item_ids,
feature_name_list=feature_name_list,
slug=slug,
verbose=verbose,
)
if single_feature:
fets = fets[feature_name_list[0]]
return fets
def compute_cnn_features(
desc_rootpath,
item_type,
item_ids,
node_batchsize,
aggr_batchsize,
caffe_cnn,
feature_name_list,
num_dims_list,
snapshot_id,
do_preprocessing=True,
image_dims=None,
mean=None,
grayscale=False,
auto_reshape=False,
transfer_weights=False,
slug_extra='',
input_trafo_func_name=None,
input_trafo_kwargs=None,
fet_trafo_type_id=None,
fet_trafo_kwargs=None,
handle_interrupt_signal=True,
verbose=False,
):
"""
Extracts the specified feature for the specified items using a trained CNN.
:param desc_rootpath: The root path of the directory where the computed
features will be stored in.
:param item_type: The type of the model class for the item which are
classified (e.g. FgPhoto). This class should have 'title', 'photo'
attributes/properties. The photo attribute should have most of the Photo
model's fields. It is advised to use an actual Photo instance here.
:param item_ids: List (or numpy array) of ids into the :ref:`item_type`
table. The length of this list is the same as the length of :ref:`y_true`
list and they have the same order.
:param node_batchsize: The number of feature computations to put in one
task executed on a worker node.
:param aggr_batchsize: The number of batches to wait for before forcing the
aggregator to download and remove those batches from redis.
:param caffe_cnn: The CaffeCNN instance which represents the CNN which is
used for feature computation.
:param feature_name_list: The features' names in the network which will be
extracted.
:param num_dims_list: Dimensions of the computed features.
:param snapshot_id: The ID of the snapshot to use.
:param do_preprocessing: True, if we should do preprocessing (resizing/cropping the image for example) on the input.
:param image_dims: Tuple with two elements which defines the image
dimensions (width, height). All input images will be resized to this size.
This should be the same as the one was used for training.
:param mean: Tuple with 3 elements which defines the mean which will be
subtracted from all images. This should be the same as the one was used for
training.
:param grayscale: If true we convert all images to grayscale.
:param auto_reshape: Automatically reshapes the network to the size of the
input image. The user should still produce the same feature dimension for
all images after the feature transformation.
:param transfer_weights: Automatically convert the net weights to fully
convolutional form. The deploy file of the model has to be fully
convolutional!
:param slug_extra: Additional string to add to the slug.
:param input_trafo_func_name: Input (usually an image) transformation
primitive. This should include the module and the function name like
'parent_module.child_module.example_function_name'. This module should be
on the python path.
:param input_trafo_kwargs: Keyword arguments for the chosen image
transformation primitive.
:param fet_trafo_type_id: Feature transformation primitive. Currently the
user can select from:
['MINC-spatial-avg', 'MINC-gram', 'MINC-mean-std']
:param fet_trafo_kwargs: Keyword arguments for the chosen feature
transformation primitive.
:param handle_interrupt_signal: If True, we patch the interrupt signal so the descriptor store is saved before exiting, when the user hits Ctrl + C.
:param verbose: If True, print progress information to the console.
"""
slug, snapshot_id = get_slug(caffe_cnn, snapshot_id, slug_extra)
fetcomp_kwargs = {
'snapshot_id': snapshot_id,
'do_preprocessing': do_preprocessing,
'image_dims': image_dims,
'mean': mean,
'grayscale': grayscale,
'auto_reshape': auto_reshape,
'transfer_weights': transfer_weights,
'input_trafo_func_name': input_trafo_func_name,
'input_trafo_kwargs': input_trafo_kwargs,
'fet_trafo_type_id': fet_trafo_type_id,
'fet_trafo_kwargs': fet_trafo_kwargs,
}
return compute_features(
desc_rootpath,
item_type,
item_ids,
node_batchsize,
aggr_batchsize,
feature_name_list,
num_dims_list,
compute_cnn_features_gpu_task,
fetcomp_kwargs,
slug,
handle_interrupt_signal=handle_interrupt_signal,
verbose=verbose,
)