-
Notifications
You must be signed in to change notification settings - Fork 3
/
inference.py
116 lines (95 loc) · 5.44 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import gc
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from chestxray14 import ChestXray14Dataset
from chexpert import CheXpertDataset
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
from model import InferenceModel
from utils import calculate_auroc
torch.multiprocessing.set_sharing_strategy('file_system')
def inference_chexpert():
split = 'test'
dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv') # also do test
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0)
inference_model = InferenceModel()
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert)
all_labels = []
all_probs_neg = []
for batch in tqdm(dataloader):
batch = batch[0]
image_paths, labels, keys = batch
image_paths = [Path(image_path) for image_path in image_paths]
agg_probs = []
agg_negative_probs = []
for image_path in image_paths:
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
agg_probs.append(probs)
agg_negative_probs.append(negative_probs)
probs = {} # Aggregated
negative_probs = {} # Aggregated
for key in agg_probs[0].keys():
probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs) # Mean Aggregation
for key in agg_negative_probs[0].keys():
negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs) # Mean Aggregation
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs,
negative_probs=negative_probs)
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert,
disease_probs=disease_probs,
negative_disease_probs=negative_disease_probs,
keys=keys)
all_labels.append(labels)
all_probs_neg.append(prob_vector_neg_prompt)
all_labels = torch.stack(all_labels)
all_probs_neg = torch.stack(all_probs_neg)
# evaluation
existing_mask = sum(all_labels, 0) > 0
all_labels_clean = all_labels[:, existing_mask]
all_probs_neg_clean = all_probs_neg[:, existing_mask]
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean)
print(f"AUROC: {overall_auroc:.5f}\n")
for idx, key in enumerate(all_keys_clean):
print(f'{key}: {per_disease_auroc[idx]:.5f}')
def inference_chestxray14():
dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv')
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1)
inference_model = InferenceModel()
all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14)
all_labels = []
all_probs_neg = []
for batch in tqdm(dataloader):
batch = batch[0]
image_path, labels, keys = batch
image_path = Path(image_path)
probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs,
negative_probs=negative_probs)
predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14,
disease_probs=disease_probs,
negative_disease_probs=negative_disease_probs,
keys=keys)
all_labels.append(labels)
all_probs_neg.append(prob_vector_neg_prompt)
gc.collect()
all_labels = torch.stack(all_labels)
all_probs_neg = torch.stack(all_probs_neg)
existing_mask = sum(all_labels, 0) > 0
all_labels_clean = all_labels[:, existing_mask]
all_probs_neg_clean = all_probs_neg[:, existing_mask]
all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]
overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:])
print(f"AUROC: {overall_auroc:.5f}\n")
for idx, key in enumerate(all_keys_clean[1:]):
print(f'{key}: {per_disease_auroc[idx]:.5f}')
if __name__ == '__main__':
# add argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14')
args = parser.parse_args()
if args.dataset == 'chexpert':
inference_chexpert()
elif args.dataset == 'chestxray14':
inference_chestxray14()