Skip to content

Commit e86a87f

Browse files
author
谢昕辰
authored
pytorch metrics implementation (#430)
* pytorch metrics impl and test * support list[str] input, delete unused test code and delete numpy version * modify input data type * add docstring and unitest of filename inputs * add indents in docstring and use tempfile lib to create dir * using with statement
1 parent 340132d commit e86a87f

File tree

2 files changed

+97
-41
lines changed

2 files changed

+97
-41
lines changed

mmseg/core/evaluation/metrics.py

+55-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import mmcv
22
import numpy as np
3+
import torch
34

45

56
def intersect_and_union(pred_label,
@@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
1112
"""Calculate intersection and Union.
1213
1314
Args:
14-
pred_label (ndarray): Prediction segmentation map.
15-
label (ndarray): Ground truth segmentation map.
15+
pred_label (ndarray | str): Prediction segmentation map
16+
or predict result filename.
17+
label (ndarray | str): Ground truth segmentation map
18+
or label filename.
1619
num_classes (int): Number of categories.
1720
ignore_index (int): Index that will be ignored in evaluation.
1821
label_map (dict): Mapping old labels to new labels. The parameter will
@@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
2124
work only when label is str. Default: False.
2225
2326
Returns:
24-
ndarray: The intersection of prediction and ground truth histogram
25-
on all classes.
26-
ndarray: The union of prediction and ground truth histogram on all
27-
classes.
28-
ndarray: The prediction histogram on all classes.
29-
ndarray: The ground truth histogram on all classes.
27+
torch.Tensor: The intersection of prediction and ground truth
28+
histogram on all classes.
29+
torch.Tensor: The union of prediction and ground truth histogram on
30+
all classes.
31+
torch.Tensor: The prediction histogram on all classes.
32+
torch.Tensor: The ground truth histogram on all classes.
3033
"""
3134

3235
if isinstance(pred_label, str):
33-
pred_label = np.load(pred_label)
36+
pred_label = torch.from_numpy(np.load(pred_label))
37+
else:
38+
pred_label = torch.from_numpy((pred_label))
3439

3540
if isinstance(label, str):
36-
label = mmcv.imread(label, flag='unchanged', backend='pillow')
37-
# modify if custom classes
41+
label = torch.from_numpy(
42+
mmcv.imread(label, flag='unchanged', backend='pillow'))
43+
else:
44+
label = torch.from_numpy(label)
45+
3846
if label_map is not None:
3947
for old_id, new_id in label_map.items():
4048
label[label == old_id] = new_id
4149
if reduce_zero_label:
42-
# avoid using underflow conversion
4350
label[label == 0] = 255
4451
label = label - 1
4552
label[label == 254] = 255
@@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
4956
label = label[mask]
5057

5158
intersect = pred_label[pred_label == label]
52-
area_intersect, _ = np.histogram(
53-
intersect, bins=np.arange(num_classes + 1))
54-
area_pred_label, _ = np.histogram(
55-
pred_label, bins=np.arange(num_classes + 1))
56-
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
59+
area_intersect = torch.histc(
60+
intersect.float(), bins=(num_classes), min=0, max=num_classes)
61+
area_pred_label = torch.histc(
62+
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
63+
area_label = torch.histc(
64+
label.float(), bins=(num_classes), min=0, max=num_classes)
5765
area_union = area_pred_label + area_label - area_intersect
58-
5966
return area_intersect, area_union, area_pred_label, area_label
6067

6168

@@ -68,8 +75,10 @@ def total_intersect_and_union(results,
6875
"""Calculate Total Intersection and Union.
6976
7077
Args:
71-
results (list[ndarray]): List of prediction segmentation maps.
72-
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
78+
results (list[ndarray] | list[str]): List of prediction segmentation
79+
maps or list of prediction result filenames.
80+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
81+
segmentation maps or list of label filenames.
7382
num_classes (int): Number of categories.
7483
ignore_index (int): Index that will be ignored in evaluation.
7584
label_map (dict): Mapping old labels to new labels. Default: dict().
@@ -83,23 +92,23 @@ def total_intersect_and_union(results,
8392
ndarray: The prediction histogram on all classes.
8493
ndarray: The ground truth histogram on all classes.
8594
"""
86-
8795
num_imgs = len(results)
8896
assert len(gt_seg_maps) == num_imgs
89-
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
90-
total_area_union = np.zeros((num_classes, ), dtype=np.float)
91-
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
92-
total_area_label = np.zeros((num_classes, ), dtype=np.float)
97+
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
98+
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
99+
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
100+
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
93101
for i in range(num_imgs):
94102
area_intersect, area_union, area_pred_label, area_label = \
95-
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
96-
ignore_index, label_map, reduce_zero_label)
103+
intersect_and_union(
104+
results[i], gt_seg_maps[i], num_classes, ignore_index,
105+
label_map, reduce_zero_label)
97106
total_area_intersect += area_intersect
98107
total_area_union += area_union
99108
total_area_pred_label += area_pred_label
100109
total_area_label += area_label
101-
return total_area_intersect, total_area_union, \
102-
total_area_pred_label, total_area_label
110+
return total_area_intersect, total_area_union, total_area_pred_label, \
111+
total_area_label
103112

104113

105114
def mean_iou(results,
@@ -112,8 +121,10 @@ def mean_iou(results,
112121
"""Calculate Mean Intersection and Union (mIoU)
113122
114123
Args:
115-
results (list[ndarray]): List of prediction segmentation maps.
116-
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
124+
results (list[ndarray] | list[str]): List of prediction segmentation
125+
maps or list of prediction result filenames.
126+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
127+
segmentation maps or list of label filenames.
117128
num_classes (int): Number of categories.
118129
ignore_index (int): Index that will be ignored in evaluation.
119130
nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -126,7 +137,6 @@ def mean_iou(results,
126137
ndarray: Per category accuracy, shape (num_classes, ).
127138
ndarray: Per category IoU, shape (num_classes, ).
128139
"""
129-
130140
all_acc, acc, iou = eval_metrics(
131141
results=results,
132142
gt_seg_maps=gt_seg_maps,
@@ -149,8 +159,10 @@ def mean_dice(results,
149159
"""Calculate Mean Dice (mDice)
150160
151161
Args:
152-
results (list[ndarray]): List of prediction segmentation maps.
153-
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
162+
results (list[ndarray] | list[str]): List of prediction segmentation
163+
maps or list of prediction result filenames.
164+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
165+
segmentation maps or list of label filenames.
154166
num_classes (int): Number of categories.
155167
ignore_index (int): Index that will be ignored in evaluation.
156168
nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -186,8 +198,10 @@ def eval_metrics(results,
186198
reduce_zero_label=False):
187199
"""Calculate evaluation metrics
188200
Args:
189-
results (list[ndarray]): List of prediction segmentation maps.
190-
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
201+
results (list[ndarray] | list[str]): List of prediction segmentation
202+
maps or list of prediction result filenames.
203+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
204+
segmentation maps or list of label filenames.
191205
num_classes (int): Number of categories.
192206
ignore_index (int): Index that will be ignored in evaluation.
193207
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@@ -200,17 +214,16 @@ def eval_metrics(results,
200214
ndarray: Per category accuracy, shape (num_classes, ).
201215
ndarray: Per category evalution metrics, shape (num_classes, ).
202216
"""
203-
204217
if isinstance(metrics, str):
205218
metrics = [metrics]
206219
allowed_metrics = ['mIoU', 'mDice']
207220
if not set(metrics).issubset(set(allowed_metrics)):
208221
raise KeyError('metrics {} is not supported'.format(metrics))
222+
209223
total_area_intersect, total_area_union, total_area_pred_label, \
210-
total_area_label = total_intersect_and_union(results, gt_seg_maps,
211-
num_classes, ignore_index,
212-
label_map,
213-
reduce_zero_label)
224+
total_area_label = total_intersect_and_union(
225+
results, gt_seg_maps, num_classes, ignore_index, label_map,
226+
reduce_zero_label)
214227
all_acc = total_area_intersect.sum() / total_area_label.sum()
215228
acc = total_area_intersect / total_area_label
216229
ret_metrics = [all_acc, acc]
@@ -222,6 +235,7 @@ def eval_metrics(results,
222235
dice = 2 * total_area_intersect / (
223236
total_area_pred_label + total_area_label)
224237
ret_metrics.append(dice)
238+
ret_metrics = [metric.numpy() for metric in ret_metrics]
225239
if nan_to_num is not None:
226240
ret_metrics = [
227241
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics

tests/test_metrics.py

+42
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,45 @@ def test_mean_dice():
164164
results, label, num_classes, ignore_index=255, nan_to_num=-1)
165165
assert acc[-1] == -1
166166
assert iou[-1] == -1
167+
168+
169+
def test_filename_inputs():
170+
import cv2
171+
import tempfile
172+
173+
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
174+
filenames = []
175+
SUFFIX = '.png' if is_image else '.npy'
176+
for idx, arr in enumerate(input_arrays):
177+
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
178+
if is_image:
179+
cv2.imwrite(filename, arr)
180+
else:
181+
np.save(filename, arr)
182+
filenames.append(filename)
183+
return filenames
184+
185+
pred_size = (10, 512, 1024)
186+
num_classes = 19
187+
ignore_index = 255
188+
results = np.random.randint(0, num_classes, size=pred_size)
189+
labels = np.random.randint(0, num_classes, size=pred_size)
190+
labels[:, 2, 5:10] = ignore_index
191+
192+
with tempfile.TemporaryDirectory() as temp_dir:
193+
194+
result_files = save_arr(results, 'pred', False, temp_dir)
195+
label_files = save_arr(labels, 'label', True, temp_dir)
196+
197+
all_acc, acc, iou = eval_metrics(
198+
result_files,
199+
label_files,
200+
num_classes,
201+
ignore_index,
202+
metrics='mIoU')
203+
204+
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
205+
ignore_index)
206+
assert all_acc == all_acc_l
207+
assert np.allclose(acc, acc_l)
208+
assert np.allclose(iou, iou_l)

0 commit comments

Comments
 (0)