Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HongtaoYang authored Jul 23, 2018
1 parent 32e0ce5 commit f5cad38
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 0 deletions.
145 changes: 145 additions & 0 deletions DAC_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

import module
import util


mode = 'Training'
num_cluster = 10
eps = 1e-10 # term added for numerical stability of log computations


# ------------------------------------build the computation graph------------------------------------------
image_pool_input = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name='image_pool_input')
u_thres = tf.placeholder(shape=[], dtype=tf.float32, name='u_thres')
l_thres = tf.placeholder(shape=[], dtype=tf.float32, name='l_thres')
lr = tf.placeholder(shape=[], dtype=tf.float32, name='learning_rate')

# get similarity matrix
label_feat = module.mnistNetwork(image_pool_input, num_cluster, name='mnistNetwork', reuse=False)
label_feat_norm = tf.nn.l2_normalize(label_feat, dim=1)
sim_mat = tf.matmul(label_feat_norm, label_feat_norm, transpose_b=True)

pos_loc = tf.greater(sim_mat, u_thres, name='greater')
neg_loc = tf.less(sim_mat, l_thres, name='less')
# select_mask = tf.cast(tf.logical_or(pos_loc, neg_loc, name='mask'), dtype=tf.float32)
pos_loc_mask = tf.cast(pos_loc, dtype=tf.float32)
neg_loc_mask = tf.cast(neg_loc, dtype=tf.float32)

# get clusters
pred_label = tf.argmax(label_feat, axis=1)

# define losses and train op
pos_entropy = tf.multiply(-tf.log(tf.clip_by_value(sim_mat, eps, 1.0)), pos_loc_mask)
neg_entropy = tf.multiply(-tf.log(tf.clip_by_value(1-sim_mat, eps, 1.0)), neg_loc_mask)

loss_sum = tf.reduce_mean(pos_entropy) + tf.reduce_mean(neg_entropy)
train_op = tf.train.RMSPropOptimizer(lr).minimize(loss_sum)
# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# with tf.control_dependencies(update_ops):
# train_op = tf.train.RMSPropOptimizer(lr).minimize(loss)


# -------------------------------------------prepared datasets----------------------------------------------
# read mnist data (1 channel)
# mnist_1 = tf.contrib.learn.datasets.load_dataset("mnist")
mnist = input_data.read_data_sets('MNIST-data') # your mnist data should be stored at 'MNIST-data'
mnist_train = mnist.train.images
mnist_train = np.reshape(mnist_train, (-1, 28, 28, 1)) # reshape into 1-channel image
mnist_train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
mnist_test = mnist.test.images
mnist_test = np.reshape(mnist_test, (-1, 28, 28, 1)) # reshape into 1-channel image
mnist_test_labels = np.asarray(mnist.test.labels, dtype=np.int32)

mnist_data = np.concatenate([mnist_train, mnist_test], axis=0)
mnist_labels = np.concatenate([mnist_train_labels, mnist_test_labels], axis=0)
# print(len(mnist_labels))

# # read cifar data
# cifar_data = []
# cifar_label = []
# for i in range(1, 6):
# file_name = 'cifar-10-data/' + 'data_batch_' + str(i)
# with open(file_name, 'rb') as fo:
# cifar_dict = cPickle.load(fo)
# data = cifar_dict['data']
# label = cifar_dict['labels']

# data = data.astype('float32')/255
# data = np.reshape(data, (-1, 3, 32, 32))
# data = np.transpose(data, (0, 2, 3, 1))
# cifar_data.append(data)
# cifar_label.append(label)

# cifar_data = np.concatenate(cifar_data, axis=0)
# cifar_label = np.concatenate(cifar_label, axis=0)
# # print cifar_data.shape


# --------------------------------------------run the graph-------------------------------------------------
saver = tf.train.Saver()
if mode == 'Training':
batch_size = 128
base_lr = 0.001
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

lamda = 0
epoch = 1
u = 0.95
l = 0.455
while u > l:
u = 0.95 - lamda
l = 0.455 + 0.1*lamda
for i in range(1, int(1001)): # 1000 iterations is roughly 1 epoch
data_samples, _ = util.get_mnist_batch(batch_size, mnist_data, mnist_labels)
feed_dict={image_pool_input: data_samples,
u_thres: u,
l_thres: l,
lr: base_lr}
train_loss, _ = sess.run([loss_sum, train_op], feed_dict=feed_dict)
if i % 20 == 0:
print('training loss at iter %d is %f' % (i, train_loss))

lamda += 1.1 * 0.009

# run testing every epoch
data_samples, data_labels = util.get_mnist_batch(512, mnist_data, mnist_labels)
feed_dict={image_pool_input: data_samples}
pred_cluster = sess.run(pred_label, feed_dict=feed_dict)

acc = util.clustering_acc(data_labels, pred_cluster)
nmi = util.NMI(data_labels, pred_cluster)
ari = util.ARI(data_labels, pred_cluster)
print('testing NMI, ARI, ACC at epoch %d is %f, %f, %f.' % (epoch, nmi, ari, acc))

if epoch % 5 == 0: # save model at every 5 epochs
model_name = 'DAC_ep_' + str(epoch) + '.ckpt'
save_path = saver.save(sess, 'DAC_models/' + model_name)
print("Model saved in file: %s" % save_path)

epoch += 1

elif mode == 'Testing':
batch_size = 1000
with tf.Session() as sess:
saver.restore(sess, "DAC_models/DAC_ep_45.ckpt")
print('model restored!')
all_predictions = np.zeros([len(mnist_labels)], dtype=np.float32)
for i in range(65):
data_samples = util.get_mnist_batch_test(batch_size, mnist_data, i)
feed_dict={image_pool_input: data_samples}
pred_cluster = sess.run(pred_label, feed_dict=feed_dict)
all_predictions[i*batch_size:(i+1)*batch_size] = pred_cluster

acc = util.clustering_acc(mnist_labels.astype(int), all_predictions.astype(int))
nmi = util.NMI(mnist_labels.astype(int), all_predictions.astype(int))
ari = util.ARI(mnist_labels.astype(int), all_predictions.astype(int))
print('testing NMI, ARI, ACC are %f, %f, %f.' % (nmi, ari, acc))





67 changes: 67 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import random
from sklearn import metrics
from sklearn.utils.linear_assignment_ import linear_assignment


def get_cifar_batch(batch_size, cifar_data, cifar_label):
batch_index = random.sample(range(len(cifar_label)), batch_size)

batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32)
batch_label = np.empty([batch_size], dtype=np.int32)
for n, i in enumerate(batch_index):
batch_data[n, ...] = cifar_data[i, ...]
batch_label[n] = cifar_label[i]

return batch_data, batch_label


def get_mnist_batch(batch_size, mnist_data, mnist_labels):
batch_index = random.sample(range(len(mnist_labels)), batch_size)

batch_data = np.empty([batch_size, 28, 28, 1], dtype=np.float32)
batch_label = np.empty([batch_size], dtype=np.int32)
for n, i in enumerate(batch_index):
batch_data[n, ...] = mnist_data[i, ...]
batch_label[n] = mnist_labels[i]

return batch_data, batch_label


def get_mnist_batch_test(batch_size, mnist_data, i):
batch_data = np.copy(mnist_data[batch_size*i:batch_size*(i+1), ...])
# batch_label = np.copy(mnist_labels[batch_size*i:batch_size*(i+1)])

return batch_data


def get_svhn_batch(batch_size, svhn_data, svhn_labels):
batch_index = random.sample(range(len(svhn_labels)), batch_size)

batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32)
batch_label = np.empty([batch_size], dtype=np.int32)
for n, i in enumerate(batch_index):
batch_data[n, ...] = svhn_data[i, ...]
batch_label[n] = svhn_labels[i]

return batch_data, batch_label


def clustering_acc(y_true, y_pred):
y_true = y_true.astype(np.int64)
assert y_pred.size == y_true.size
D = max(y_pred.max(), y_true.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
for i in range(y_pred.size):
w[y_pred[i], y_true[i]] += 1
ind = linear_assignment(w.max() - w)

return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size


def NMI(y_true,y_pred):
return metrics.normalized_mutual_info_score(y_true, y_pred)


def ARI(y_true,y_pred):
return metrics.adjusted_rand_score(y_true, y_pred)

0 comments on commit f5cad38

Please sign in to comment.