-
Notifications
You must be signed in to change notification settings - Fork 23
/
main.py
148 lines (125 loc) · 7.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import sys
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from data_generator import ImageDataGenerator
from saml_func import SAML
from train import train
from train import test
import datetime
import argparse
from utils import check_folder, show_all_variables
import logging
currtime = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
tf.set_random_seed(2)
def parse_args(train_date):
desc = "Tensorflow implementation of DenseUNet for prostate segmentation"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gpu', type=str, default='0', help='train or test or guide')
parser.add_argument('--phase', type=str, default='train', help='train or test or guide')
parser.add_argument('--n_class', type=int, default=2, help='The size of class')
## Training operations
parser.add_argument('--target_domain', type=str, default='ISBI', help='dataset_name')
parser.add_argument('--volume_size', type=list, default=[384, 384, 3], help='The size of input data')
parser.add_argument('--label_size', type=list, default=[384, 384, 1], help='The size of label')
parser.add_argument('--epoch', type=int, default=1, help='The number of epochs to run')
parser.add_argument('--train_iterations', type=int, default=10000, help='The number of training iterations')
parser.add_argument('--meta_batch_size', type=int, default=5, help='number of images sampled per source domain')
parser.add_argument('--test_batch_size', type=int, default=1, help='number of images sampled per source domain')
parser.add_argument('--inner_lr', type=float, default=1e-4, help='The learning rate')
parser.add_argument('--outer_lr', type=float, default=1e-3, help='The learning rate')
parser.add_argument('--metric_lr', type=float, default=1e-3, help='The learning rate')
parser.add_argument('--margin', type=float, default=10.0, help='The learning rate')
parser.add_argument('--compactness_loss_weight', type=float, default=1.0, help='The learning rate')
parser.add_argument('--smoothness_loss_weight', type=float, default=0.005, help='The learning rate')
parser.add_argument('--clipNorm', type=int, default=True, help='number of images sampled per source domain')
parser.add_argument('--gradients_clip_value', type=float, default=10.0, help='The learning rate')
# Logging, saving, and testing options
parser.add_argument('--resume', type=int, default=False, help='number of images sampled per source domain')
parser.add_argument('--log', type=int, default=True, help='write tensorboard')
parser.add_argument('--decay_step', type=float, default=500, help='The learning rate')
parser.add_argument('--decay_rate', type=float, default=0.95, help='The learning rate')
parser.add_argument('--test_freq', type=int, default=200, help='The number of ckpt_save_freq')
parser.add_argument('--save_freq', type=int, default=200, help='The number of ckpt_save_freq')
parser.add_argument('--print_interval', type=int, default=5, help='The frequency to write tensorboard')
parser.add_argument('--summary_interval', type=int, default=20, help='The frequency to write tensorboard')
parser.add_argument('--restored_model', type=str, default=None, help='Model to restore')
parser.add_argument('--test_model', type=str, default=None, help='Model to restore')
# parser.add_argument('--dropout', type=str, default=1, help='dropout rate')
# parser.add_argument('--cost_kwargs', type=str, default=1, help='cost_kwargs')
# parser.add_argument('--opt_kwargs', type=str, default=1, help='opt_kwargs')
parser.add_argument('--checkpoint_dir', type=str, default='../output/' + train_date + '/checkpoints/' ,
help='Directory name to save the checkpoints')
parser.add_argument('--result_dir', type=str, default='../output/' + train_date + '/results/',
help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='../output/' + train_date + '/logs/',
help='Directory name to save training logs')
parser.add_argument('--sample_dir', type=str, default='../output/' + train_date + '/samples/',
help='Directory name to save the samples on training')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --result_dir
check_folder(args.result_dir)
# --result_dir
check_folder(args.log_dir)
# --sample_dir
check_folder(args.sample_dir)
return args
def main():
train_date = 'xxx'
args = parse_args(train_date)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# define logger
logging.basicConfig(filename=args.log_dir+"/"+args.phase+'_log.txt', level=logging.DEBUG, format='%(asctime)s %(message)s')
logging.getLogger().addHandler(logging.StreamHandler())
# print all parameters
logging.info("Usage:")
logging.info(" {0}".format(" ".join([x for x in sys.argv])))
logging.debug("All settings used:")
os.system('cp main.py %s' % (args.log_dir)) # bkp of train procedure
os.system('cp saml_func.py %s' % (args.log_dir)) # bkp of train procedure
os.system('cp train.py %s' % (args.log_dir)) # bkp of train procedure
os.system('cp utils.py %s' % (args.log_dir)) # bkp of train procedure
os.system('cp data_generator.py %s' % (args.log_dir))
filelist_root = '../dataset'
source_list = ['HK', 'ISBI', 'ISBI_1.5', 'I2CVB','UCL', 'BIDMC']#'ISBI_1.5', 'I2CVB', 'UCL','BIDMC']#, 'I2CVB', 'ISBI_1.5', 'UCL', 'BIDMC']#'I2CVB', 'UCL', 'BIDMC', 'HK']
source_list.remove(args.target_domain)
# Constructing model
model = SAML(args)
model.construct_model_train()
model.construct_model_test()
model.summ_op = tf.summary.merge_all()
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
show_all_variables()
# restore model ----
resume_itr = 0
model_file = None
if args.resume:
model_file = tf.train.latest_checkpoint(args.checkpoint_dir)
if model_file:
ind1 = model_file.index('model')
resume_itr = int(model_file[ind1+5:])
print("Restoring model weights from " + model_file)
saver.restore(sess, model_file)
train_file_list = [os.path.join(filelist_root, source_domain+'_train_list') for source_domain in source_list]
test_file_list = [os.path.join(filelist_root, args.target_domain+'_train_list')]
# start training ----
if args.phase == 'train':
train(model, saver, sess, train_file_list, test_file_list[0], args, resume_itr)
else:
args.test_model = 'xxx'
saver.restore(sess, args.test_model)
logging.info("testing model restored %s" % args.test_model)
test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file_list[0], model, args)
with open((os.path.join(args.log_dir,'test.txt')), 'a') as f:
print >> f, 'testing model %s :' % (args.test_model)
print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
if __name__ == "__main__":
main()