-
Notifications
You must be signed in to change notification settings - Fork 0
/
conformal_classification_calib.py
113 lines (82 loc) · 4.27 KB
/
conformal_classification_calib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import glob, os, tqdm, json
import numpy as np
from utils import inference_mmpretrain, compute_conformity_scores, calibrate_cp_threshold, get_prediction_set, blockPrint, enablePrint, plot_uncertainty_vs_difficulty, plot_coverage_per_class, plot_coverage_vs_size, plot_confusion_matrix
from mmengine.config import Config
from mmengine.runner import Runner
from mmpretrain import ImageClassificationInferencer
def predict(dataloader, inferencer):
'''
Predict scores and retrieve ground truth labels for all images in dataloader.
'''
scores = []
gt_labels = []
for img_batch in tqdm.tqdm(dataloader):
metadata = img_batch['data_samples']
images = [m.img_path for m in metadata]
pred = inference_mmpretrain(images, inferencer)
scores.append(pred)
gt_labels += [m.gt_label.item() for m in metadata]
scores = np.concatenate(scores, axis=0)
gt_labels = np.array(gt_labels)
return scores, gt_labels
def compute_accuracy(argmax, target):
return (argmax==target).sum() / len(target)
def main(args):
config = Config.fromfile(args.config)
config.work_dir = 'outputs/conformal_prediction/'
config.load_from = args.checkpoint
blockPrint()
runner = Runner.from_cfg(config)
enablePrint()
inferencer = ImageClassificationInferencer(
model=args.config,
pretrained=args.checkpoint,
device=f'cuda:{args.gpu_id}'
)
calibration_results = {
'significance_level' : args.alpha,
'regularisation' : {
'lambda' : args.l,
'kreg' : args.kreg,
}
}
# calibration
calib_loader = runner.val_dataloader
scores, gt_labels = predict(calib_loader, inferencer)
cs_thr, true_class_conformity_scores = calibrate_cp_threshold(scores, gt_labels, args.alpha, l=args.l, kreg=args.kreg)
print(cs_thr)
calibration_results['conformality_score_thr'] = float(cs_thr)
# validation
test_loader = runner.test_dataloader
classes = test_loader.dataset.CLASSES
scores, gt_labels = predict(test_loader, inferencer)
prediction_set_list, size, credibility, confidence, ranking, covered, confusion_matrix = get_prediction_set(scores, cs_thr, true_class_conformity_scores, l=args.l, kreg=args.kreg, gt_labels=gt_labels)
np.save(f'{args.outpath}/true_class_conformity_scores_calib.npy', true_class_conformity_scores)
argmax = scores.argmax(axis=1)
print(compute_accuracy(argmax,gt_labels))
print(covered.sum() / len(covered))
print(np.unique(size, return_counts=True))
plot_uncertainty_vs_difficulty('size', ranking, size, gt_labels, classes, args.outpath)
plot_uncertainty_vs_difficulty('credibility', ranking, credibility, gt_labels, classes, args.outpath)
plot_uncertainty_vs_difficulty('confidence', ranking, confidence, gt_labels, classes, args.outpath)
plot_coverage_per_class(covered, gt_labels, classes, args.alpha, args.outpath)
plot_coverage_vs_size(size, covered, gt_labels, classes, args.alpha, args.outpath)
plot_confusion_matrix(confusion_matrix, len(ranking), classes, args.outpath)
calibration_results['val_accuracy'] = float(compute_accuracy(argmax,gt_labels))
calibration_results['coverage'] = float(covered.sum() / len(covered))
with open(f'{args.outpath}/calibration_results.json', 'w') as fout:
json.dump(calibration_results, fout, indent = 6)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="mmpretrain config")
parser.add_argument("--checkpoint", required=True, help="mmpretrain checkpoint")
parser.add_argument("--im-dir", help="Directory containing the images")
parser.add_argument("--outpath", required=True, help="Path to output directory")
parser.add_argument("--gpu-id", default='0', help="ID of gpu to be used")
parser.add_argument("--alpha", type=float, default=0.1, help="significance level")
parser.add_argument("--l", type=float, default=0., help="lambda parameter for regularisation of conformality score")
parser.add_argument("--kreg", type=float, default=0., help="kreg parameter for regularisation of conformality score")
args = parser.parse_args()
os.makedirs(args.outpath, exist_ok=True)
main(args)