-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_vae.py
34 lines (28 loc) · 897 Bytes
/
main_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""CS446 2018 Spring MP10.
Implementation of a variational autoencoder for image generation.
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from vaes.vae import VariationalAutoencoder
import input_data
from anime_data import *
import time
def main(_):
"""High level pipeline.
This script performs the training for VAEs.
"""
# Get dataset.
# dataset = input_data.read_data_sets('MNIST_data', one_hot=True).train
dataset = get_dataset(low_memory=False, mode="L")
# Build model.
NLATENT = 100
# model = VariationalAutoencoder(ndims=28*28, nlatent=NLATENT)
model = VariationalAutoencoder(ndims=64*64, nlatent=NLATENT)
# Start training
print("Training...")
start_t = time.time()
model.train(dataset, num_steps=5000)
print("Done! %r", time.time() - start_t)
if __name__ == "__main__":
tf.app.run()