1
1
import os .path as osp
2
+ import pickle
3
+ import shutil
2
4
import tempfile
3
5
4
6
import mmcv
5
7
import numpy as np
6
8
import torch
9
+ import torch .distributed as dist
7
10
from mmcv .engine import collect_results_cpu , collect_results_gpu
8
11
from mmcv .image import tensor2imgs
9
12
from mmcv .runner import get_dist_info
10
13
11
- from mmseg .core .evaluation .metrics import intersect_and_union
14
+ from mmseg .core .evaluation .metrics import ResultProcessor
12
15
13
16
14
17
def np2tmp (array , temp_file_name = None , tmpdir = None ):
@@ -169,24 +172,39 @@ def multi_gpu_test(model,
169
172
170
173
def progressive_single_gpu_test (model ,
171
174
data_loader ,
175
+ efficient_test ,
172
176
show = False ,
173
177
out_dir = None ,
174
178
opacity = 0.5 ):
175
179
model .eval ()
176
180
dataset = data_loader .dataset
177
- num_classes = len (dataset .CLASSES )
178
181
prog_bar = mmcv .ProgressBar (len (dataset ))
179
182
180
- total_area_intersect = torch .zeros ((num_classes , ), dtype = torch .float64 )
181
- total_area_union = torch .zeros ((num_classes , ), dtype = torch .float64 )
182
- total_area_pred_label = torch .zeros ((num_classes , ), dtype = torch .float64 )
183
- total_area_label = torch .zeros ((num_classes , ), dtype = torch .float64 )
183
+ if efficient_test :
184
+ collector = ResultProcessor (
185
+ num_classes = len (dataset .CLASSES ),
186
+ ignore_index = dataset .ignore_index ,
187
+ collect_type = 'pixels_count' ,
188
+ label_map = dataset .label_map ,
189
+ reduce_zero_label = dataset .reduce_zero_label )
190
+ else :
191
+ collector = ResultProcessor (
192
+ num_classes = len (dataset .CLASSES ),
193
+ ignore_index = dataset .ignore_index ,
194
+ collect_type = 'seg_map' ,
195
+ label_map = dataset .label_map ,
196
+ reduce_zero_label = dataset .reduce_zero_label )
197
+
198
+ gt_maps_generator = dataset .get_gt_seg_maps ()
184
199
185
- cur = 0
186
200
for _ , data in enumerate (data_loader ):
187
201
with torch .no_grad ():
188
202
result = model (return_loss = False , ** data )
189
203
204
+ gt_map = next (gt_maps_generator )
205
+ meta = data ['img_metas' ][0 ].data
206
+ collector .collect (result , gt_map , meta )
207
+
190
208
if show or out_dir :
191
209
img_tensor = data ['img' ][0 ]
192
210
img_metas = data ['img_metas' ][0 ].data [0 ]
@@ -213,101 +231,163 @@ def progressive_single_gpu_test(model,
213
231
out_file = out_file ,
214
232
opacity = opacity )
215
233
216
- for i in range (len (result )):
217
- gt_semantic_map = dataset .get_gt_seg_map (cur + i )
218
-
219
- area_intersect , area_union , area_pred_label , area_label = \
220
- intersect_and_union (
221
- result [i ], gt_semantic_map , num_classes ,
222
- dataset .ignore_index , dataset .label_map ,
223
- dataset .reduce_zero_label )
224
-
225
- total_area_intersect += area_intersect
226
- total_area_union += area_union
227
- total_area_pred_label += area_pred_label
228
- total_area_label += area_label
229
-
230
- print (total_area_intersect / total_area_union )
231
-
234
+ batch_size = len (result )
235
+ for _ in range (batch_size ):
232
236
prog_bar .update ()
233
237
234
- cur += len (result )
235
-
236
- return total_area_intersect , total_area_union , total_area_pred_label , \
237
- total_area_label
238
+ return collector
238
239
239
240
240
241
# TODO: Support distributed test api
241
242
def progressive_multi_gpu_test (model ,
242
243
data_loader ,
244
+ efficient_test ,
243
245
tmpdir = None ,
244
246
gpu_collect = False ):
245
247
246
248
model .eval ()
247
249
dataset = data_loader .dataset
248
- num_classes = len (dataset .CLASSES )
250
+ if efficient_test :
251
+ collector = ResultProcessor (
252
+ num_classes = len (dataset .CLASSES ),
253
+ ignore_index = dataset .ignore_index ,
254
+ collect_type = 'pixelx_count' ,
255
+ label_map = dataset .label_map ,
256
+ reduce_zero_label = dataset .reduce_zero_label )
257
+ else :
258
+ collector = ResultProcessor (
259
+ num_classes = len (dataset .CLASSES ),
260
+ ignore_index = dataset .ignore_index ,
261
+ collect_type = 'seg_map' ,
262
+ label_map = dataset .label_map ,
263
+ reduce_zero_label = dataset .reduce_zero_label )
264
+
249
265
rank , world_size = get_dist_info ()
250
266
if rank == 0 :
251
267
prog_bar = mmcv .ProgressBar (len (dataset ))
252
268
253
- total_area_intersect = torch .zeros ((num_classes , ), dtype = torch .float64 )
254
- total_area_union = torch .zeros ((num_classes , ), dtype = torch .float64 )
255
- total_area_pred_label = torch .zeros ((num_classes , ), dtype = torch .float64 )
256
- total_area_label = torch .zeros ((num_classes , ), dtype = torch .float64 )
257
-
258
269
cur = 0
259
270
for _ , data in enumerate (data_loader ):
260
271
with torch .no_grad ():
261
272
result = model (return_loss = False , rescale = True , ** data )
262
273
263
- for i in range (len (result )):
264
- gt_semantic_map = dataset .get_gt_seg_map (cur + i * world_size )
265
-
266
- area_intersect , area_union , area_pred_label , area_label = \
267
- intersect_and_union (
268
- result [i ], gt_semantic_map , num_classes ,
269
- dataset .ignore_index , dataset .label_map ,
270
- dataset .reduce_zero_label )
274
+ gt_seg_map = dataset .index_gt_seg_maps (cur + rank )
275
+ meta = data ['img_metas' ][0 ].data
276
+ collector .collect (result , gt_seg_map , meta )
271
277
272
- total_area_intersect += area_intersect
273
- total_area_union += area_union
274
- total_area_pred_label += area_pred_label
275
- total_area_label += area_label
276
-
277
- if rank == 0 :
278
- for _ in range (len (result ) * world_size ):
279
- prog_bar .update ()
278
+ if rank == 0 :
279
+ for _ in range (len (result ) * world_size ):
280
+ prog_bar .update ()
280
281
281
282
cur += len (result ) * world_size
282
283
283
- pixel_count_matrix = [
284
- total_area_intersect , total_area_union , total_area_pred_label ,
285
- total_area_label
286
- ]
287
284
# collect results from all ranks
288
285
if gpu_collect :
289
- results = collect_count_results_gpu ( pixel_count_matrix , 4 * world_size )
286
+ collector = collect_collector_gpu ( collector )
290
287
else :
291
- results = collect_count_results_cpu (pixel_count_matrix , 4 * world_size ,
292
- tmpdir )
293
- return results
288
+ collector = collect_collector_cpu (collector , tmpdir )
289
+ return collector
294
290
295
291
296
- def collect_count_results_gpu ( result_part , size ):
297
- """Collect pixel count matrix result under gpu mode.
292
+ def collect_collector_gpu ( collector ):
293
+ """Collect result collectors under gpu mode.
298
294
299
295
On gpu mode, this function will encode results to gpu tensors and use gpu
300
296
communication for results collection.
301
297
302
298
Args:
303
- result_part (list[Tensor]): four type of pixel count matrix --
304
- {area_intersect, area_union, area_pred_label, area_label}, These
305
- four tensor shape of (num_classes, ).
306
- size (int): Size of the results, commonly equal to length of
307
- the results.
299
+ collector (object): Result collector containing predictions and labels
300
+ to be collected.
301
+ Returns:
302
+ object: The gathered collector.
308
303
"""
309
- pass
304
+ rank , world_size = get_dist_info ()
305
+ # dump result part to tensor with pickle
306
+ part_tensor = torch .tensor (
307
+ bytearray (pickle .dumps (collector )), dtype = torch .uint8 , device = 'cuda' )
308
+ # gather all result part tensor shape
309
+ shape_tensor = torch .tensor (part_tensor .shape , device = 'cuda' )
310
+ shape_list = [shape_tensor .clone () for _ in range (world_size )]
311
+ dist .all_gather (shape_list , shape_tensor )
312
+ # padding result part tensor to max length
313
+ shape_max = torch .tensor (shape_list ).max ()
314
+ part_send = torch .zeros (shape_max , dtype = torch .uint8 , device = 'cuda' )
315
+ part_send [:shape_tensor [0 ]] = part_tensor
316
+ part_recv_list = [
317
+ part_tensor .new_zeros (shape_max ) for _ in range (world_size )
318
+ ]
319
+ # gather all result part
320
+ dist .all_gather (part_recv_list , part_send )
310
321
322
+ if rank == 0 :
323
+ # load results of all parts from tmp dir
324
+ main_collector = pickle .loads (
325
+ part_recv_list [0 ][:shape_list [0 ]].cpu ().numpy ().tobytes ())
326
+ sub_collectors = []
327
+ for recv , shape in zip (part_recv_list , shape_list ):
328
+ part_collector = pickle .loads (
329
+ recv [:shape [0 ]].cpu ().numpy ().tobytes ())
330
+ # When data is severely insufficient, an empty part_result
331
+ # on a certain gpu could makes the overall outputs empty.
332
+ if part_collector :
333
+ sub_collectors .append (part_collector )
334
+ main_collector .merge (sub_collectors )
335
+ return main_collector
336
+
337
+
338
+ def collect_collector_cpu (collector , tmpdir = None ):
339
+ """Collect result collectors under cpu mode.
340
+
341
+ On cpu mode, this function will save the result collectors on different
342
+ gpus to``tmpdir`` and collect them by the rank 0 worker.
311
343
312
- def collect_count_results_cpu (result_part , size , tmpdir = None ):
313
- pass
344
+ Args:
345
+ collector (object): Result collector containing predictions and labels
346
+ to be collected.
347
+ tmpdir (str | None): temporal directory for collected results to
348
+ store. If set to None, it will create a random temporal directory
349
+ for it.
350
+
351
+ Returns:
352
+ object: The gathered collector.
353
+ """
354
+ rank , world_size = get_dist_info ()
355
+ # create a tmp dir if it is not specified
356
+ if tmpdir is None :
357
+ MAX_LEN = 512
358
+ # 32 is whitespace
359
+ dir_tensor = torch .full ((MAX_LEN , ),
360
+ 32 ,
361
+ dtype = torch .uint8 ,
362
+ device = 'cuda' )
363
+ if rank == 0 :
364
+ mmcv .mkdir_or_exist ('.dist_test' )
365
+ tmpdir = tempfile .mkdtemp (dir = '.dist_test' )
366
+ tmpdir = torch .tensor (
367
+ bytearray (tmpdir .encode ()), dtype = torch .uint8 , device = 'cuda' )
368
+ dir_tensor [:len (tmpdir )] = tmpdir
369
+ dist .broadcast (dir_tensor , 0 )
370
+ tmpdir = dir_tensor .cpu ().numpy ().tobytes ().decode ().rstrip ()
371
+ else :
372
+ mmcv .mkdir_or_exist (tmpdir )
373
+ # dump the part result to the dir
374
+ mmcv .dump (collector , osp .join (tmpdir , f'part_{ rank } .pkl' ))
375
+ dist .barrier ()
376
+ # collect all parts
377
+ if rank != 0 :
378
+ return None
379
+ else :
380
+ # load results of all parts from tmp dir
381
+ main_collector = mmcv .load (osp .join (tmpdir , f'part_{ 0 } .pkl' ))
382
+ sub_collectors = []
383
+ for i in range (1 , world_size ):
384
+ part_file = osp .join (tmpdir , f'part_{ i } .pkl' )
385
+ part_collector = mmcv .load (part_file )
386
+ # When data is severely insufficient, an empty part_result
387
+ # on a certain gpu could makes the overall outputs empty.
388
+ if part_collector :
389
+ sub_collectors .append (part_collector )
390
+ main_collector .merge (sub_collectors )
391
+ # remove tmp dir
392
+ shutil .rmtree (tmpdir )
393
+ return main_collector
0 commit comments