-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5bcfa8d
commit 3a50e7c
Showing
7 changed files
with
750 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
python==3.7.11 | ||
pytorch==1.9.0 | ||
numpy==1.20.1 | ||
scikit-learn==0.22.2.post1 | ||
scipy==1.6.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from sklearn.preprocessing import MinMaxScaler | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
import scipy.io | ||
import torch | ||
|
||
|
||
class BDGP(Dataset): | ||
def __init__(self, path): | ||
data1 = scipy.io.loadmat(path+'BDGP.mat')['X1'].astype(np.float32) | ||
data2 = scipy.io.loadmat(path+'BDGP.mat')['X2'].astype(np.float32) | ||
labels = scipy.io.loadmat(path+'BDGP.mat')['Y'].transpose() | ||
self.x1 = data1 | ||
self.x2 = data2 | ||
self.y = labels | ||
|
||
def __len__(self): | ||
return self.x1.shape[0] | ||
|
||
def __getitem__(self, idx): | ||
return [torch.from_numpy(self.x1[idx]), torch.from_numpy( | ||
self.x2[idx])], torch.from_numpy(self.y[idx]), torch.from_numpy(np.array(idx)).long() | ||
|
||
|
||
class CCV(Dataset): | ||
def __init__(self, path): | ||
self.data1 = np.load(path+'STIP.npy').astype(np.float32) | ||
scaler = MinMaxScaler() | ||
self.data1 = scaler.fit_transform(self.data1) | ||
self.data2 = np.load(path+'SIFT.npy').astype(np.float32) | ||
self.data3 = np.load(path+'MFCC.npy').astype(np.float32) | ||
self.labels = np.load(path+'label.npy') | ||
|
||
def __len__(self): | ||
return 6773 | ||
|
||
def __getitem__(self, idx): | ||
x1 = self.data1[idx] | ||
x2 = self.data2[idx] | ||
x3 = self.data3[idx] | ||
|
||
return [torch.from_numpy(x1), torch.from_numpy( | ||
x2), torch.from_numpy(x3)], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() | ||
|
||
|
||
class MNIST_USPS(Dataset): | ||
def __init__(self, path): | ||
self.Y = scipy.io.loadmat(path + 'MNIST_USPS.mat')['Y'].astype(np.int32).reshape(5000,) | ||
self.V1 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X1'].astype(np.float32) | ||
self.V2 = scipy.io.loadmat(path + 'MNIST_USPS.mat')['X2'].astype(np.float32) | ||
|
||
def __len__(self): | ||
return 5000 | ||
|
||
def __getitem__(self, idx): | ||
|
||
x1 = self.V1[idx].reshape(784) | ||
x2 = self.V2[idx].reshape(784) | ||
return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long() | ||
|
||
|
||
class Fashion(Dataset): | ||
def __init__(self, path): | ||
self.Y = scipy.io.loadmat(path + 'Fashion.mat')['Y'].astype(np.int32).reshape(10000,) | ||
self.V1 = scipy.io.loadmat(path + 'Fashion.mat')['X1'].astype(np.float32) | ||
self.V2 = scipy.io.loadmat(path + 'Fashion.mat')['X2'].astype(np.float32) | ||
self.V3 = scipy.io.loadmat(path + 'Fashion.mat')['X3'].astype(np.float32) | ||
|
||
def __len__(self): | ||
return 10000 | ||
|
||
def __getitem__(self, idx): | ||
|
||
x1 = self.V1[idx].reshape(784) | ||
x2 = self.V2[idx].reshape(784) | ||
x3 = self.V3[idx].reshape(784) | ||
|
||
return [torch.from_numpy(x1), torch.from_numpy(x2), torch.from_numpy(x3)], self.Y[idx], torch.from_numpy(np.array(idx)).long() | ||
|
||
|
||
class Caltech(Dataset): | ||
def __init__(self, path, view): | ||
data = scipy.io.loadmat(path) | ||
scaler = MinMaxScaler() | ||
self.view1 = scaler.fit_transform(data['X1'].astype(np.float32)) | ||
self.view2 = scaler.fit_transform(data['X2'].astype(np.float32)) | ||
self.view3 = scaler.fit_transform(data['X3'].astype(np.float32)) | ||
self.view4 = scaler.fit_transform(data['X4'].astype(np.float32)) | ||
self.view5 = scaler.fit_transform(data['X5'].astype(np.float32)) | ||
self.labels = scipy.io.loadmat(path)['Y'].transpose() | ||
self.view = view | ||
|
||
def __len__(self): | ||
return 1400 | ||
|
||
def __getitem__(self, idx): | ||
if self.view == 2: | ||
return [torch.from_numpy( | ||
self.view1[idx]), torch.from_numpy(self.view2[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() | ||
if self.view == 3: | ||
return [torch.from_numpy(self.view1[idx]), torch.from_numpy( | ||
self.view2[idx]), torch.from_numpy(self.view5[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() | ||
if self.view == 4: | ||
return [torch.from_numpy(self.view1[idx]), torch.from_numpy(self.view2[idx]), torch.from_numpy( | ||
self.view5[idx]), torch.from_numpy(self.view4[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() | ||
if self.view == 5: | ||
return [torch.from_numpy(self.view1[idx]), torch.from_numpy( | ||
self.view2[idx]), torch.from_numpy(self.view5[idx]), torch.from_numpy( | ||
self.view4[idx]), torch.from_numpy(self.view3[idx])], torch.from_numpy(self.labels[idx]), torch.from_numpy(np.array(idx)).long() | ||
|
||
|
||
def load_data(dataset): | ||
if dataset == "BDGP": | ||
dataset = BDGP('./data/') | ||
dims = [1750, 79] | ||
view = 2 | ||
data_size = 2500 | ||
class_num = 5 | ||
elif dataset == "MNIST-USPS": | ||
dataset = MNIST_USPS('./data/') | ||
dims = [784, 784] | ||
view = 2 | ||
class_num = 10 | ||
data_size = 5000 | ||
elif dataset == "CCV": | ||
dataset = CCV('./data/') | ||
dims = [5000, 5000, 4000] | ||
view = 3 | ||
data_size = 6773 | ||
class_num = 20 | ||
elif dataset == "Fashion": | ||
dataset = Fashion('./data/') | ||
dims = [784, 784, 784] | ||
view = 3 | ||
data_size = 10000 | ||
class_num = 10 | ||
elif dataset == "Caltech-2V": | ||
dataset = Caltech('data/Caltech-5V.mat', view=2) | ||
dims = [40, 254] | ||
view = 2 | ||
data_size = 1400 | ||
class_num = 7 | ||
elif dataset == "Caltech-3V": | ||
dataset = Caltech('data/Caltech-5V.mat', view=3) | ||
dims = [40, 254, 928] | ||
view = 3 | ||
data_size = 1400 | ||
class_num = 7 | ||
elif dataset == "Caltech-4V": | ||
dataset = Caltech('data/Caltech-5V.mat', view=4) | ||
dims = [40, 254, 928, 512] | ||
view = 4 | ||
data_size = 1400 | ||
class_num = 7 | ||
elif dataset == "Caltech-5V": | ||
dataset = Caltech('data/Caltech-5V.mat', view=5) | ||
dims = [40, 254, 928, 512, 1984] | ||
view = 5 | ||
data_size = 1400 | ||
class_num = 7 | ||
else: | ||
raise NotImplementedError | ||
return dataset, dims, view, data_size, class_num |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
|
||
class Loss(nn.Module): | ||
def __init__(self, batch_size, class_num, temperature_f, temperature_l, device): | ||
super(Loss, self).__init__() | ||
self.batch_size = batch_size | ||
self.class_num = class_num | ||
self.temperature_f = temperature_f | ||
self.temperature_l = temperature_l | ||
self.device = device | ||
|
||
self.mask = self.mask_correlated_samples(batch_size) | ||
self.similarity = nn.CosineSimilarity(dim=2) | ||
self.criterion = nn.CrossEntropyLoss(reduction="sum") | ||
|
||
def mask_correlated_samples(self, N): | ||
mask = torch.ones((N, N)) | ||
mask = mask.fill_diagonal_(0) | ||
for i in range(N//2): | ||
mask[i, N//2 + i] = 0 | ||
mask[N//2 + i, i] = 0 | ||
mask = mask.bool() | ||
return mask | ||
|
||
def forward_feature(self, h_i, h_j): | ||
N = 2 * self.batch_size | ||
h = torch.cat((h_i, h_j), dim=0) | ||
|
||
sim = torch.matmul(h, h.T) / self.temperature_f | ||
sim_i_j = torch.diag(sim, self.batch_size) | ||
sim_j_i = torch.diag(sim, -self.batch_size) | ||
|
||
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) | ||
mask = self.mask_correlated_samples(N) | ||
negative_samples = sim[mask].reshape(N, -1) | ||
|
||
labels = torch.zeros(N).to(positive_samples.device).long() | ||
logits = torch.cat((positive_samples, negative_samples), dim=1) | ||
loss = self.criterion(logits, labels) | ||
loss /= N | ||
return loss | ||
|
||
def forward_label(self, q_i, q_j): | ||
p_i = q_i.sum(0).view(-1) | ||
p_i /= p_i.sum() | ||
ne_i = math.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum() | ||
p_j = q_j.sum(0).view(-1) | ||
p_j /= p_j.sum() | ||
ne_j = math.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum() | ||
entropy = ne_i + ne_j | ||
|
||
q_i = q_i.t() | ||
q_j = q_j.t() | ||
N = 2 * self.class_num | ||
q = torch.cat((q_i, q_j), dim=0) | ||
|
||
sim = self.similarity(q.unsqueeze(1), q.unsqueeze(0)) / self.temperature_l | ||
sim_i_j = torch.diag(sim, self.class_num) | ||
sim_j_i = torch.diag(sim, -self.class_num) | ||
|
||
positive_clusters = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) | ||
mask = self.mask_correlated_samples(N) | ||
negative_clusters = sim[mask].reshape(N, -1) | ||
|
||
labels = torch.zeros(N).to(positive_clusters.device).long() | ||
logits = torch.cat((positive_clusters, negative_clusters), dim=1) | ||
loss = self.criterion(logits, labels) | ||
loss /= N | ||
return loss + entropy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from sklearn.metrics import v_measure_score, adjusted_rand_score, accuracy_score | ||
from sklearn.cluster import KMeans | ||
from scipy.optimize import linear_sum_assignment | ||
from torch.utils.data import DataLoader | ||
import numpy as np | ||
import torch | ||
|
||
|
||
def cluster_acc(y_true, y_pred): | ||
y_true = y_true.astype(np.int64) | ||
assert y_pred.size == y_true.size | ||
D = max(y_pred.max(), y_true.max()) + 1 | ||
w = np.zeros((D, D), dtype=np.int64) | ||
for i in range(y_pred.size): | ||
w[y_pred[i], y_true[i]] += 1 | ||
u = linear_sum_assignment(w.max() - w) | ||
ind = np.concatenate([u[0].reshape(u[0].shape[0], 1), u[1].reshape([u[0].shape[0], 1])], axis=1) | ||
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size | ||
|
||
|
||
def purity(y_true, y_pred): | ||
y_voted_labels = np.zeros(y_true.shape) | ||
labels = np.unique(y_true) | ||
ordered_labels = np.arange(labels.shape[0]) | ||
for k in range(labels.shape[0]): | ||
y_true[y_true == labels[k]] = ordered_labels[k] | ||
labels = np.unique(y_true) | ||
bins = np.concatenate((labels, [np.max(labels)+1]), axis=0) | ||
|
||
for cluster in np.unique(y_pred): | ||
hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins) | ||
winner = np.argmax(hist) | ||
y_voted_labels[y_pred == cluster] = winner | ||
|
||
return accuracy_score(y_true, y_voted_labels) | ||
|
||
|
||
def evaluate(label, pred): | ||
nmi = v_measure_score(label, pred) | ||
ari = adjusted_rand_score(label, pred) | ||
acc = cluster_acc(label, pred) | ||
pur = purity(label, pred) | ||
return nmi, ari, acc, pur | ||
|
||
|
||
def inference(loader, model, device, view, data_size): | ||
""" | ||
:return: | ||
total_pred: prediction among all modalities | ||
pred_vectors: predictions of each modality, list | ||
labels_vector: true label | ||
Hs: high-level features | ||
Zs: low-level features | ||
""" | ||
model.eval() | ||
soft_vector = [] | ||
pred_vectors = [] | ||
Hs = [] | ||
Zs = [] | ||
for v in range(view): | ||
pred_vectors.append([]) | ||
Hs.append([]) | ||
Zs.append([]) | ||
labels_vector = [] | ||
|
||
for step, (xs, y, _) in enumerate(loader): | ||
for v in range(view): | ||
xs[v] = xs[v].to(device) | ||
with torch.no_grad(): | ||
qs, preds = model.forward_cluster(xs) | ||
hs, _, _, zs = model.forward(xs) | ||
q = sum(qs)/view | ||
for v in range(view): | ||
hs[v] = hs[v].detach() | ||
zs[v] = zs[v].detach() | ||
preds[v] = preds[v].detach() | ||
pred_vectors[v].extend(preds[v].cpu().detach().numpy()) | ||
Hs[v].extend(hs[v].cpu().detach().numpy()) | ||
Zs[v].extend(zs[v].cpu().detach().numpy()) | ||
q = q.detach() | ||
soft_vector.extend(q.cpu().detach().numpy()) | ||
labels_vector.extend(y.numpy()) | ||
|
||
labels_vector = np.array(labels_vector).reshape(data_size) | ||
total_pred = np.argmax(np.array(soft_vector), axis=1) | ||
for v in range(view): | ||
Hs[v] = np.array(Hs[v]) | ||
Zs[v] = np.array(Zs[v]) | ||
pred_vectors[v] = np.array(pred_vectors[v]) | ||
return total_pred, pred_vectors, Hs, labels_vector, Zs | ||
|
||
|
||
def valid(model, device, dataset, view, data_size, class_num, eval_h=False): | ||
test_loader = DataLoader( | ||
dataset, | ||
batch_size=256, | ||
shuffle=False, | ||
) | ||
total_pred, pred_vectors, high_level_vectors, labels_vector, low_level_vectors = inference(test_loader, model, device, view, data_size) | ||
if eval_h: | ||
print("Clustering results on low-level features of each view:") | ||
|
||
for v in range(view): | ||
kmeans = KMeans(n_clusters=class_num, n_init=100) | ||
y_pred = kmeans.fit_predict(low_level_vectors[v]) | ||
nmi, ari, acc, pur = evaluate(labels_vector, y_pred) | ||
print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, | ||
v + 1, nmi, | ||
v + 1, ari, | ||
v + 1, pur)) | ||
|
||
print("Clustering results on high-level features of each view:") | ||
|
||
for v in range(view): | ||
kmeans = KMeans(n_clusters=class_num, n_init=100) | ||
y_pred = kmeans.fit_predict(high_level_vectors[v]) | ||
nmi, ari, acc, pur = evaluate(labels_vector, y_pred) | ||
print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v + 1, acc, | ||
v + 1, nmi, | ||
v + 1, ari, | ||
v + 1, pur)) | ||
print("Clustering results on cluster assignments of each view:") | ||
for v in range(view): | ||
nmi, ari, acc, pur = evaluate(labels_vector, pred_vectors[v]) | ||
print('ACC{} = {:.4f} NMI{} = {:.4f} ARI{} = {:.4f} PUR{}={:.4f}'.format(v+1, acc, | ||
v+1, nmi, | ||
v+1, ari, | ||
v+1, pur)) | ||
|
||
print("Clustering results on semantic labels: " + str(labels_vector.shape[0])) | ||
nmi, ari, acc, pur = evaluate(labels_vector, total_pred) | ||
print('ACC = {:.4f} NMI = {:.4f} ARI = {:.4f} PUR={:.4f}'.format(acc, nmi, ari, pur)) | ||
return acc, nmi, pur |
Oops, something went wrong.