-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_seg.py
119 lines (95 loc) · 4.82 KB
/
run_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
from matplotlib import pyplot as plt
from pathlib import Path
from uuid import uuid4
import pycocotools.mask as mask_util
import argparse
import sys
from utils import compute_segmentation_metrics, PATH_TO_ID
from models.run_biovil import plot_phrase_grounding as ppgb
from models.BioViL.image.data.io import load_image
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='name of model (BioViL, )')
parser.add_argument('test_set', type=str, help='name of test set (CheXlocalize, )')
parser.add_argument('visualize', type=str, help='yes or no')
parser.add_argument('method', type=str, help='how to generate heatmap (naive, grad_cam, gradcam_plus, cocoa)')
return parser.parse_args()
PROMPTS = {
"Cardiomegaly": "Findings suggesting cardiomegaly",
"Edema": "Findings suggesting an edema",
"Consolidation": "Findings suggesting consolidation",
"Atelectasis": "Findings suggesting atelectasis",
"Pneumothorax": "Findings suggesting a pneumothorax",
"Pleural Effusion": "Findings suggesting pleural effusion",
}
SEG_TARGETS = {
"Cardiomegaly": ["Heart"],
"Edema": ["Left Lung", "Right Lung"],
"Consolidation": ["Left Lung", "Right Lung"],
"Atelectasis": ["Left Lung", "Right Lung", "Facies Diaphragmatica"],
"Pneumothorax": [],
"Pleural Effusion": ["Left Lung", "Right Lung", "Facies Diaphragmatica"],
}
def main():
args = parse_args()
print(f"Running {sys.argv[0]} with args {args}")
PLOT_IMAGES = False
if args.visualize == "yes":
PLOT_IMAGES = True
if args.model != "BioViL":
raise NotImplementedError("Only BioViL is implemented for now")
if args.test_set != "CheXlocalize":
raise NotImplementedError("Only CheXlocalize is implemented for now")
ious_by_pathology = [0]*len(PROMPTS)
dices_by_pathology = [0]*len(PROMPTS)
numbers_by_pathology = [0]*len(PROMPTS)
json_obj = json.load(open("datasets/CheXlocalize/gt_segmentations_test.json"))
c = 0
for obj in json_obj:
filename = "datasets/CheXlocalize/CheXpert/test/" + obj.replace("_", "/", (obj.count('_')-1)) + ".jpg"
for query in json_obj[obj]:
if query not in PROMPTS:
continue
annots = json_obj[obj][query]
if annots['counts'] != 'ifdl3':
gt_mask = mask_util.decode(annots)
if gt_mask.max() == 0:
continue
c += 1
if c != 4:
continue
text_prompt = PROMPTS[query]
if args.method == "naive":
heatmap = ppgb(filename, text_prompt)
else:
heatmap, image = ppgb(filename, text_prompt, method=args.method, input_size=gt_mask.shape, pathology=query, seg_targets=SEG_TARGETS[query])
best_iou, best_dice, best_thresh = compute_segmentation_metrics(heatmap, gt_mask)
if PLOT_IMAGES:
_, axes = plt.subplots(1, 3, figsize=(15, 6))
# image = load_image(Path(filename)).convert("RGB")
axes[0].imshow(image)
axes[0].axis('off')
axes[0].set_title("Input image")
axes[1].imshow((heatmap > best_thresh).astype(int))
axes[1].axis('off')
axes[1].set_title(f"BioViL mask: {text_prompt}")
axes[2].imshow(gt_mask)
axes[2].axis('off')
axes[2].set_title(f"GT mask: {query}")
plt.savefig(f"biovil_plot_{uuid4()}.png")
ious_by_pathology[PATH_TO_ID[query]] += best_iou
dices_by_pathology[PATH_TO_ID[query]] += best_dice
numbers_by_pathology[PATH_TO_ID[query]] += 1
f = open("run_seg.txt", "a")
for pathology in PROMPTS:
print("Pathology:", pathology, "mIoU:", ious_by_pathology[PATH_TO_ID[pathology]]/numbers_by_pathology[PATH_TO_ID[pathology]], "Avg. DICE:", dices_by_pathology[PATH_TO_ID[pathology]]/numbers_by_pathology[PATH_TO_ID[pathology]])
f.write("Pathology: " + pathology + " mIoU: " + str(ious_by_pathology[PATH_TO_ID[pathology]]/numbers_by_pathology[PATH_TO_ID[pathology]]) + " Avg. DICE: " + str(dices_by_pathology[PATH_TO_ID[pathology]]/numbers_by_pathology[PATH_TO_ID[pathology]]) + "\n")
f.write("\n")
print("mIoU:", sum(ious_by_pathology)/sum(numbers_by_pathology))
f.write("mIoU: " + str(sum(ious_by_pathology)/sum(numbers_by_pathology)) + "\n")
print("Avg. DICE: ", sum(dices_by_pathology)/sum(numbers_by_pathology))
f.write("Avg. DICE: " + str(sum(dices_by_pathology)/sum(numbers_by_pathology)) + "\n")
f.close()
if __name__ == "__main__":
main()