-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdae_walkback.py
executable file
·133 lines (101 loc) · 3.77 KB
/
dae_walkback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!env/bin/python
import tensorflow as tf
import input_data
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy
# Helpers
trunc = lambda x : str(x)[:8]
cast32 = lambda x : numpy.cast['float32'](x)
def binomial_draw(shape=[1], p=0.5, dtype='float32'):
return tf.select(tf.less(tf.random_uniform(shape=shape, minval=0, maxval=1, dtype='float32'), tf.fill(shape, p)), tf.ones(shape, dtype=dtype), tf.zeros(shape, dtype=dtype))
def binomial_draw_vec(p_vec, shape=[1], dtype='float32'):
return tf.select(tf.less(tf.random_uniform(shape=shape, minval=0, maxval=1, dtype='float32'), p_vec), tf.ones(shape, dtype=dtype), tf.zeros(shape, dtype=dtype))
def salt_and_pepper(X, rate=0.3):
a = binomial_draw(shape=tf.shape(X), p=1-rate)
b = binomial_draw(shape=tf.shape(X), p=0.5)
z = tf.zeros(tf.shape(X), dtype='float32')
c = tf.select(tf.equal(a, z), b, z)
return tf.add(tf.mul(X, a), c)
# Xavier Initializers
def get_shared_weights(n_in, n_out, interval):
val = numpy.random.uniform(-interval, interval, size=(n_in, n_out))
val = cast32(val)
return tf.Variable(val)
def get_shared_bias(n, offset = 0):
val = numpy.zeros(n) - offset
val = cast32(val)
return tf.Variable(val)
# Read mnist examples
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Number of hidden layers
hidden_size = 2000
# Number of walkbacks
walkbacks = 5
# Number of epochs
n_epoch = 500
# Batch size
batch_size = 100
# Salt and pepper noise
input_salt_and_pepper = 0.4
# Input
x0 = tf.placeholder(tf.float32, [None, 784])
x0_copy = x0
W1 = get_shared_weights(784, hidden_size, numpy.sqrt(6. / (784 + hidden_size)))
b0 = get_shared_bias(784)
b1 = get_shared_bias(hidden_size)
p_X_chain = []
for i in range(walkbacks):
# Binarize
x0_bin = tf.floor(tf.add(x0_copy, tf.fill(tf.shape(x0), 0.5)))
# Add noise
x_corrupt = salt_and_pepper(x0_bin, input_salt_and_pepper)
# Activate
h1 = tf.sigmoid(tf.matmul(x_corrupt, W1) + b1)
# Activate
x1 = tf.sigmoid(tf.matmul(h1, tf.transpose(W1)) + b0)
p_X_chain.append(x1)
x0_copy = binomial_draw_vec(x1, shape=tf.shape(x1))
cross_entropies = [-tf.reduce_sum(x0*tf.log(tf.clip_by_value(x1,1e-10,1.0)) + (1-x0)*tf.log(tf.clip_by_value(1-x1,1e-10,1.0))) for x1 in p_X_chain]
cross_entropy = tf.add_n(cross_entropies)
train_step = tf.train.AdamOptimizer().minimize(cross_entropy)
# Initalization
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(n_epoch):
print 'Epoch: ', i+1,
# train
train_cost = []
for j in range(mnist.train.num_examples/batch_size):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
result = sess.run((cross_entropy, train_step), feed_dict={x0: batch_xs})
train_cost.append(result[0])
train_cost = numpy.mean(train_cost)
print 'Train: ', train_cost/(100*28*28),
# valid
valid_cost = []
for j in range(mnist.validation.num_examples/batch_size):
batch_xs, batch_ys = mnist.validation.next_batch(batch_size)
result = sess.run(cross_entropy, feed_dict={x0: batch_xs})
valid_cost.append(result)
valid_cost = numpy.mean(valid_cost)
print 'Valid: ', valid_cost/(100*28*28),
# test
test_cost = []
for j in range(mnist.test.num_examples/batch_size):
batch_xs, batch_ys = mnist.test.next_batch(batch_size)
result = sess.run(cross_entropy, feed_dict={x0: batch_xs})
test_cost.append(result)
test_cost = numpy.mean(test_cost)
print 'Test: ', test_cost/(100*28*28)
# sample from the network
test_input = mnist.test.next_batch(1)[0]
samples = [test_input]
fig, axs = plt.subplots(40, 10, figsize=(10, 40))
for i in range(400):
samples.append(sess.run(x1, feed_dict={x0: samples[-1]}))
axs[i/10][i%10].imshow(numpy.reshape(samples[i], (28,28)), cmap='gray')
plt.axis('off')
plt.savefig('dae_walkback.png')