-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrain_test.py
92 lines (70 loc) · 3.67 KB
/
Train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import print_function
from __future__ import division
import argparse
from glob import glob
import natsort
import tensorflow as tf
from model import denoiser
from utils import *
import os
parser = argparse.ArgumentParser(description='')
parser.add_argument('--epoch', dest='epoch', type=int, default=70, help='# of epoch')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=128, help='# images in batch')
parser.add_argument('--lr', dest='lr', type=float, default=0.001, help='initial learning rate for sgd')
parser.add_argument('--sigma', dest='sigma', type=float, default=50.0, help='noise level (for testing)')
parser.add_argument('--data', dest='data', default='./patches/clean_pats_blind.npy', help='training data path')
parser.add_argument('--checkpoint_dir', dest='ckpt_dir', default='./trained_model', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./Output', help='sample are saved here')
parser.add_argument('--log_dir', dest='log_dir', default='./logs', help='tensorboard logs are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./Output', help='test sample are saved here')
parser.add_argument('--eval_set', dest='eval_set', default='CBSD68', help='dataset for eval in training')
parser.add_argument('--test_set', dest='test_set', default='CBSD68', help='dataset for testing')
parser.add_argument('--gpu', dest='gpu', default='0', help='which gpu to use')
parser.add_argument('--type', dest='type', default='', help='arg to give unique names to realizations')
parser.add_argument('--phase', dest='phase', default='train', help='train or test')
parser.add_argument('--use_gpu', dest='use_gpu', type=int, default=1, help='gpu flag, 1 for GPU and 0 for CPU')
args = parser.parse_args()
def denoiser_train(denoiser, lr):
eval_files = natsort.natsorted(glob('./data/test/{}/*.png'.format(args.eval_set)))
denoiser.train(args.data, eval_files, batch_size=args.batch_size, epoch=args.epoch, lr=lr)
def denoiser_test(denoiser, save_dir):
print('Testing on {} dataset'.format(args.test_set))
test_files = natsort.natsorted(glob('./data/test/{}/*.png'.format(args.test_set)))
denoiser.test(test_files, save_dir)
def main(_):
ckpt_dir = args.ckpt_dir
sample_dir = args.sample_dir
test_dir = args.test_dir
log_dir = args.log_dir
if not os.path.exists(args.ckpt_dir):
os.makedirs(args.ckpt_dir)
lr = args.lr * np.ones([args.epoch])
lr[40:] = lr[0] / 10.0 #lr decay
lr[50:] = lr[0] / 20.0 #lr decay
if args.use_gpu:
# added to control the gpu memory
print("GPU\n")
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
model = denoiser(sess, sigma=args.sigma, ckpt_dir=ckpt_dir, sample_dir=sample_dir, log_dir=log_dir)
if args.phase == 'train':
denoiser_train(model, lr=lr)
elif args.phase == 'test':
denoiser_test(model, test_dir)
else:
print('[!]Unknown phase')
exit(0)
else:
print("CPU\n")
with tf.Session() as sess:
model = denoiser(sess, sigma=args.sigma, ckpt_dir=ckpt_dir, sample_dir=sample_dir, log_dir=log_dir)
if args.phase == 'train':
denoiser_train(model, lr=lr)
elif args.phase == 'test':
denoiser_test(model, test_dir)
else:
print('[!]Unknown phase')
exit(0)
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu)
tf.app.run()