-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
116 lines (105 loc) · 4.18 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import glob
import os, losses, utils
from torch.utils.data import DataLoader
from data import datasets, trans
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from natsort import natsorted
from models import RDP
import random
def same_seeds(seed):
# Python built-in random module
random.seed(seed)
# Numpy
np.random.seed(seed)
# Torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
same_seeds(24)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.vals = []
self.std = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.vals.append(val)
self.std = np.std(self.vals)
def main():
val_dir = '/LPBA_path/Val/'
weights = [1, 1] # loss weights
lr = 0.0001
model_idx = -1
model_folder = 'RDP_ncc_{}_reg_{}_lr_{}_54r/'.format(weights[0], weights[1], lr)
model_dir = 'experiments/' + model_folder
img_size = (160, 192, 160)
model = RDP(img_size, channels=16)
best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict']
print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx]))
model.load_state_dict(best_model)
model.cuda()
reg_model = utils.register_model(img_size, 'nearest')
reg_model.cuda()
test_composed = transforms.Compose([trans.Seg_norm(),
trans.NumpyType((np.float32, np.int16)),
])
test_set = datasets.LPBABrainInferDatasetS2S(glob.glob(val_dir + '*.pkl'), transforms=test_composed)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
eval_dsc_def = AverageMeter()
eval_dsc_raw = AverageMeter()
eval_det = AverageMeter()
with torch.no_grad():
stdy_idx = 0
for data in test_loader:
model.eval()
data = [t.cuda() for t in data]
x = data[0]
y = data[1]
x_seg = data[2]
y_seg = data[3]
x_def, flow = model(x,y)
def_out = reg_model([x_seg.cuda().float(), flow.cuda()])
tar = y.detach().cpu().numpy()[0, 0, :, :, :]
jac_det = utils.jacobian_determinant_vxm(flow.detach().cpu().numpy()[0, :, :, :, :])
eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0))
dsc_trans = utils.dice_val_VOI(def_out.long(), y_seg.long())
dsc_raw = utils.dice_val_VOI(x_seg.long(), y_seg.long())
print('Trans dsc: {:.4f}, Raw dsc: {:.4f}'.format(dsc_trans.item(),dsc_raw.item()))
eval_dsc_def.update(dsc_trans.item(), x.size(0))
eval_dsc_raw.update(dsc_raw.item(), x.size(0))
stdy_idx += 1
print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg,
eval_dsc_def.std,
eval_dsc_raw.avg,
eval_dsc_raw.std))
print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std))
if __name__ == '__main__':
'''
GPU configuration
'''
GPU_iden = 0
GPU_num = torch.cuda.device_count()
print('Number of GPU: ' + str(GPU_num))
for GPU_idx in range(GPU_num):
GPU_name = torch.cuda.get_device_name(GPU_idx)
print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
torch.cuda.set_device(GPU_iden)
GPU_avai = torch.cuda.is_available()
print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
print('If the GPU is available? ' + str(GPU_avai))
main()