forked from qhoangdl/MGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
68 lines (61 loc) · 2.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import sys
import pickle
import argparse
import numpy as np
import tensorflow as tf
from models import MGAN
FLAGS = None
def main(_):
tmp = pickle.load(open(FLAGS.dataset_file, "rb"))
x_train = tmp['data'].astype(np.float32).reshape([-1]+list(FLAGS.image_size) ) / 127.5 - 1.
model = MGAN(
num_z=FLAGS.num_z,
beta=FLAGS.beta,
num_gens=FLAGS.num_gens,
d_batch_size=FLAGS.d_batch_size,
g_batch_size=FLAGS.g_batch_size,
z_prior=FLAGS.z_prior,
learning_rate=FLAGS.learning_rate,
img_size=tuple(FLAGS.image_size),
num_conv_layers=FLAGS.num_conv_layers,
num_gen_feature_maps=FLAGS.num_gen_feature_maps,
num_dis_feature_maps=FLAGS.num_dis_feature_maps,
num_epochs=FLAGS.num_epochs,
sample_fp="samples/samples_{epoch:04d}.png",
sample_by_gen_fp="samples_by_gen/samples_{epoch:04d}.png",
random_seed=6789)
# model._restore('mgan_checkpoint_450', 450)
model.fit(x_train)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_z', type=int, default=100,
help='Number of latent units.')
parser.add_argument('--beta', type=float, default=0.01,
help='Diversity parameter beta.')
parser.add_argument('--num_gens', type=int, default=10,
help='Number of generators.')
parser.add_argument('--d_batch_size', type=int, default=64,
help='Minibatch size for the discriminator.')
parser.add_argument('--g_batch_size', type=int, default=12,
help='Minibatch size for the generators.')
parser.add_argument('--z_prior', type=str, default="uniform",
help='Prior distribution of the noise (uniform/gaussian).')
parser.add_argument('--learning_rate', type=float, default=0.0002,
help='Learning rate.')
parser.add_argument('--num_conv_layers', type=int, default=3,
help='Number of convolutional layers.')
parser.add_argument('--num_gen_feature_maps', type=int, default=128,
help='Number of feature maps of Generator.')
parser.add_argument('--num_dis_feature_maps', type=int, default=128,
help='Number of feature maps of Discriminator.')
parser.add_argument('--num_epochs', type=int, default=500,
help='Number of epochs.')
parser.add_argument('--dataset_file', type=str, default='./data/cifar10_train.pkl',
help='Dataset as a pickled dictionary {"data": Train_np_array, "labels": array-like}.')
parser.add_argument('--image_size', nargs='+', type=int, default=(32, 32, 3),
help='')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)