-
Notifications
You must be signed in to change notification settings - Fork 0
/
coreset.py
133 lines (108 loc) · 4.7 KB
/
coreset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import numpy as np
class Base:
def __init__(self, data, args, device='cuda', **kwargs):
self.data = data
self.args = args
self.device = device
n = int(data.feat_train.shape[0] * args.reduction_rate)
d = data.feat_train.shape[1]
self.nnodes_syn = n
self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
def generate_labels_syn(self, data):
from collections import Counter
counter = Counter(data.labels_train)
num_class_dict = {}
n = len(data.labels_train)
sorted_counter = sorted(counter.items(), key=lambda x:x[1])
sum_ = 0
labels_syn = []
self.syn_class_indices = {}
for ix, (c, num) in enumerate(sorted_counter):
if ix == len(sorted_counter) - 1:
num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
labels_syn += [c] * num_class_dict[c]
else:
num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
sum_ += num_class_dict[c]
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
labels_syn += [c] * num_class_dict[c]
self.num_class_dict = num_class_dict
return labels_syn
def select(self):
return
class KCenter(Base):
def __init__(self, data, args, device='cuda', **kwargs):
super(KCenter, self).__init__(data, args, device='cuda', **kwargs)
def select(self, embeds, inductive=False):
# feature: embeds
# kcenter # class by class
num_class_dict = self.num_class_dict
if inductive:
idx_train = np.arange(len(self.data.idx_train))
else:
idx_train = self.data.idx_train
labels_train = self.data.labels_train
idx_selected = []
for class_id, cnt in num_class_dict.items():
idx = idx_train[labels_train==class_id]
feature = embeds[idx]
mean = torch.mean(feature, dim=0, keepdim=True)
# dis = distance(feature, mean)[:,0]
dis = torch.cdist(feature, mean)[:,0]
rank = torch.argsort(dis)
idx_centers = rank[:1].tolist()
for i in range(cnt-1):
feature_centers = feature[idx_centers]
dis_center = torch.cdist(feature, feature_centers)
dis_min, _ = torch.min(dis_center, dim=-1)
id_max = torch.argmax(dis_min).item()
idx_centers.append(id_max)
idx_selected.append(idx[idx_centers])
# return np.array(idx_selected).reshape(-1)
return np.hstack(idx_selected)
class Herding(Base):
def __init__(self, data, args, device='cuda', **kwargs):
super(Herding, self).__init__(data, args, device='cuda', **kwargs)
def select(self, embeds, inductive=False):
num_class_dict = self.num_class_dict
if inductive:
idx_train = np.arange(len(self.data.idx_train))
else:
idx_train = self.data.idx_train
labels_train = self.data.labels_train
idx_selected = []
# herding # class by class
for class_id, cnt in num_class_dict.items():
idx = idx_train[labels_train==class_id]
features = embeds[idx]
mean = torch.mean(features, dim=0, keepdim=True)
selected = []
idx_left = np.arange(features.shape[0]).tolist()
for i in range(cnt):
det = mean*(i+1) - torch.sum(features[selected], dim=0)
dis = torch.cdist(det, features[idx_left])
id_min = torch.argmin(dis)
selected.append(idx_left[id_min])
del idx_left[id_min]
idx_selected.append(idx[selected])
# return np.array(idx_selected).reshape(-1)
return np.hstack(idx_selected)
class Random(Base):
def __init__(self, data, args, device='cuda', **kwargs):
super(Random, self).__init__(data, args, device='cuda', **kwargs)
def select(self, embeds, inductive=False):
num_class_dict = self.num_class_dict
if inductive:
idx_train = np.arange(len(self.data.idx_train))
else:
idx_train = self.data.idx_train
labels_train = self.data.labels_train
idx_selected = []
for class_id, cnt in num_class_dict.items():
idx = idx_train[labels_train==class_id]
selected = np.random.permutation(idx)
idx_selected.append(selected[:cnt])
# return np.array(idx_selected).reshape(-1)
return np.hstack(idx_selected)