-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
137 lines (117 loc) · 5.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
## GAN Variants
from GAN import GAN
from CGAN import CGAN
from infoGAN import infoGAN
from ACGAN import ACGAN
from EBGAN import EBGAN
from WGAN import WGAN
from WGAN_GP import WGAN_GP
from DRAGAN import DRAGAN
from LSGAN import LSGAN
from BEGAN import BEGAN
## VAE Variants
from VAE import VAE
from CVAE import CVAE
from utils import show_all_variables
import tensorflow as tf
import argparse
"""parsing and configuration"""
def parse_args():
desc = "Tensorflow implementation of GAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gan_type', type=str, default='GAN',
choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE'],
help='The type of GAN', required=True)
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'],
help='The name of dataset')
parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=100, help='The size of batch')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
help='Directory name to save the checkpoints')
parser.add_argument('--result_dir', type=str, default='results',
help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs',
help='Directory name to save training logs')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --checkpoint_dir
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
# --result_dir
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
# --result_dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# open session
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# declare instance for GAN
if args.gan_type == 'GAN':
gan = GAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'CGAN':
gan = CGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'ACGAN':
gan = ACGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'infoGAN':
gan = infoGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'EBGAN':
gan = EBGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'WGAN':
gan = WGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'WGAN_GP':
gan = WGAN_GP(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'DRAGAN':
gan = DRAGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'LSGAN':
gan = LSGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'BEGAN':
gan = BEGAN(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'VAE':
gan = VAE(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
elif args.gan_type == 'CVAE':
gan = CVAE(sess, epoch=args.epoch, batch_size=args.batch_size, dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
else:
raise Exception("[!] There is no option for " + args.gan_type)
# build graph
gan.build_model()
# show network architecture
show_all_variables()
# launch the graph in a session
gan.train()
print(" [*] Training finished!")
# visualize learned generator
gan.visualize_results(args.epoch-1)
print(" [*] Testing finished!")
if __name__ == '__main__':
main()