-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_SC-GAN.py
92 lines (61 loc) · 2.9 KB
/
test_SC-GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util import html
from data_custom.data_load import load_nifty_volume_as_array
from data_custom.data_load import save_array_as_nifty_volume
import numpy as np
import torch
import math
if __name__ == '__main__':
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
# data_loader = CreateDataLoader(opt)
# dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
# web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
# webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
in_type1=opt.input1
in_type2=opt.input2
out_type=opt.out
for phase in ['train','val','test']:
target_path=opt.dataroot+'/'+phase
patients=os.listdir(target_path)
order_c=opt.order
for i in range(len(patients)):
# if i >= opt.how_many:
# break
target_subject=os.path.join(target_path,patients[i])+'/'
in_im1=np.float32(load_nifty_volume_as_array(filename=target_subject+in_type1+'.nii'))
in_im2=np.float32(load_nifty_volume_as_array(filename=target_subject+in_type2+'.nii'))
out_im1=np.float32(load_nifty_volume_as_array(filename=target_subject+out_type+'.nii'))
out_im1[out_im1<0] = 0
in_im1[in_im1<0] = 0
in_im2[in_im2<0] = 0
#subject based normalization
out_im1=1.0*out_im1/out_im1.max()
in_im1=1.0*in_im1/in_im1.max()
in_im2=1.0*in_im2/in_im2.max()
data_x=np.array([in_im1,in_im2])
data_y=np.array([out_im1])
data_x=(data_x*1.0-0.5)*2
data_y=(data_y*1.0-0.5)*2
# data_x=np.float32(data_x)
# data_y=np.float32(data_y)
data_x=np.expand_dims(data_x,axis=0)
data_y=np.expand_dims(data_y,axis=0)
data={'A': torch.from_numpy(data_x), 'B':torch.from_numpy(data_y), 'A_paths':target_subject, 'B_paths':target_subject}
model.set_input(data)
model.test()
fake_im=model.fake_B.cpu().data.numpy()
fake_im=fake_im*0.5+0.5
fake_im = np.squeeze(fake_im)
save_array_as_nifty_volume(fake_im, filename=target_subject+out_type+'_syn_SC-GAN.nii')