-
Notifications
You must be signed in to change notification settings - Fork 69
/
metrics.py
44 lines (40 loc) · 2.03 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
This file contains implementations for the precision@k and IoU (mean, overall) evaluation metrics.
"""
import torch
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools.mask import decode
import numpy as np
def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6):
outputs = outputs.int()
intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
union = (outputs | labels).float().sum((1, 2)) # Will be zero if both are 0
iou = (intersection + EPS) / (union + EPS) # EPS is used to avoid division by zero
return iou, intersection, union
def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO):
print('evaluating precision@k & iou metrics...')
counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]}
total_intersection_area = 0
total_union_area = 0
ious_list = []
for instance in tqdm(coco_gt.imgs.keys()): # each image_id contains exactly one instance
gt_annot = coco_gt.imgToAnns[instance][0]
gt_mask = decode(gt_annot['segmentation'])
pred_annots = coco_pred.imgToAnns[instance]
pred_annot = sorted(pred_annots, key=lambda a: a['score'])[-1] # choose pred with highest score
pred_mask = decode(pred_annot['segmentation'])
iou, intersection, union = compute_iou(torch.tensor(pred_mask).unsqueeze(0),
torch.tensor(gt_mask).unsqueeze(0))
iou, intersection, union = iou.item(), intersection.item(), union.item()
for iou_threshold in counters_by_iou.keys():
if iou > iou_threshold:
counters_by_iou[iou_threshold] += 1
total_intersection_area += intersection
total_union_area += union
ious_list.append(iou)
num_samples = len(ious_list)
precision_at_k = np.array(list(counters_by_iou.values())) / num_samples
overall_iou = total_intersection_area / total_union_area
mean_iou = np.mean(ious_list)
return precision_at_k, overall_iou, mean_iou