This repository has been archived by the owner on Feb 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
bison_eval.py
94 lines (73 loc) · 2.87 KB
/
bison_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import json
import argparse
import numpy as np
class BisonEval:
def __init__(self, anno, pred):
if pred.getBisonIds() != anno.getBisonIds():
print('[Warning] The prediction does not' +
'cover the entire set of bison data.' +
'The evaluation is running on the {}'.format(
len(pred.getBisonIds())) +
'subset from prediction file.')
self.params = {'bison_ids': pred.getBisonIds()}
self.anno = anno
self.pred = pred
def evaluate(self):
accuracy = []
for bison_id in self.params['bison_ids']:
accuracy.append(self.anno[bison_id]['true_image_id'] ==
self.pred[bison_id])
mean_accuracy = np.mean(accuracy)
print("[Result] Mean BISON accuracy on {}: {:.2f}%".format(
self.anno.dataset, mean_accuracy * 100)
)
return mean_accuracy
class Annotation:
def __init__(self, anno_filepath):
assert os.path.exists(anno_filepath), 'Annotation file does not exist'
with open(anno_filepath) as fd:
anno_results = json.load(fd)
self._data = {res['bison_id']: res for res in anno_results['data']}
self.dataset = "{}.{}".format(anno_results['info']['source'],
anno_results['info']['split'])
def getBisonIds(self):
return self._data.keys()
def __getitem__(self, key):
return self._data[key]
class Prediction:
def __init__(self, pred_filepath):
assert os.path.exists(pred_filepath), 'Prediction file does not exist'
with open(pred_filepath) as fd:
pred_results = json.load(fd)
self._data = {result['bison_id']: result['predicted_image_id']
for result in pred_results}
def getBisonIds(self):
return self._data.keys()
def __getitem__(self, key):
return self._data[key]
def _command_line_parser():
parser = argparse.ArgumentParser()
default_anno = './annotations/bison_annotations.cocoval2014.json'
default_pred = './predictions/fake_predictions.cocoval2014.json'
parser.add_argument('--anno_path', default=default_anno,
help='Path to the annotation file')
parser.add_argument('--pred_path', default=default_pred,
help='Path to the prediction file')
return parser
def main(args):
anno = Annotation(args.anno_path)
pred = Prediction(args.pred_path)
bison = BisonEval(anno, pred)
bison.evaluate()
if __name__ == '__main__':
parser = _command_line_parser()
args = parser.parse_args()
main(args)