This repository has been archived by the owner on Jan 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathcore.py
320 lines (269 loc) · 15.2 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import copy
import os
import functools
import logging
from collections import OrderedDict
from multiprocessing.pool import ThreadPool
import numpy as np
from tqdm.auto import tqdm
from brainio.assemblies import NeuroidAssembly, walk_coords
from brainio.stimuli import StimulusSet
from model_tools.utils import fullname
from result_caching import store_xarray
class Defaults:
batch_size = 64
class ActivationsExtractorHelper:
def __init__(self, get_activations, preprocessing, identifier=False, batch_size=Defaults.batch_size):
"""
:param identifier: an activations identifier for the stored results file. False to disable saving.
"""
self._logger = logging.getLogger(fullname(self))
self._batch_size = batch_size
self.identifier = identifier
self.get_activations = get_activations
self.preprocess = preprocessing or (lambda x: x)
self._stimulus_set_hooks = {}
self._batch_activations_hooks = {}
def __call__(self, stimuli, layers, stimuli_identifier=None):
"""
:param stimuli_identifier: a stimuli identifier for the stored results file. False to disable saving.
"""
if isinstance(stimuli, StimulusSet):
return self.from_stimulus_set(stimulus_set=stimuli, layers=layers, stimuli_identifier=stimuli_identifier)
else:
return self.from_paths(stimuli_paths=stimuli, layers=layers, stimuli_identifier=stimuli_identifier)
def from_stimulus_set(self, stimulus_set, layers, stimuli_identifier=None):
"""
:param stimuli_identifier: a stimuli identifier for the stored results file.
False to disable saving. None to use `stimulus_set.identifier`
"""
if stimuli_identifier is None and hasattr(stimulus_set, 'identifier'):
stimuli_identifier = stimulus_set.identifier
for hook in self._stimulus_set_hooks.copy().values(): # copy to avoid stale handles
stimulus_set = hook(stimulus_set)
stimuli_paths = [str(stimulus_set.get_stimulus(stimulus_id)) for stimulus_id in stimulus_set['stimulus_id']]
activations = self.from_paths(stimuli_paths=stimuli_paths, layers=layers, stimuli_identifier=stimuli_identifier)
activations = attach_stimulus_set_meta(activations, stimulus_set)
return activations
def from_paths(self, stimuli_paths, layers, stimuli_identifier=None):
if layers is None:
layers = ['logits']
if self.identifier and stimuli_identifier:
fnc = functools.partial(self._from_paths_stored,
identifier=self.identifier, stimuli_identifier=stimuli_identifier)
else:
self._logger.debug(f"self.identifier `{self.identifier}` or stimuli_identifier {stimuli_identifier} "
f"are not set, will not store")
fnc = self._from_paths
# In case stimuli paths are duplicates (e.g. multiple trials), we first reduce them to only the paths that need
# to be run individually, compute activations for those, and then expand the activations to all paths again.
# This is done here, before storing, so that we only store the reduced activations.
reduced_paths = self._reduce_paths(stimuli_paths)
activations = fnc(layers=layers, stimuli_paths=reduced_paths)
activations = self._expand_paths(activations, original_paths=stimuli_paths)
return activations
@store_xarray(identifier_ignore=['stimuli_paths', 'layers'], combine_fields={'layers': 'layer'})
def _from_paths_stored(self, identifier, layers, stimuli_identifier, stimuli_paths):
return self._from_paths(layers=layers, stimuli_paths=stimuli_paths)
def _from_paths(self, layers, stimuli_paths):
if len(layers) == 0:
raise ValueError("No layers passed to retrieve activations from")
self._logger.info('Running stimuli')
layer_activations = self._get_activations_batched(stimuli_paths, layers=layers, batch_size=self._batch_size)
self._logger.info('Packaging into assembly')
return self._package(layer_activations, stimuli_paths)
def _reduce_paths(self, stimuli_paths):
return list(set(stimuli_paths))
def _expand_paths(self, activations, original_paths):
activations_paths = activations['stimulus_path'].values
argsort_indices = np.argsort(activations_paths)
sorted_x = activations_paths[argsort_indices]
sorted_index = np.searchsorted(sorted_x, original_paths)
index = [argsort_indices[i] for i in sorted_index]
return activations[{'stimulus_path': index}]
def register_batch_activations_hook(self, hook):
r"""
The hook will be called every time a batch of activations is retrieved.
The hook should have the following signature::
hook(batch_activations) -> batch_activations
The hook should return new batch_activations which will be used in place of the previous ones.
"""
handle = HookHandle(self._batch_activations_hooks)
self._batch_activations_hooks[handle.id] = hook
return handle
def register_stimulus_set_hook(self, hook):
r"""
The hook will be called every time before a stimulus set is processed.
The hook should have the following signature::
hook(stimulus_set) -> stimulus_set
The hook should return a new stimulus_set which will be used in place of the previous one.
"""
handle = HookHandle(self._stimulus_set_hooks)
self._stimulus_set_hooks[handle.id] = hook
return handle
def _get_activations_batched(self, paths, layers, batch_size):
layer_activations = None
for batch_start in tqdm(range(0, len(paths), batch_size), unit_scale=batch_size, desc="activations"):
batch_end = min(batch_start + batch_size, len(paths))
batch_inputs = paths[batch_start:batch_end]
batch_activations = self._get_batch_activations(batch_inputs, layer_names=layers, batch_size=batch_size)
for hook in self._batch_activations_hooks.copy().values(): # copy to avoid handle re-enabling messing with the loop
batch_activations = hook(batch_activations)
if layer_activations is None:
layer_activations = copy.copy(batch_activations)
else:
for layer_name, layer_output in batch_activations.items():
layer_activations[layer_name] = np.concatenate((layer_activations[layer_name], layer_output))
return layer_activations
def _get_batch_activations(self, inputs, layer_names, batch_size):
inputs, num_padding = self._pad(inputs, batch_size)
preprocessed_inputs = self.preprocess(inputs)
activations = self.get_activations(preprocessed_inputs, layer_names)
assert isinstance(activations, OrderedDict)
activations = self._unpad(activations, num_padding)
return activations
def _pad(self, batch_images, batch_size):
num_images = len(batch_images)
if num_images % batch_size == 0:
return batch_images, 0
num_padding = batch_size - (num_images % batch_size)
padding = np.repeat(batch_images[-1:], repeats=num_padding, axis=0)
return np.concatenate((batch_images, padding)), num_padding
def _unpad(self, layer_activations, num_padding):
return change_dict(layer_activations, lambda values: values[:-num_padding or None])
def _package(self, layer_activations, stimuli_paths):
shapes = [a.shape for a in layer_activations.values()]
self._logger.debug(f"Activations shapes: {shapes}")
self._logger.debug("Packaging individual layers")
layer_assemblies = [self._package_layer(single_layer_activations, layer=layer, stimuli_paths=stimuli_paths) for
layer, single_layer_activations in tqdm(layer_activations.items(), desc='layer packaging')]
# merge manually instead of using merge_data_arrays since `xarray.merge` is very slow with these large arrays
# complication: (non)neuroid_coords are taken from the structure of layer_assemblies[0] i.e. the 1st assembly;
# using these names/keys for all assemblies results in KeyError if the first layer contains flatten_coord_names
# (see _package_layer) not present in later layers, e.g. first layer = conv, later layer = transformer layer
self._logger.debug(f"Merging {len(layer_assemblies)} layer assemblies")
model_assembly = np.concatenate([a.values for a in layer_assemblies],
axis=layer_assemblies[0].dims.index('neuroid'))
nonneuroid_coords = {coord: (dims, values) for coord, dims, values in walk_coords(layer_assemblies[0])
if set(dims) != {'neuroid'}}
neuroid_coords = {coord: [dims, values] for coord, dims, values in walk_coords(layer_assemblies[0])
if set(dims) == {'neuroid'}}
for layer_assembly in layer_assemblies[1:]:
for coord in neuroid_coords:
neuroid_coords[coord][1] = np.concatenate((neuroid_coords[coord][1], layer_assembly[coord].values))
assert layer_assemblies[0].dims == layer_assembly.dims
for dim in set(layer_assembly.dims) - {'neuroid'}:
for coord in layer_assembly[dim].coords:
assert (layer_assembly[coord].values == nonneuroid_coords[coord][1]).all()
neuroid_coords = {coord: (dims_values[0], dims_values[1]) # re-package as tuple instead of list for xarray
for coord, dims_values in neuroid_coords.items()}
model_assembly = type(layer_assemblies[0])(model_assembly, coords={**nonneuroid_coords, **neuroid_coords},
dims=layer_assemblies[0].dims)
return model_assembly
def _package_layer(self, layer_activations, layer, stimuli_paths):
assert layer_activations.shape[0] == len(stimuli_paths)
activations, flatten_indices = flatten(layer_activations, return_index=True) # collapse for single neuroid dim
flatten_coord_names = None
if flatten_indices.shape[1] == 1: # fully connected, e.g. classifier
# see comment in _package for an explanation why we cannot simply have 'channel' for the FC layer
flatten_coord_names = ['channel', 'channel_x', 'channel_y']
elif flatten_indices.shape[1] == 2: # Transformer, e.g. ViT
flatten_coord_names = ['channel', 'embedding']
elif flatten_indices.shape[1] == 3: # 2DConv, e.g. resnet
flatten_coord_names = ['channel', 'channel_x', 'channel_y']
elif flatten_indices.shape[1] == 4: # temporal sliding window, e.g. omnivron
flatten_coord_names = ['channel_temporal', 'channel_x', 'channel_y', 'channel']
else:
# we still package the activations, but are unable to provide channel information
self._logger.debug(f"Unknown layer activations shape {layer_activations.shape}, not inferring channels")
# build assembly
coords = {'stimulus_path': stimuli_paths,
'neuroid_num': ('neuroid', list(range(activations.shape[1]))),
'model': ('neuroid', [self.identifier] * activations.shape[1]),
'layer': ('neuroid', [layer] * activations.shape[1]),
}
if flatten_coord_names:
flatten_coords = {flatten_coord_names[i]: [sample_index[i] if i < flatten_indices.shape[1] else np.nan
for sample_index in flatten_indices]
for i in range(len(flatten_coord_names))}
coords = {**coords, **{coord: ('neuroid', values) for coord, values in flatten_coords.items()}}
layer_assembly = NeuroidAssembly(activations, coords=coords, dims=['stimulus_path', 'neuroid'])
neuroid_id = [".".join([f"{value}" for value in values]) for values in zip(*[
layer_assembly[coord].values for coord in ['model', 'layer', 'neuroid_num']])]
layer_assembly['neuroid_id'] = 'neuroid', neuroid_id
return layer_assembly
def insert_attrs(self, wrapper):
wrapper.from_stimulus_set = self.from_stimulus_set
wrapper.from_paths = self.from_paths
wrapper.register_batch_activations_hook = self.register_batch_activations_hook
wrapper.register_stimulus_set_hook = self.register_stimulus_set_hook
def change_dict(d, change_function, keep_name=False, multithread=False):
if not multithread:
map_fnc = map
else:
pool = ThreadPool()
map_fnc = pool.map
def apply_change(layer_values):
layer, values = layer_values
values = change_function(values) if not keep_name else change_function(layer, values)
return layer, values
results = map_fnc(apply_change, d.items())
results = OrderedDict(results)
if multithread:
pool.close()
return results
def lstrip_local(path):
parts = path.split(os.sep)
try:
start_index = parts.index('.brainio')
except ValueError: # not in list -- perhaps custom directory
return path
path = os.sep.join(parts[start_index:])
return path
def attach_stimulus_set_meta(assembly, stimulus_set):
stimulus_paths = [str(stimulus_set.get_stimulus(stimulus_id)) for stimulus_id in stimulus_set['stimulus_id']]
stimulus_paths = [lstrip_local(path) for path in stimulus_paths]
assembly_paths = [lstrip_local(path) for path in assembly['stimulus_path'].values]
assert (np.array(assembly_paths) == np.array(stimulus_paths)).all()
assembly['stimulus_path'] = stimulus_set['stimulus_id'].values
assembly = assembly.rename({'stimulus_path': 'stimulus_id'})
for column in stimulus_set.columns:
assembly[column] = 'stimulus_id', stimulus_set[column].values
assembly = assembly.stack(presentation=('stimulus_id',))
return assembly
class HookHandle:
next_id = 0
def __init__(self, hook_dict):
self.hook_dict = hook_dict
self.id = HookHandle.next_id
HookHandle.next_id += 1
self._saved_hook = None
def remove(self):
hook = self.hook_dict[self.id]
del self.hook_dict[self.id]
return hook
def disable(self):
self._saved_hook = self.remove()
def enable(self):
self.hook_dict[self.id] = self._saved_hook
self._saved_hook = None
def flatten(layer_output, return_index=False):
flattened = layer_output.reshape(layer_output.shape[0], -1)
if not return_index:
return flattened
def cartesian_product_broadcasted(*arrays):
"""
http://stackoverflow.com/a/11146645/190597
"""
broadcastable = np.ix_(*arrays)
broadcasted = np.broadcast_arrays(*broadcastable)
dtype = np.result_type(*arrays)
rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
out = np.empty(rows * cols, dtype=dtype)
start, end = 0, rows
for a in broadcasted:
out[start:end] = a.reshape(-1)
start, end = end, end + rows
return out.reshape(cols, rows).T
index = cartesian_product_broadcasted(*[np.arange(s, dtype='int') for s in layer_output.shape[1:]])
return flattened, index