Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
add function check_and_write_list_tfrecord()
  • Loading branch information
KaijieMo1 authored Nov 2, 2020
1 parent e512f63 commit fd40e9c
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import yaml
import tensorflow as tf
from med_io.preprocess_raw_dataset import *
from med_io.read_and_save_datapath import *
from train import *
from evaluate import *
from predict import *
Expand All @@ -16,20 +17,23 @@

def args_argument():
parser = argparse.ArgumentParser(prog='MedSeg')
parser.add_argument('-e', '--exp_name', type=str, default='exp0', help='Name of experiment (subfolder in result_rootdir)')

parser.add_argument('-e', '--exp_name', type=str, default='exp0',
help='Name of experiment (subfolder in result_rootdir)')

parser.add_argument('--preprocess', type=bool, default=False, help='Preprocess the data')
parser.add_argument('--train', type=bool, default=True, help='Train the model')
parser.add_argument('--evaluate', type=bool, default=False, help='Evaluate the model')
parser.add_argument('--predict', type=bool, default=True, help='Predict the model')
parser.add_argument('--restore', type=bool, default=False, help='Restore the unfinished trained model')
#parser.add_argument('-c', '--config_path', type=str, default='./config/bi.yaml', help='Configuration file of the project')
parser.add_argument('-c', '--config_path', type=str, default='./config/config1.yaml', help='Configuration file of the project')
#parser.add_argument('-c', '--config_path', type=str, default='./config/nifti_AT.yaml', help='Configuration file of the project')
# parser.add_argument('-c', '--config_path', type=str, default='./config/bi.yaml', help='Configuration file of the project')
# parser.add_argument('-c', '--config_path', type=str, default='./config/config1.yaml', help='Configuration file of the project')
# parser.add_argument('-c', '--config_path', type=str, default='./config/nifti_AT.yaml', help='Configuration file of the project')
parser.add_argument('-c', '--config_path', type=str, default='./config/new/config_body_ident.yaml',
help='Configuration file of the project')

parser.add_argument("--gpu", type=int, default=0, help="Specify the GPU to use")
parser.add_argument('--gpu_memory', type=float, default=None, help='Set GPU allocation. (in GB) ')
parser.add_argument('--calculate_max_shape_only', type=bool, default=False,
parser.add_argument('--calculate_max_shape_only', type=bool, default=True,
help='Only calculate the max shape of each dataset')
parser.add_argument('--split_only', type=bool, default=False,
help='Only split the whole dataset to train, validation, and test dataset')
Expand Down Expand Up @@ -57,10 +61,9 @@ def main(args):
except RuntimeError as e:
print(e)
else: # allocate dynamic growth
config = tf.ConfigProto()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
tf.set_session(tf.Session(config=config))

tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

with open(args.config_path, "r") as yaml_file:
config = yaml.load(yaml_file.read())
Expand All @@ -75,7 +78,7 @@ def main(args):
random.seed(config['random_seed'])

if args.exp_name:
config['exp_name']=args.exp_name
config['exp_name'] = args.exp_name

if args.train_epoch:
config['epoch'] = args.train_epoch
Expand All @@ -86,6 +89,22 @@ def main(args):
if args.dataset:
config['dataset'] = [args.dataset]

def check_and_write_list_tfrecord():
for dataset in config['dataset']:

if (not os.path.isfile(
config['dir_list_tfrecord'] + '/' + config['filename_tfrec_pickle'][dataset] + '.pickle') and not
config['read_body_identification']):
read_and_save_tfrec_path(config, rootdir=config['rootdir_tfrec'][dataset],
filename_tfrec_pickle=config['filename_tfrec_pickle'][dataset],
dataset=dataset)

if (not os.path.isfile(
config['dir_list_tfrecord'] + '/' + config['filename_tfrec_pickle'][dataset] + '_bi.pickle') and
config['read_body_identification']):
read_and_save_tfrec_path(config, rootdir=config['rootdir_tfrec'][dataset],
filename_tfrec_pickle=config['filename_tfrec_pickle'][dataset],
dataset=dataset)

# preprocess and convert input to TFRecords
if args.preprocess:
Expand All @@ -94,21 +113,23 @@ def main(args):
split(config) # split into train, validation and test set

if args.calculate_max_shape_only:
check_and_write_list_tfrecord()
calculate_max_shape(config) # find and dump the max shape
split(config) # split into train, validation and test set
if args.split_only:
check_and_write_list_tfrecord()
split(config) # split into train, validation and test set


if args.train: # train the model
train(config, args.restore)
print("Training finished for %s" % (config['dir_model_checkpoint']+os.sep+config['exp_name']))
print("Training finished for %s" % (config['dir_model_checkpoint'] + os.sep + config['exp_name']))
if args.evaluate: # evaluate the metrics of a trained model
evaluate(config,datasets=config['dataset'])
print("Evaluation finished for %s" % (config['result_rootdir']+os.sep+config['exp_name']))
evaluate(config, datasets=config['dataset'])
print("Evaluation finished for %s" % (config['result_rootdir'] + os.sep + config['exp_name']))
if args.predict: # predict and generate output masks of a trained model
predict(config, datasets=config['dataset'], save_predict_data=config['save_predict_data'])
print("Prediction finished for %s" % (config['result_rootdir']+os.sep+config['exp_name']))
print("Prediction finished for %s" % (config['result_rootdir'] + os.sep + config['exp_name']))


if __name__ == '__main__':
main(args_argument())

0 comments on commit fd40e9c

Please sign in to comment.