Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

TemporalModelCommitments #3

Open
wants to merge 59 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
55fa08b
add temporal model commitment
fksato Mar 19, 2019
4189ae4
complete temporal map
fksato Mar 20, 2019
64790c7
complete temporal map
fksato Mar 20, 2019
b7f2480
complete temporal map
fksato Mar 20, 2019
e418b30
complete temporal map
fksato Mar 20, 2019
dc4c152
code clean-up for TemporalModelCommitments
fksato Mar 21, 2019
f5ecd21
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
7418a86
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
4aee6bb
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
76e9c8b
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
0e245b4
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
dc16eb6
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
8d272c9
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
40fd5d1
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
262e0ae
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
79a4a63
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
45056d6
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
1ece6b0
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
0060c57
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
163bb74
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
e97ca1e
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
a34cf2f
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
25bac61
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
815db33
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
5e8fe13
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
1ef87c6
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
ac30543
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
d5dfc33
code clean-up TemporalModelCommitment
fksato Mar 21, 2019
2cb3219
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
135c93a
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
f6e8037
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
95ebcbf
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
bc134db
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
ad01a0a
add result caching parameters to temporal model commitments
fksato Mar 23, 2019
f2b75e5
implement result caching for temporal model commitment data
fksato Mar 23, 2019
203deb4
code/repository clean up
fksato Mar 23, 2019
fe76c1c
add temporal_map.py
fksato Mar 25, 2019
9b6913f
add temporal_maps unit test
fksato Mar 29, 2019
250a679
fix merge conflict, add temporal_maps pytest
fksato Mar 29, 2019
0471651
Merge branch 'master' of github.com:brain-score/model-tools
fksato Mar 29, 2019
bfd21f9
merge with upstream
fksato Mar 30, 2019
f88c2b6
Merge branch 'master' of github.com:brain-score/model-tools
fksato Apr 1, 2019
d69b0a7
Delete .gitignore
fksato Apr 1, 2019
66dcb79
Delete conftest.cpython-36-PYTEST.pyc
fksato Apr 1, 2019
5f20b6e
Delete __init__.cpython-36.pyc
fksato Apr 1, 2019
e540772
first pass code cleanup
fksato Apr 2, 2019
849b86a
add testing assemblies for temporal tests
fksato Apr 3, 2019
2cea388
Merge remote-tracking branch 'upstream/master'
fksato Apr 3, 2019
ca3782e
add simplified temporal assemblies
fksato Apr 3, 2019
521181c
add pls_regression comparison to test temporal maps
fksato Apr 4, 2019
1ba7cbd
add correct stimulus set paths to temporal testing assemblies
fksato Apr 4, 2019
bc7e9f9
add correct stimulus set paths to temporal testing assemblies
fksato Apr 4, 2019
c5d20ef
Merge remote-tracking branch 'upstream/master'
fksato Apr 4, 2019
9396676
remove unnecessary imports
fksato Apr 4, 2019
07a8238
Merge remote-tracking branch 'upstream/master'
fksato Apr 5, 2019
0ecb126
add new assembly/stimulus for temporal mapping tests
fksato Apr 5, 2019
11454ea
finish pytest
fksato Apr 5, 2019
115b03e
add absolute image paths to testing stimulus set
fksato Apr 6, 2019
d4c8d54
rename pytorch model in temporal testing
fksato Apr 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions model_tools/brain_transformation/temporal_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Optional

from brainio_base.assemblies import merge_data_arrays

from model_tools.brain_transformation import LayerSelection

from brainscore.model_interface import BrainModel
from brainscore.metrics.regression import pls_regression

from result_caching import store, store_dict

class TemporalModelCommitment(BrainModel):
def __init__(self, identifier, base_model, layers, region_layer_map: Optional[dict] = None):
self.layers = layers
self.identifier = identifier or None
self.base_model = base_model
self.region_layer_map = region_layer_map or {}
self.recorded_regions = []

self.time_bins = None
self._temporal_maps = {}
self._layer_regions = None

def make_temporal(self, assembly):
assert self.region_layer_map # force commit_region to come before
assert len(set(assembly.time_bin.values)) > 1 # force temporal recordings/assembly

temporal_mapped_regions = set(assembly['region'].values)

temporal_mapped_regions = list(set(self.region_layer_map.keys()).intersection(self.region_layer_map.keys()))
fksato marked this conversation as resolved.
Show resolved Hide resolved
layer_regions = {self.region_layer_map[region]: region for region in temporal_mapped_regions}

stimulus_set = assembly.stimulus_set

activations = self.base_model(stimulus_set, layers=list(layer_regions.keys()))
activations = self._set_region_coords(activations, layer_regions)

self._temporal_maps = self._set_temporal_maps(self.identifier, temporal_mapped_regions, activations, assembly)

def look_at(self, stimuli):
layer_regions = {self.region_layer_map[region]: region for region in self.recorded_regions}
assert len(layer_regions) == len(self.recorded_regions), f"duplicate layers for {self.recorded_regions}"
activations = self.base_model(stimuli, layers=list(layer_regions.keys()))

activations = self._set_region_coords(activations ,layer_regions)
return self._temporal_activations(self.identifier, activations)

@store(identifier_ignore=['assembly'])
def _temporal_activations(self, identifier, assembly):
temporal_assembly = []
for region in self.recorded_regions:
temporal_regressors = self._temporal_maps[region]
region_activations = assembly.sel(region=region)
for time_bin in self.time_bins:
regressor = temporal_regressors[time_bin]
regressed_act = regressor.predict(region_activations)
regressed_act = self._package_temporal(time_bin, region, regressed_act)
temporal_assembly.append(regressed_act)
temporal_assembly = merge_data_arrays(temporal_assembly)
return temporal_assembly

@store_dict(dict_key='temporal_mapped_regions', identifier_ignore=['temporal_mapped_regions', 'activations' ,'assembly'])
def _set_temporal_maps(self, identifier, temporal_mapped_regions, activations, assembly):
temporal_maps = {}
for region in temporal_mapped_regions:
time_bin_regressor = {}
region_activations = activations.sel(region=region)
for time_bin in assembly.time_bin.values:
target_assembly = assembly.sel(time_bin=time_bin ,region=region)
regressor = pls_regression(neuroid_coord=('neuroid_id' ,'layer' ,'region'))
regressor.fit(region_activations, target_assembly)
time_bin_regressor[time_bin] = regressor
temporal_maps[region] = time_bin_regressor
return temporal_maps

def _set_region_coords(self, activations, layer_regions):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work on a fix for this soon, so that you hopefully won't need this here anymore (brain-score/brainio_base#1)

coords = { 'region' : ( ('neuroid'), [layer_regions[layer] for layer in activations['layer'].values]) }
activations = activations.assign_coords(**coords)
activations = activations.set_index({'neuroid' :'region'}, append=True)
return activations

def _package_temporal(self, time_bin, region, assembly):
assert len(time_bin) == 2
assembly = assembly.expand_dims('time_bin', axis=-1)
coords = {
'time_bin_start': (('time_bin'), [time_bin[0]])
, 'time_bin_end': (('time_bin'), [time_bin[1]])
, 'region' : ( ('neuroid'), [region] * assembly.shape[1])
}
assembly = assembly.assign_coords(**coords)
assembly = assembly.set_index(time_bin=['time_bin_start', 'time_bin_end'], neuroid='region', append=True)
return assembly

def start_temporal_recording(self, recording_target, time_bins):
fksato marked this conversation as resolved.
Show resolved Hide resolved
assert self._temporal_maps
assert self.region_layer_map
assert recording_target in self._temporal_maps.keys()
if time_bins is not None:
assert set(self._temporal_maps[recording_target].keys()).issuperset(set(time_bins))
else:
time_bins = self._temporal_maps[recording_target].keys()

self.recorded_regions = [recording_target]
self.time_bins = time_bins

def start_recording(self, recording_target):
assert self._temporal_maps
assert self.region_layer_map
assert recording_target in self._temporal_maps.keys()
if self.time_bins is None:
self.time_bins = self._temporal_maps[recording_target].keys()
self.recorded_regions = [recording_target]

def commit_region(self, region, assembly):
fksato marked this conversation as resolved.
Show resolved Hide resolved
layer_selection = LayerSelection(model_identifier=self.identifier,
activations_model=self.base_model, layers=self.layers)
best_layer = layer_selection(assembly)
self.region_layer_map[region] = best_layer

def receptive_fields(self, record=True):
fksato marked this conversation as resolved.
Show resolved Hide resolved
pass
88 changes: 88 additions & 0 deletions tests/brain_transformation/test_temporal_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import functools
import numpy as np
from pytest import approx

from model_tools.activations import PytorchWrapper
from model_tools.brain_transformation.temporal_map import TemporalModelCommitment

from brainscore import get_assembly
from brainscore.benchmarks.loaders import AssemblyLoader, DicarloMajaj2015Loader, DicarloMajaj2015ITLoader

# create a test assembly:
class DicarloMajaj2015TemporalLoader(AssemblyLoader): # needed to add variation argument
def __init__(self, name='dicarlo.Majaj2015.temporal'):
mschrimpf marked this conversation as resolved.
Show resolved Hide resolved
super(DicarloMajaj2015TemporalLoader, self).__init__(name=name)
self._helper = DicarloMajaj2015Loader()

def __call__(self, average_repetition=True, variation=6):
assembly = get_assembly(name='dicarlo.Majaj2015.temporal')
assembly = self._helper._filter_erroneous_neuroids(assembly)
assembly = assembly.sel(variation=variation)
assembly = assembly.transpose('presentation', 'neuroid', 'time_bin')
if average_repetition:
assembly = self._helper.average_repetition(assembly)
return assembly

def get_stim(assembly):
return assembly.stimulus_set[assembly.stimulus_set['image_id'].isin(assembly['image_id'].values)]

def pytorch_custom():
mschrimpf marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch import nn
from model_tools.activations.pytorch import load_preprocess_images

class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3)
self.relu1 = torch.nn.ReLU()
linear_input_size = np.power((224 - 3 + 2 * 0) / 1 + 1, 2) * 2
self.linear = torch.nn.Linear(int(linear_input_size), 1000)
self.relu2 = torch.nn.ReLU() # can't get named ReLU output otherwise

def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
x = self.relu2(x)
return x

preprocessing = functools.partial(load_preprocess_images, image_size=224)
return PytorchWrapper(model=MyModel(), preprocessing=preprocessing)

class TestTemporalModelCommitment:
test_data = [(pytorch_custom, ['linear', 'relu2'], 'test_temporal', 'relu2', 'IT', 1, 39, 29)]
@pytest.mark.parametrize("model_ctr, layers, identifier, expected_best_layer, region, expected_region_count,"
" expected_time_bin_count, expected_recorded_time_bin_cnt"
, test_data)
def test(self, model_ctr, layers, identifier, expected_best_layer, region
, expected_region_count, expected_recorded_regions, expected_time_bin_count, expected_recorded_time_bin_cnt):
train_test_assembly_loader = DicarloMajaj2015TemporalLoader()
commit_loader = DicarloMajaj2015ITLoader()

training_assembly = train_test_assembly_loader(variation=3)
commit_assembly = commit_loader(average_repetition=False)
validation_assembly = train_test_assembly_loader(variation=6)

extractor = pytorch_custom()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use model_ctr


t_bins = [t for t in training_assembly.time_bin.values if t[0] >= 0]

temporal_model = TemporalModelCommitment(identifier, extractor, layers)
# commit region:
temporal_model.commit_region(region, commit_assembly)
assert temporal_model.region_layer_map[region] == expected_best_layer
# make temporal:
temporal_model.make_temporal(training_assembly)
assert len(temporal_model._temporal_maps.keys()) == expected_region_count
assert len(temporal_model._temporal_maps[region].keys()) == expected_time_bin_count
# start recording:
temporal_model.start_temporal_recording(region, t_bins)
assert temporal_model.recorded_regions == expected_recorded_regions
# look at:
stim = get_stim(commit_assembly)
temporal_activations = temporal_model.look_at(stim)
assert set(temporal_activations.region.values) == set(region)
assert len(set(temporal_activations.time_bin.values)) == expected_recorded_time_bin_cnt