forked from vvanirudh/tensorflow-vrnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample_vrnn.py
51 lines (38 loc) · 1.35 KB
/
sample_vrnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
'''
Variational RNN sample script using TensorFlow
Model introduced in https://arxiv.org/abs/1506.02216
Chung, J., Kastner, K., Dinh, L., Goel, K., Courville, A. C., & Bengio, Y. (2015).
A recurrent latent variable model for sequential data.
In Advances in neural information processing systems (pp. 2980-2988).
Code original author : phreeza (taken from https://github.com/phreeza/tensorflow-vrnn)
Author : Anirudh Vemula
Date : December 5th, 2016
'''
import tensorflow as tf
import os
import cPickle
from model_vrnn import VRNN
from utils_vrnn import DataLoader
import numpy as np
def main():
'''
Main function
'''
# Laod the saved arguments
with open(os.path.join('save-vrnn', 'config.pkl')) as f:
saved_args = cPickle.load(f)
# Initialize the model with the saved arguments in inference mode
model = VRNN(saved_args, True)
# Initialize the TensorFlow session
sess = tf.InteractiveSession()
# Initialize the saver
saver = tf.train.Saver(tf.all_variables())
# Get model checkpoint
ckpt = tf.train.get_checkpoint_state('save-vrnn')
print "loading model: ", ckpt.model_checkpoint_path
# Restore the model from the saved file
saver.restore(sess, ckpt.model_checkpoint_path)
# Sample the model
sample_data, mus, sigmas = model.sample(sess, saved_args)
if __name__ == '__main__':
main()