forked from cs231n/cs231n.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_activation_visualize.py
62 lines (52 loc) · 2.23 KB
/
mnist_activation_visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import matplotlib as mp
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data
import math
def getActivations(layer,stimuli):
units = sess.run(layer,feed_dict={x:np.reshape(stimuli,[1,784],order='F'),keep_prob:1.0})
plotNNFilter(units)
def plotNNFilter(units):
filters = units.shape[3]
plt.figure(1, figsize=(20,20))
n_columns = 6
n_rows = math.ceil(filters / n_columns) + 1
for i in range(filters):
plt.subplot(n_rows, n_columns, i+1)
plt.title('Filter ' + str(i))
plt.imshow(units[0,:,:,i], interpolation="nearest", cmap="gray")
tf.reset_default_graph()
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784],name="x-in")
true_y = tf.placeholder(tf.float32, [None, 10],name="y-in")
keep_prob = tf.placeholder("float")
x_image = tf.reshape(x,[-1,28,28,1])
hidden_1 = slim.conv2d(x_image,5,[5,5])
pool_1 = slim.max_pool2d(hidden_1,[2,2])
hidden_2 = slim.conv2d(pool_1,5,[5,5])
pool_2 = slim.max_pool2d(hidden_2,[2,2])
hidden_3 = slim.conv2d(pool_2,20,[5,5])
out_y = slim.fully_connected(slim.flatten(hidden_3),10,activation_fn=tf.nn.softmax)
cross_entropy = -tf.reduce_sum(true_y*tf.log(out_y))
correct_prediction = tf.equal(tf.argmax(out_y,1), tf.argmax(true_y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
batchSize = 50
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(100):
batch = mnist.train.next_batch(batchSize)
sess.run(train_step, feed_dict={x:batch[0],true_y:batch[1], keep_prob:0.5})
if i % 100 == 0 and i != 0:
trainAccuracy = sess.run(accuracy, feed_dict={x:batch[0],true_y:batch[1], keep_prob:1.0})
print("step %d, training accuracy %g"%(i, trainAccuracy))
testAccuracy = sess.run(accuracy, feed_dict={x:mnist.test.images,true_y:mnist.test.labels, keep_prob:1.0})
print("test accuracy %g"%(testAccuracy))
imageToUse = mnist.test.images[0]
plt.imshow(np.reshape(imageToUse,[28,28]), interpolation="nearest", cmap="gray")
getActivations(hidden_1,imageToUse)
plt.show()
print("dcm")