forked from zergylord/oneshot
-
Notifications
You must be signed in to change notification settings - Fork 1
/
match_net.py
172 lines (156 loc) · 7.27 KB
/
match_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import numpy as np
import time
cur_time = time.time()
mb_dim = 32 #training examples per minibatch
x_dim = 28 #size of one side of square image
y_dim = 5 #possible classes
n_samples_per_class = 1 #samples of each class
n_samples = y_dim*n_samples_per_class #total number of labeled samples
eps = 1e-10 #term added for numerical stability of log computations
tie = False #tie the weights of the query network to the labeled network
x_i_learn = True #toggle learning for the query network
learning_rate = 1e-1
data = np.load('data.npy')
data = np.reshape(data,[-1,20,28,28]) #each of the 1600 classes has 20 examples
data = np.random.permutation(data)
train_data = data[:1200,:,:,:]
test_data = data[1200:,:,:,:]
'''
Samples a minibatch of size mb_dim. Each training example contains
n_samples labeled samples, such that n_samples_per_class samples
come from each of y_dim randomly chosen classes. An additional example
one one of these classes is then chosen to be the query, and its label
is the target of the network.
'''
def get_minibatch(test=False):
if test:
cur_data = test_data
print('testing')
else:
cur_data = train_data
mb_x_i = np.zeros((mb_dim,n_samples,x_dim,x_dim,1))
mb_y_i = np.zeros((mb_dim,n_samples))
mb_x_hat = np.zeros((mb_dim,x_dim,x_dim,1),dtype=np.int)
mb_y_hat = np.zeros((mb_dim,),dtype=np.int)
for i in range(mb_dim):
ind = 0
pinds = np.random.permutation(n_samples)
classes = np.random.choice(cur_data.shape[0],y_dim,False)
x_hat_class = np.random.randint(y_dim)
for j,cur_class in enumerate(classes): #each class
example_inds = np.random.choice(cur_data.shape[1],n_samples_per_class,False)
for eind in example_inds:
mb_x_i[i,pinds[ind],:,:,0] = np.rot90(cur_data[cur_class][eind],np.random.randint(4))
mb_y_i[i,pinds[ind]] = j
ind +=1
if j == x_hat_class:
mb_x_hat[i,:,:,0] = np.rot90(cur_data[cur_class][np.random.choice(cur_data.shape[1])],np.random.randint(4))
mb_y_hat[i] = j
return mb_x_i,mb_y_i,mb_x_hat,mb_y_hat
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('summary_dir', '/tmp/oneshot_logs', 'Summaries directory')
if tf.gfile.Exists(FLAGS.summary_dir):
tf.gfile.DeleteRecursively(FLAGS.summary_dir)
tf.gfile.MakeDirs(FLAGS.summary_dir)
x_hat = tf.placeholder(tf.float32,shape=[None,x_dim,x_dim,1])
x_i = tf.placeholder(tf.float32,shape=[None,n_samples,x_dim,x_dim,1])
y_i_ind = tf.placeholder(tf.int32,shape=[None,n_samples])
y_i = tf.one_hot(y_i_ind,y_dim)
y_hat_ind = tf.placeholder(tf.int32,shape=[None])
y_hat = tf.one_hot(y_hat_ind,y_dim)
'''
creates a stack of 4 layers. Each layer contains a
3x3 conv layers, batch normalization, retified activation,
and then 2x2 max pooling. The net effect is to tranform the
mb_dimx28x28x1 images into a mb_dimx1x1x64 embedding, the extra
dims are removed, resulting in mb_dimx64.
'''
def make_conv_net(inp,scope,reuse=False,stop_grad=False):
with tf.variable_scope(scope) as varscope:
if reuse: varscope.reuse_variables()
cur_input = inp
cur_filters = 1
for i in range(4):
with tf.variable_scope('conv'+str(i)):
W = tf.get_variable('W',[3,3,cur_filters,64])
beta = tf.get_variable('beta',[64],initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma',[64],initializer=tf.constant_initializer(1.0))
cur_filters = 64
pre_norm = tf.nn.conv2d(cur_input,W,strides=[1,1,1,1],padding='SAME')
mean,variance = tf.nn.moments(pre_norm,[0,1,2])
post_norm = tf.nn.batch_normalization(pre_norm,mean,variance,beta,gamma,eps)
conv = tf.nn.relu(post_norm)
cur_input = tf.nn.max_pool(conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID')
if stop_grad:
return tf.stop_gradient(tf.squeeze(cur_input,[1,2]))
else:
return tf.squeeze(cur_input,[1,2])
'''
assemble a computational graph for processing minibatches of the n_samples labeled examples and one unlabeled sample.
All labeled examples use the same convolutional network, whereas the unlabeled sample defaults to using different parameters.
After using the convolutional networks to encode the input, the pairwise cos similarity is computed. The normalized version of this
is used to weight each label's contribution to the queried label prediction.
'''
scope = 'encode_x'
x_hat_encode = make_conv_net(x_hat,scope)
#x_hat_inv_mag = tf.rsqrt(tf.clip_by_value(tf.reduce_sum(tf.square(x_hat_encode),1,keep_dims=True),eps,float("inf")))
cos_sim_list = []
if not tie:
scope = 'encode_x_i'
for i in range(n_samples):
x_i_encode = make_conv_net(x_i[:,i,:,:,:],scope,tie or i > 0,not x_i_learn)
x_i_inv_mag = tf.rsqrt(tf.clip_by_value(tf.reduce_sum(tf.square(x_i_encode),1,keep_dims=True),eps,float("inf")))
dotted = tf.squeeze(
tf.batch_matmul(tf.expand_dims(x_hat_encode,1),tf.expand_dims(x_i_encode,2)),[1,])
cos_sim_list.append(dotted
*x_i_inv_mag)
#*x_hat_inv_mag
cos_sim = tf.concat(1,cos_sim_list)
tf.histogram_summary('cos sim',cos_sim)
weighting = tf.nn.softmax(cos_sim)
label_prob = tf.squeeze(tf.batch_matmul(tf.expand_dims(weighting,1),y_i))
tf.histogram_summary('label prob',label_prob)
top_k = tf.nn.in_top_k(label_prob,y_hat_ind,1)
acc = tf.reduce_mean(tf.to_float(top_k))
tf.scalar_summary('train avg accuracy',acc)
correct_prob = tf.reduce_sum(tf.log(tf.clip_by_value(label_prob,eps,1.0))*y_hat,1)
loss = tf.reduce_mean(-correct_prob,0)
tf.scalar_summary('loss',loss)
optim = tf.train.GradientDescentOptimizer(learning_rate)
#optim = tf.train.AdamOptimizer(learning_rate)
grads = optim.compute_gradients(loss)
grad_summaries = [tf.histogram_summary(v.name,g) if g is not None else '' for g,v in grads]
train_step = optim.apply_gradients(grads)
#testing stuff
test_acc = tf.reduce_mean(tf.to_float(top_k))
'''
End of the construction of the computational graph. The remaining code runs training steps.
'''
sess = tf.Session()
merged = tf.merge_all_summaries()
test_summ = tf.scalar_summary('test avg accuracy',test_acc)
writer = tf.train.SummaryWriter(FLAGS.summary_dir,sess.graph)
sess.run(tf.initialize_all_variables())
for i in range(int(1e7)):
mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch()
feed_dict = {x_hat: mb_x_hat,
y_hat_ind: mb_y_hat,
x_i: mb_x_i,
y_i_ind: mb_y_i}
_,mb_loss,summary,ans = sess.run([train_step,loss,merged,cos_sim],feed_dict=feed_dict)
if i % int(1e2) == 0:
mb_x_i,mb_y_i,mb_x_hat,mb_y_hat = get_minibatch(True)
feed_dict = {x_hat: mb_x_hat,
y_hat_ind: mb_y_hat,
x_i: mb_x_i,
y_i_ind: mb_y_i}
_,test_summary = sess.run([test_acc,test_summ],feed_dict=feed_dict)
writer.add_summary(test_summary,i)
print(i,'loss: ',mb_loss,'time: ',time.time()-cur_time)
cur_time = time.time()
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
writer.add_run_metadata(run_metadata, 'step%d' % i)
writer.add_summary(summary,i)