-
Notifications
You must be signed in to change notification settings - Fork 15
/
tsne.py
52 lines (46 loc) · 1.49 KB
/
tsne.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
Description:
Author: voicebeer
Date: 2020-10-31 02:52:59
LastEditTime: 2020-10-31 03:04:46
'''
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.manifold import TSNE
class FeatureVisualize(object):
'''
Visualize features by TSNE
'''
def __init__(self, features, labels):
'''
features: (m,n)
labels: (m,)
'''
self.features = features
self.labels = labels
def plot_tsne(self, save_eps=False):
''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file.
'''
tsne = TSNE(n_components=2, init='pca', random_state=0)
features = tsne.fit_transform(self.features)
x_min, x_max = np.min(features, 0), np.max(features, 0)
data = (features - x_min) / (x_max - x_min)
del features
for i in range(data.shape[0]):
plt.text(data[i, 0], data[i, 1], str(self.labels[i]),
color=plt.cm.Set1(self.labels[i] / 10.),
fontdict={'weight': 'bold', 'size': 5})
plt.xticks([])
plt.yticks([])
plt.title('T-SNE')
if save_eps:
plt.savefig('tsne.eps', dpi=600, format='eps')
plt.show()
# if __name__ == '__main__':
# digits = datasets.load_digits(n_class=5)
# features, labels = digits.data, digits.target
# print(features.shape)
# print(labels.shape)
# vis = FeatureVisualize(features, labels)
# vis.plot_tsne(save_eps=True)