Skip to content

Commit 9dfe1c0

Browse files
author
sennnnn
committed
Temp code.
1 parent f3aaecc commit 9dfe1c0

File tree

6 files changed

+393
-221
lines changed

6 files changed

+393
-221
lines changed

mmseg/apis/test.py

+147-67
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import os.path as osp
2+
import pickle
3+
import shutil
24
import tempfile
35

46
import mmcv
57
import numpy as np
68
import torch
9+
import torch.distributed as dist
710
from mmcv.engine import collect_results_cpu, collect_results_gpu
811
from mmcv.image import tensor2imgs
912
from mmcv.runner import get_dist_info
1013

11-
from mmseg.core.evaluation.metrics import intersect_and_union
14+
from mmseg.core.evaluation.metrics import ResultProcessor
1215

1316

1417
def np2tmp(array, temp_file_name=None, tmpdir=None):
@@ -169,24 +172,39 @@ def multi_gpu_test(model,
169172

170173
def progressive_single_gpu_test(model,
171174
data_loader,
175+
efficient_test,
172176
show=False,
173177
out_dir=None,
174178
opacity=0.5):
175179
model.eval()
176180
dataset = data_loader.dataset
177-
num_classes = len(dataset.CLASSES)
178181
prog_bar = mmcv.ProgressBar(len(dataset))
179182

180-
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
181-
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
182-
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
183-
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
183+
if efficient_test:
184+
collector = ResultProcessor(
185+
num_classes=len(dataset.CLASSES),
186+
ignore_index=dataset.ignore_index,
187+
collect_type='pixels_count',
188+
label_map=dataset.label_map,
189+
reduce_zero_label=dataset.reduce_zero_label)
190+
else:
191+
collector = ResultProcessor(
192+
num_classes=len(dataset.CLASSES),
193+
ignore_index=dataset.ignore_index,
194+
collect_type='seg_map',
195+
label_map=dataset.label_map,
196+
reduce_zero_label=dataset.reduce_zero_label)
197+
198+
gt_maps_generator = dataset.get_gt_seg_maps()
184199

185-
cur = 0
186200
for _, data in enumerate(data_loader):
187201
with torch.no_grad():
188202
result = model(return_loss=False, **data)
189203

204+
gt_map = next(gt_maps_generator)
205+
meta = data['img_metas'][0].data
206+
collector.collect(result, gt_map, meta)
207+
190208
if show or out_dir:
191209
img_tensor = data['img'][0]
192210
img_metas = data['img_metas'][0].data[0]
@@ -213,101 +231,163 @@ def progressive_single_gpu_test(model,
213231
out_file=out_file,
214232
opacity=opacity)
215233

216-
for i in range(len(result)):
217-
gt_semantic_map = dataset.get_gt_seg_map(cur + i)
218-
219-
area_intersect, area_union, area_pred_label, area_label = \
220-
intersect_and_union(
221-
result[i], gt_semantic_map, num_classes,
222-
dataset.ignore_index, dataset.label_map,
223-
dataset.reduce_zero_label)
224-
225-
total_area_intersect += area_intersect
226-
total_area_union += area_union
227-
total_area_pred_label += area_pred_label
228-
total_area_label += area_label
229-
230-
print(total_area_intersect / total_area_union)
231-
234+
batch_size = len(result)
235+
for _ in range(batch_size):
232236
prog_bar.update()
233237

234-
cur += len(result)
235-
236-
return total_area_intersect, total_area_union, total_area_pred_label, \
237-
total_area_label
238+
return collector
238239

239240

240241
# TODO: Support distributed test api
241242
def progressive_multi_gpu_test(model,
242243
data_loader,
244+
efficient_test,
243245
tmpdir=None,
244246
gpu_collect=False):
245247

246248
model.eval()
247249
dataset = data_loader.dataset
248-
num_classes = len(dataset.CLASSES)
250+
if efficient_test:
251+
collector = ResultProcessor(
252+
num_classes=len(dataset.CLASSES),
253+
ignore_index=dataset.ignore_index,
254+
collect_type='pixelx_count',
255+
label_map=dataset.label_map,
256+
reduce_zero_label=dataset.reduce_zero_label)
257+
else:
258+
collector = ResultProcessor(
259+
num_classes=len(dataset.CLASSES),
260+
ignore_index=dataset.ignore_index,
261+
collect_type='seg_map',
262+
label_map=dataset.label_map,
263+
reduce_zero_label=dataset.reduce_zero_label)
264+
249265
rank, world_size = get_dist_info()
250266
if rank == 0:
251267
prog_bar = mmcv.ProgressBar(len(dataset))
252268

253-
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
254-
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
255-
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
256-
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
257-
258269
cur = 0
259270
for _, data in enumerate(data_loader):
260271
with torch.no_grad():
261272
result = model(return_loss=False, rescale=True, **data)
262273

263-
for i in range(len(result)):
264-
gt_semantic_map = dataset.get_gt_seg_map(cur + i * world_size)
265-
266-
area_intersect, area_union, area_pred_label, area_label = \
267-
intersect_and_union(
268-
result[i], gt_semantic_map, num_classes,
269-
dataset.ignore_index, dataset.label_map,
270-
dataset.reduce_zero_label)
274+
gt_seg_map = dataset.index_gt_seg_maps(cur + rank)
275+
meta = data['img_metas'][0].data
276+
collector.collect(result, gt_seg_map, meta)
271277

272-
total_area_intersect += area_intersect
273-
total_area_union += area_union
274-
total_area_pred_label += area_pred_label
275-
total_area_label += area_label
276-
277-
if rank == 0:
278-
for _ in range(len(result) * world_size):
279-
prog_bar.update()
278+
if rank == 0:
279+
for _ in range(len(result) * world_size):
280+
prog_bar.update()
280281

281282
cur += len(result) * world_size
282283

283-
pixel_count_matrix = [
284-
total_area_intersect, total_area_union, total_area_pred_label,
285-
total_area_label
286-
]
287284
# collect results from all ranks
288285
if gpu_collect:
289-
results = collect_count_results_gpu(pixel_count_matrix, 4 * world_size)
286+
collector = collect_collector_gpu(collector)
290287
else:
291-
results = collect_count_results_cpu(pixel_count_matrix, 4 * world_size,
292-
tmpdir)
293-
return results
288+
collector = collect_collector_cpu(collector, tmpdir)
289+
return collector
294290

295291

296-
def collect_count_results_gpu(result_part, size):
297-
"""Collect pixel count matrix result under gpu mode.
292+
def collect_collector_gpu(collector):
293+
"""Collect result collectors under gpu mode.
298294
299295
On gpu mode, this function will encode results to gpu tensors and use gpu
300296
communication for results collection.
301297
302298
Args:
303-
result_part (list[Tensor]): four type of pixel count matrix --
304-
{area_intersect, area_union, area_pred_label, area_label}, These
305-
four tensor shape of (num_classes, ).
306-
size (int): Size of the results, commonly equal to length of
307-
the results.
299+
collector (object): Result collector containing predictions and labels
300+
to be collected.
301+
Returns:
302+
object: The gathered collector.
308303
"""
309-
pass
304+
rank, world_size = get_dist_info()
305+
# dump result part to tensor with pickle
306+
part_tensor = torch.tensor(
307+
bytearray(pickle.dumps(collector)), dtype=torch.uint8, device='cuda')
308+
# gather all result part tensor shape
309+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
310+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
311+
dist.all_gather(shape_list, shape_tensor)
312+
# padding result part tensor to max length
313+
shape_max = torch.tensor(shape_list).max()
314+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
315+
part_send[:shape_tensor[0]] = part_tensor
316+
part_recv_list = [
317+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
318+
]
319+
# gather all result part
320+
dist.all_gather(part_recv_list, part_send)
310321

322+
if rank == 0:
323+
# load results of all parts from tmp dir
324+
main_collector = pickle.loads(
325+
part_recv_list[0][:shape_list[0]].cpu().numpy().tobytes())
326+
sub_collectors = []
327+
for recv, shape in zip(part_recv_list, shape_list):
328+
part_collector = pickle.loads(
329+
recv[:shape[0]].cpu().numpy().tobytes())
330+
# When data is severely insufficient, an empty part_result
331+
# on a certain gpu could makes the overall outputs empty.
332+
if part_collector:
333+
sub_collectors.append(part_collector)
334+
main_collector.merge(sub_collectors)
335+
return main_collector
336+
337+
338+
def collect_collector_cpu(collector, tmpdir=None):
339+
"""Collect result collectors under cpu mode.
340+
341+
On cpu mode, this function will save the result collectors on different
342+
gpus to``tmpdir`` and collect them by the rank 0 worker.
311343
312-
def collect_count_results_cpu(result_part, size, tmpdir=None):
313-
pass
344+
Args:
345+
collector (object): Result collector containing predictions and labels
346+
to be collected.
347+
tmpdir (str | None): temporal directory for collected results to
348+
store. If set to None, it will create a random temporal directory
349+
for it.
350+
351+
Returns:
352+
object: The gathered collector.
353+
"""
354+
rank, world_size = get_dist_info()
355+
# create a tmp dir if it is not specified
356+
if tmpdir is None:
357+
MAX_LEN = 512
358+
# 32 is whitespace
359+
dir_tensor = torch.full((MAX_LEN, ),
360+
32,
361+
dtype=torch.uint8,
362+
device='cuda')
363+
if rank == 0:
364+
mmcv.mkdir_or_exist('.dist_test')
365+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
366+
tmpdir = torch.tensor(
367+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
368+
dir_tensor[:len(tmpdir)] = tmpdir
369+
dist.broadcast(dir_tensor, 0)
370+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
371+
else:
372+
mmcv.mkdir_or_exist(tmpdir)
373+
# dump the part result to the dir
374+
mmcv.dump(collector, osp.join(tmpdir, f'part_{rank}.pkl'))
375+
dist.barrier()
376+
# collect all parts
377+
if rank != 0:
378+
return None
379+
else:
380+
# load results of all parts from tmp dir
381+
main_collector = mmcv.load(osp.join(tmpdir, f'part_{0}.pkl'))
382+
sub_collectors = []
383+
for i in range(1, world_size):
384+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
385+
part_collector = mmcv.load(part_file)
386+
# When data is severely insufficient, an empty part_result
387+
# on a certain gpu could makes the overall outputs empty.
388+
if part_collector:
389+
sub_collectors.append(part_collector)
390+
main_collector.merge(sub_collectors)
391+
# remove tmp dir
392+
shutil.rmtree(tmpdir)
393+
return main_collector

mmseg/core/evaluation/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .class_names import get_classes, get_palette
22
from .eval_hooks import DistEvalHook, EvalHook
3-
from .metrics import (calculate_metrics, eval_metrics, mean_dice, mean_fscore,
4-
mean_iou)
3+
from .metrics import (ResultProcessor, calculate_metrics, eval_metrics,
4+
mean_dice, mean_fscore, mean_iou)
55

66
__all__ = [
77
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
8-
'eval_metrics', 'get_classes', 'get_palette', 'calculate_metrics'
8+
'eval_metrics', 'get_classes', 'get_palette', 'calculate_metrics',
9+
'ResultProcessor'
910
]

0 commit comments

Comments
 (0)