-
Notifications
You must be signed in to change notification settings - Fork 4
/
instance_selection.py
150 lines (124 loc) · 5.3 KB
/
instance_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from embeddings import *
from density_estimators import *
def get_embedder(embedding):
if embedding == 'inceptionv3':
embedder = Inceptionv3Embedding().eval().cuda()
elif embedding == 'resnet50':
embedder = ResNet50Embedding().eval().cuda()
elif embedding == 'places365':
embedder = Places365Embedding().eval().cuda()
elif embedding == 'resnextwsl':
embedder = ResNextWSL().eval().cuda()
elif embedding == 'swav':
embedder = SwAVEmbedding().eval().cuda()
if torch.cuda.current_device() > 1:
embedder = nn.DataParallel(embedder)
return embedder
def get_embeddings_from_loader(dataloader,
embedder,
return_labels=False,
verbose=False):
embeddings = []
labels = []
with torch.no_grad():
if verbose:
dataloader = tqdm(dataloader, desc='Extracting embeddings')
for data in dataloader:
if len(data) == 2:
images, label = data
images = images.cuda()
else:
images = data.cuda()
labels.append(torch.zeros(len(images)))
embed = embedder(images)
embeddings.append(embed.cpu())
labels.append(label)
embeddings = torch.cat(embeddings, dim=0)
labels = torch.cat(labels, dim=0)
if return_labels:
return embeddings, labels
else:
return embeddings
def get_keep_indices(embeddings,
labels,
density_measure,
retention_ratio,
verbose=False):
keep_indices = []
unique_labels = torch.unique(labels)
if verbose:
unique_labels = tqdm(unique_labels, desc='Scoring instances')
for label in unique_labels:
class_indices = torch.where(labels == label)[0]
class_embeddings = embeddings[class_indices]
if density_measure == 'ppca':
scores = PPCA(class_embeddings)
elif density_measure == 'gaussian':
scores = GaussianModel(class_embeddings)
elif density_measure == 'nn_dist':
# make negative so that larger values are better
scores = -compute_nearest_neighbour_distances(class_embeddings,
nearest_k=5)
cutoff = np.percentile(scores, (100 - retention_ratio))
keep_mask = torch.from_numpy(scores > cutoff).bool()
keep_indices.append(class_indices[keep_mask])
keep_indices = torch.cat(keep_indices, dim=0)
return keep_indices
def select_instances(dataset,
retention_ratio,
embedding='inceptionv3',
density_measure='gaussian',
indices_filepath=None,
batch_size=128,
num_workers=4):
"""
Args:
dataset (Dataset): dataset to be subsampled with instance selection.
retention_ratio (float): percentage of the dataset to keep.
embedding (str): embedding function for extracting image features.
Options are 'inceptionv3', 'resnet50', 'places365', 'resnextwsl',
and 'swav'.
density_measure (str): scoring function to use when determining whether
to select instances. Options are 'ppca', 'gaussian', or 'nn_dist'.
indices_filepath (str): filepath for saving indices so that they don't
need to be recomputed each time. Should have .pkl file extension.
batch_size (int): how many samples per batch to load.
num_workers (int): how many subprocesses to use for data loading.
Returns:
instance_selected_dataset (Subset): subset of original dataset
containing the best scoring instances.
"""
assert (retention_ratio > 0) and retention_ratio <= 100, \
'retention_ratio should be betwee 0 and 100'
if retention_ratio == 100:
print('Retention ratio = 100, skipping dataset reduction')
return dataset
if indices_filepath is not None:
if os.path.exists(indices_filepath):
keep_indices = torch.load(indices_filepath)
return Subset(dataset, keep_indices)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
pin_memory=True)
embedder = get_embedder(embedding)
embeddings, labels = get_embeddings_from_loader(dataloader,
embedder,
return_labels=True,
verbose=True)
keep_indices = get_keep_indices(embeddings,
labels,
density_measure,
retention_ratio=retention_ratio,
verbose=True)
if indices_filepath is not None:
torch.save(keep_indices, indices_filepath)
return Subset(dataset, keep_indices)