1
1
import mmcv
2
2
import numpy as np
3
+ import torch
3
4
4
5
5
6
def intersect_and_union (pred_label ,
@@ -11,8 +12,10 @@ def intersect_and_union(pred_label,
11
12
"""Calculate intersection and Union.
12
13
13
14
Args:
14
- pred_label (ndarray): Prediction segmentation map.
15
- label (ndarray): Ground truth segmentation map.
15
+ pred_label (ndarray | str): Prediction segmentation map
16
+ or predict result filename.
17
+ label (ndarray | str): Ground truth segmentation map
18
+ or label filename.
16
19
num_classes (int): Number of categories.
17
20
ignore_index (int): Index that will be ignored in evaluation.
18
21
label_map (dict): Mapping old labels to new labels. The parameter will
@@ -21,25 +24,29 @@ def intersect_and_union(pred_label,
21
24
work only when label is str. Default: False.
22
25
23
26
Returns:
24
- ndarray : The intersection of prediction and ground truth histogram
25
- on all classes.
26
- ndarray : The union of prediction and ground truth histogram on all
27
- classes.
28
- ndarray : The prediction histogram on all classes.
29
- ndarray : The ground truth histogram on all classes.
27
+ torch.Tensor : The intersection of prediction and ground truth
28
+ histogram on all classes.
29
+ torch.Tensor : The union of prediction and ground truth histogram on
30
+ all classes.
31
+ torch.Tensor : The prediction histogram on all classes.
32
+ torch.Tensor : The ground truth histogram on all classes.
30
33
"""
31
34
32
35
if isinstance (pred_label , str ):
33
- pred_label = np .load (pred_label )
36
+ pred_label = torch .from_numpy (np .load (pred_label ))
37
+ else :
38
+ pred_label = torch .from_numpy ((pred_label ))
34
39
35
40
if isinstance (label , str ):
36
- label = mmcv .imread (label , flag = 'unchanged' , backend = 'pillow' )
37
- # modify if custom classes
41
+ label = torch .from_numpy (
42
+ mmcv .imread (label , flag = 'unchanged' , backend = 'pillow' ))
43
+ else :
44
+ label = torch .from_numpy (label )
45
+
38
46
if label_map is not None :
39
47
for old_id , new_id in label_map .items ():
40
48
label [label == old_id ] = new_id
41
49
if reduce_zero_label :
42
- # avoid using underflow conversion
43
50
label [label == 0 ] = 255
44
51
label = label - 1
45
52
label [label == 254 ] = 255
@@ -49,13 +56,13 @@ def intersect_and_union(pred_label,
49
56
label = label [mask ]
50
57
51
58
intersect = pred_label [pred_label == label ]
52
- area_intersect , _ = np .histogram (
53
- intersect , bins = np .arange (num_classes + 1 ))
54
- area_pred_label , _ = np .histogram (
55
- pred_label , bins = np .arange (num_classes + 1 ))
56
- area_label , _ = np .histogram (label , bins = np .arange (num_classes + 1 ))
59
+ area_intersect = torch .histc (
60
+ intersect .float (), bins = (num_classes ), min = 0 , max = num_classes )
61
+ area_pred_label = torch .histc (
62
+ pred_label .float (), bins = (num_classes ), min = 0 , max = num_classes )
63
+ area_label = torch .histc (
64
+ label .float (), bins = (num_classes ), min = 0 , max = num_classes )
57
65
area_union = area_pred_label + area_label - area_intersect
58
-
59
66
return area_intersect , area_union , area_pred_label , area_label
60
67
61
68
@@ -68,8 +75,10 @@ def total_intersect_and_union(results,
68
75
"""Calculate Total Intersection and Union.
69
76
70
77
Args:
71
- results (list[ndarray]): List of prediction segmentation maps.
72
- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
78
+ results (list[ndarray] | list[str]): List of prediction segmentation
79
+ maps or list of prediction result filenames.
80
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
81
+ segmentation maps or list of label filenames.
73
82
num_classes (int): Number of categories.
74
83
ignore_index (int): Index that will be ignored in evaluation.
75
84
label_map (dict): Mapping old labels to new labels. Default: dict().
@@ -83,23 +92,23 @@ def total_intersect_and_union(results,
83
92
ndarray: The prediction histogram on all classes.
84
93
ndarray: The ground truth histogram on all classes.
85
94
"""
86
-
87
95
num_imgs = len (results )
88
96
assert len (gt_seg_maps ) == num_imgs
89
- total_area_intersect = np .zeros ((num_classes , ), dtype = np . float )
90
- total_area_union = np .zeros ((num_classes , ), dtype = np . float )
91
- total_area_pred_label = np .zeros ((num_classes , ), dtype = np . float )
92
- total_area_label = np .zeros ((num_classes , ), dtype = np . float )
97
+ total_area_intersect = torch .zeros ((num_classes , ), dtype = torch . float64 )
98
+ total_area_union = torch .zeros ((num_classes , ), dtype = torch . float64 )
99
+ total_area_pred_label = torch .zeros ((num_classes , ), dtype = torch . float64 )
100
+ total_area_label = torch .zeros ((num_classes , ), dtype = torch . float64 )
93
101
for i in range (num_imgs ):
94
102
area_intersect , area_union , area_pred_label , area_label = \
95
- intersect_and_union (results [i ], gt_seg_maps [i ], num_classes ,
96
- ignore_index , label_map , reduce_zero_label )
103
+ intersect_and_union (
104
+ results [i ], gt_seg_maps [i ], num_classes , ignore_index ,
105
+ label_map , reduce_zero_label )
97
106
total_area_intersect += area_intersect
98
107
total_area_union += area_union
99
108
total_area_pred_label += area_pred_label
100
109
total_area_label += area_label
101
- return total_area_intersect , total_area_union , \
102
- total_area_pred_label , total_area_label
110
+ return total_area_intersect , total_area_union , total_area_pred_label , \
111
+ total_area_label
103
112
104
113
105
114
def mean_iou (results ,
@@ -112,8 +121,10 @@ def mean_iou(results,
112
121
"""Calculate Mean Intersection and Union (mIoU)
113
122
114
123
Args:
115
- results (list[ndarray]): List of prediction segmentation maps.
116
- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
124
+ results (list[ndarray] | list[str]): List of prediction segmentation
125
+ maps or list of prediction result filenames.
126
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
127
+ segmentation maps or list of label filenames.
117
128
num_classes (int): Number of categories.
118
129
ignore_index (int): Index that will be ignored in evaluation.
119
130
nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -126,7 +137,6 @@ def mean_iou(results,
126
137
ndarray: Per category accuracy, shape (num_classes, ).
127
138
ndarray: Per category IoU, shape (num_classes, ).
128
139
"""
129
-
130
140
all_acc , acc , iou = eval_metrics (
131
141
results = results ,
132
142
gt_seg_maps = gt_seg_maps ,
@@ -149,8 +159,10 @@ def mean_dice(results,
149
159
"""Calculate Mean Dice (mDice)
150
160
151
161
Args:
152
- results (list[ndarray]): List of prediction segmentation maps.
153
- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
162
+ results (list[ndarray] | list[str]): List of prediction segmentation
163
+ maps or list of prediction result filenames.
164
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
165
+ segmentation maps or list of label filenames.
154
166
num_classes (int): Number of categories.
155
167
ignore_index (int): Index that will be ignored in evaluation.
156
168
nan_to_num (int, optional): If specified, NaN values will be replaced
@@ -186,8 +198,10 @@ def eval_metrics(results,
186
198
reduce_zero_label = False ):
187
199
"""Calculate evaluation metrics
188
200
Args:
189
- results (list[ndarray]): List of prediction segmentation maps.
190
- gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
201
+ results (list[ndarray] | list[str]): List of prediction segmentation
202
+ maps or list of prediction result filenames.
203
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
204
+ segmentation maps or list of label filenames.
191
205
num_classes (int): Number of categories.
192
206
ignore_index (int): Index that will be ignored in evaluation.
193
207
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@@ -200,17 +214,16 @@ def eval_metrics(results,
200
214
ndarray: Per category accuracy, shape (num_classes, ).
201
215
ndarray: Per category evalution metrics, shape (num_classes, ).
202
216
"""
203
-
204
217
if isinstance (metrics , str ):
205
218
metrics = [metrics ]
206
219
allowed_metrics = ['mIoU' , 'mDice' ]
207
220
if not set (metrics ).issubset (set (allowed_metrics )):
208
221
raise KeyError ('metrics {} is not supported' .format (metrics ))
222
+
209
223
total_area_intersect , total_area_union , total_area_pred_label , \
210
- total_area_label = total_intersect_and_union (results , gt_seg_maps ,
211
- num_classes , ignore_index ,
212
- label_map ,
213
- reduce_zero_label )
224
+ total_area_label = total_intersect_and_union (
225
+ results , gt_seg_maps , num_classes , ignore_index , label_map ,
226
+ reduce_zero_label )
214
227
all_acc = total_area_intersect .sum () / total_area_label .sum ()
215
228
acc = total_area_intersect / total_area_label
216
229
ret_metrics = [all_acc , acc ]
@@ -222,6 +235,7 @@ def eval_metrics(results,
222
235
dice = 2 * total_area_intersect / (
223
236
total_area_pred_label + total_area_label )
224
237
ret_metrics .append (dice )
238
+ ret_metrics = [metric .numpy () for metric in ret_metrics ]
225
239
if nan_to_num is not None :
226
240
ret_metrics = [
227
241
np .nan_to_num (metric , nan = nan_to_num ) for metric in ret_metrics
0 commit comments