-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
32e0ce5
commit f5cad38
Showing
2 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import tensorflow as tf | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
import numpy as np | ||
|
||
import module | ||
import util | ||
|
||
|
||
mode = 'Training' | ||
num_cluster = 10 | ||
eps = 1e-10 # term added for numerical stability of log computations | ||
|
||
|
||
# ------------------------------------build the computation graph------------------------------------------ | ||
image_pool_input = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name='image_pool_input') | ||
u_thres = tf.placeholder(shape=[], dtype=tf.float32, name='u_thres') | ||
l_thres = tf.placeholder(shape=[], dtype=tf.float32, name='l_thres') | ||
lr = tf.placeholder(shape=[], dtype=tf.float32, name='learning_rate') | ||
|
||
# get similarity matrix | ||
label_feat = module.mnistNetwork(image_pool_input, num_cluster, name='mnistNetwork', reuse=False) | ||
label_feat_norm = tf.nn.l2_normalize(label_feat, dim=1) | ||
sim_mat = tf.matmul(label_feat_norm, label_feat_norm, transpose_b=True) | ||
|
||
pos_loc = tf.greater(sim_mat, u_thres, name='greater') | ||
neg_loc = tf.less(sim_mat, l_thres, name='less') | ||
# select_mask = tf.cast(tf.logical_or(pos_loc, neg_loc, name='mask'), dtype=tf.float32) | ||
pos_loc_mask = tf.cast(pos_loc, dtype=tf.float32) | ||
neg_loc_mask = tf.cast(neg_loc, dtype=tf.float32) | ||
|
||
# get clusters | ||
pred_label = tf.argmax(label_feat, axis=1) | ||
|
||
# define losses and train op | ||
pos_entropy = tf.multiply(-tf.log(tf.clip_by_value(sim_mat, eps, 1.0)), pos_loc_mask) | ||
neg_entropy = tf.multiply(-tf.log(tf.clip_by_value(1-sim_mat, eps, 1.0)), neg_loc_mask) | ||
|
||
loss_sum = tf.reduce_mean(pos_entropy) + tf.reduce_mean(neg_entropy) | ||
train_op = tf.train.RMSPropOptimizer(lr).minimize(loss_sum) | ||
# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | ||
# with tf.control_dependencies(update_ops): | ||
# train_op = tf.train.RMSPropOptimizer(lr).minimize(loss) | ||
|
||
|
||
# -------------------------------------------prepared datasets---------------------------------------------- | ||
# read mnist data (1 channel) | ||
# mnist_1 = tf.contrib.learn.datasets.load_dataset("mnist") | ||
mnist = input_data.read_data_sets('MNIST-data') # your mnist data should be stored at 'MNIST-data' | ||
mnist_train = mnist.train.images | ||
mnist_train = np.reshape(mnist_train, (-1, 28, 28, 1)) # reshape into 1-channel image | ||
mnist_train_labels = np.asarray(mnist.train.labels, dtype=np.int32) | ||
mnist_test = mnist.test.images | ||
mnist_test = np.reshape(mnist_test, (-1, 28, 28, 1)) # reshape into 1-channel image | ||
mnist_test_labels = np.asarray(mnist.test.labels, dtype=np.int32) | ||
|
||
mnist_data = np.concatenate([mnist_train, mnist_test], axis=0) | ||
mnist_labels = np.concatenate([mnist_train_labels, mnist_test_labels], axis=0) | ||
# print(len(mnist_labels)) | ||
|
||
# # read cifar data | ||
# cifar_data = [] | ||
# cifar_label = [] | ||
# for i in range(1, 6): | ||
# file_name = 'cifar-10-data/' + 'data_batch_' + str(i) | ||
# with open(file_name, 'rb') as fo: | ||
# cifar_dict = cPickle.load(fo) | ||
# data = cifar_dict['data'] | ||
# label = cifar_dict['labels'] | ||
|
||
# data = data.astype('float32')/255 | ||
# data = np.reshape(data, (-1, 3, 32, 32)) | ||
# data = np.transpose(data, (0, 2, 3, 1)) | ||
# cifar_data.append(data) | ||
# cifar_label.append(label) | ||
|
||
# cifar_data = np.concatenate(cifar_data, axis=0) | ||
# cifar_label = np.concatenate(cifar_label, axis=0) | ||
# # print cifar_data.shape | ||
|
||
|
||
# --------------------------------------------run the graph------------------------------------------------- | ||
saver = tf.train.Saver() | ||
if mode == 'Training': | ||
batch_size = 128 | ||
base_lr = 0.001 | ||
with tf.Session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
|
||
lamda = 0 | ||
epoch = 1 | ||
u = 0.95 | ||
l = 0.455 | ||
while u > l: | ||
u = 0.95 - lamda | ||
l = 0.455 + 0.1*lamda | ||
for i in range(1, int(1001)): # 1000 iterations is roughly 1 epoch | ||
data_samples, _ = util.get_mnist_batch(batch_size, mnist_data, mnist_labels) | ||
feed_dict={image_pool_input: data_samples, | ||
u_thres: u, | ||
l_thres: l, | ||
lr: base_lr} | ||
train_loss, _ = sess.run([loss_sum, train_op], feed_dict=feed_dict) | ||
if i % 20 == 0: | ||
print('training loss at iter %d is %f' % (i, train_loss)) | ||
|
||
lamda += 1.1 * 0.009 | ||
|
||
# run testing every epoch | ||
data_samples, data_labels = util.get_mnist_batch(512, mnist_data, mnist_labels) | ||
feed_dict={image_pool_input: data_samples} | ||
pred_cluster = sess.run(pred_label, feed_dict=feed_dict) | ||
|
||
acc = util.clustering_acc(data_labels, pred_cluster) | ||
nmi = util.NMI(data_labels, pred_cluster) | ||
ari = util.ARI(data_labels, pred_cluster) | ||
print('testing NMI, ARI, ACC at epoch %d is %f, %f, %f.' % (epoch, nmi, ari, acc)) | ||
|
||
if epoch % 5 == 0: # save model at every 5 epochs | ||
model_name = 'DAC_ep_' + str(epoch) + '.ckpt' | ||
save_path = saver.save(sess, 'DAC_models/' + model_name) | ||
print("Model saved in file: %s" % save_path) | ||
|
||
epoch += 1 | ||
|
||
elif mode == 'Testing': | ||
batch_size = 1000 | ||
with tf.Session() as sess: | ||
saver.restore(sess, "DAC_models/DAC_ep_45.ckpt") | ||
print('model restored!') | ||
all_predictions = np.zeros([len(mnist_labels)], dtype=np.float32) | ||
for i in range(65): | ||
data_samples = util.get_mnist_batch_test(batch_size, mnist_data, i) | ||
feed_dict={image_pool_input: data_samples} | ||
pred_cluster = sess.run(pred_label, feed_dict=feed_dict) | ||
all_predictions[i*batch_size:(i+1)*batch_size] = pred_cluster | ||
|
||
acc = util.clustering_acc(mnist_labels.astype(int), all_predictions.astype(int)) | ||
nmi = util.NMI(mnist_labels.astype(int), all_predictions.astype(int)) | ||
ari = util.ARI(mnist_labels.astype(int), all_predictions.astype(int)) | ||
print('testing NMI, ARI, ACC are %f, %f, %f.' % (nmi, ari, acc)) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import random | ||
from sklearn import metrics | ||
from sklearn.utils.linear_assignment_ import linear_assignment | ||
|
||
|
||
def get_cifar_batch(batch_size, cifar_data, cifar_label): | ||
batch_index = random.sample(range(len(cifar_label)), batch_size) | ||
|
||
batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32) | ||
batch_label = np.empty([batch_size], dtype=np.int32) | ||
for n, i in enumerate(batch_index): | ||
batch_data[n, ...] = cifar_data[i, ...] | ||
batch_label[n] = cifar_label[i] | ||
|
||
return batch_data, batch_label | ||
|
||
|
||
def get_mnist_batch(batch_size, mnist_data, mnist_labels): | ||
batch_index = random.sample(range(len(mnist_labels)), batch_size) | ||
|
||
batch_data = np.empty([batch_size, 28, 28, 1], dtype=np.float32) | ||
batch_label = np.empty([batch_size], dtype=np.int32) | ||
for n, i in enumerate(batch_index): | ||
batch_data[n, ...] = mnist_data[i, ...] | ||
batch_label[n] = mnist_labels[i] | ||
|
||
return batch_data, batch_label | ||
|
||
|
||
def get_mnist_batch_test(batch_size, mnist_data, i): | ||
batch_data = np.copy(mnist_data[batch_size*i:batch_size*(i+1), ...]) | ||
# batch_label = np.copy(mnist_labels[batch_size*i:batch_size*(i+1)]) | ||
|
||
return batch_data | ||
|
||
|
||
def get_svhn_batch(batch_size, svhn_data, svhn_labels): | ||
batch_index = random.sample(range(len(svhn_labels)), batch_size) | ||
|
||
batch_data = np.empty([batch_size, 32, 32, 3], dtype=np.float32) | ||
batch_label = np.empty([batch_size], dtype=np.int32) | ||
for n, i in enumerate(batch_index): | ||
batch_data[n, ...] = svhn_data[i, ...] | ||
batch_label[n] = svhn_labels[i] | ||
|
||
return batch_data, batch_label | ||
|
||
|
||
def clustering_acc(y_true, y_pred): | ||
y_true = y_true.astype(np.int64) | ||
assert y_pred.size == y_true.size | ||
D = max(y_pred.max(), y_true.max()) + 1 | ||
w = np.zeros((D, D), dtype=np.int64) | ||
for i in range(y_pred.size): | ||
w[y_pred[i], y_true[i]] += 1 | ||
ind = linear_assignment(w.max() - w) | ||
|
||
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size | ||
|
||
|
||
def NMI(y_true,y_pred): | ||
return metrics.normalized_mutual_info_score(y_true, y_pred) | ||
|
||
|
||
def ARI(y_true,y_pred): | ||
return metrics.adjusted_rand_score(y_true, y_pred) |