-
Notifications
You must be signed in to change notification settings - Fork 4
/
metrics.py
114 lines (98 loc) · 3.62 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Common image segmentation metrics.
"""
# Taken from:
# https://github.com/kevinzakka/pytorch-goodies
import torch
EPS = 1e-10
def nanmean(x):
"""Computes the arithmetic mean ignoring any NaNs."""
return torch.mean(x[x == x])
def overall_pixel_accuracy(hist):
"""Computes the total pixel accuracy.
The overall pixel accuracy provides an intuitive
approximation for the qualitative perception of the
label when it is viewed in its overall shape but not
its details.
Args:
hist: confusion matrix.
Returns:
overall_acc: the overall pixel accuracy.
"""
correct = torch.diag(hist).sum()
total = hist.sum()
overall_acc = correct / (total + EPS)
return overall_acc
def per_class_pixel_accuracy(hist):
"""Computes the average per-class pixel accuracy.
The per-class pixel accuracy is a more fine-grained
version of the overall pixel accuracy. A model could
score a relatively high overall pixel accuracy by
correctly predicting the dominant labels or areas
in the image whilst incorrectly predicting the
possibly more important/rare labels. Such a model
will score a low per-class pixel accuracy.
Args:
hist: confusion matrix.
Returns:
avg_per_class_acc: the average per-class pixel accuracy.
"""
correct_per_class = torch.diag(hist)
total_per_class = hist.sum(dim=1)
per_class_acc = correct_per_class / (total_per_class + EPS)
avg_per_class_acc = nanmean(per_class_acc)
return avg_per_class_acc, per_class_acc
def jaccard_index(hist):
"""Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
Args:
hist: confusion matrix.
Returns:
avg_jacc: the average per-class jaccard index.
"""
A_inter_B = torch.diag(hist) # interection
A = hist.sum(dim=1) # ground-truth set
B = hist.sum(dim=0) # predicted set
jaccard = A_inter_B / (A + B - A_inter_B + EPS) # interesection / union
avg_jacc = nanmean(jaccard)
return avg_jacc, jaccard
#def dice_coefficient(hist):
def F1_Score(hist):
"""Computes the Sørensen–Dice coefficient, a.k.a the F1 score.
Args:
hist: confusion matrix.
Returns:
avg_dice: the average per-class dice coefficient.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim=1)
B = hist.sum(dim=0)
dice = (2 * A_inter_B) / (A + B + EPS)
avg_dice = nanmean(dice)
return avg_dice, dice
def _fast_hist(true, pred, num_classes):
mask = (true >= 0) & (true < num_classes)
hist = torch.bincount(
num_classes * true[mask] + pred[mask],
minlength=num_classes ** 2,
).reshape(num_classes, num_classes).float()
return hist
def eval_metrics(true, pred, num_classes):
"""Computes various segmentation metrics on 2D feature maps.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
pred: a tensor of shape [B, H, W] or [B, 1, H, W].
num_classes: the number of classes to segment. This number
should be less than the ID of the ignored class.
Returns:
overall_acc: the overall pixel accuracy.
avg_per_class_acc: the average per-class pixel accuracy.
avg_jacc: the jaccard index.
avg_dice: the dice coefficient.
"""
hist = torch.zeros((num_classes, num_classes))
for t, p in zip(true, pred):
hist += _fast_hist(t.flatten(), p.flatten(), num_classes)
overall_acc = overall_pixel_accuracy(hist)
avg_per_class_acc = per_class_pixel_accuracy(hist)
avg_jacc = jaccard_index(hist)
avg_dice = dice_coefficient(hist)
return overall_acc, avg_per_class_acc, avg_jacc, avg_dice