diff --git a/model_tools/brain_transformation/temporal_map.py b/model_tools/brain_transformation/temporal_map.py new file mode 100644 index 0000000..646768f --- /dev/null +++ b/model_tools/brain_transformation/temporal_map.py @@ -0,0 +1,111 @@ +from typing import Optional + +from brainio_base.assemblies import merge_data_arrays + +from model_tools.brain_transformation import ModelCommitment + +from brainscore.model_interface import BrainModel +from brainscore.metrics.regression import pls_regression + +from result_caching import store, store_dict + +class TemporalModelCommitment(BrainModel): + def __init__(self, identifier, base_model, layers, region_layer_map: Optional[dict] = None): + self.layers = layers + self.identifier = identifier + self.base_model = base_model + # + self.model_commitment = ModelCommitment(self.identifier, self.base_model, self.layers) + self.commit_region = self.model_commitment.commit_region + self.region_assemblies = self.model_commitment.region_assemblies + self.region_layer_map = self.model_commitment.layer_model.region_layer_map + self.recorded_regions = [] + + self.time_bins = None + self._temporal_maps = {} + self._layer_regions = None + + def make_temporal(self, assembly): + if not self.region_layer_map: + for region in self.region_assemblies.keys(): + self.model_commitment.do_commit_region(region) + # assert self.region_layer_map # force commit_region to come before + assert len(set(assembly.time_bin.values)) > 1 # force temporal recordings/assembly + + temporal_mapped_regions = set(assembly['region'].values) + + temporal_mapped_regions = list(set(self.region_layer_map.keys()).intersection(self.region_layer_map.keys())) + layer_regions = {self.region_layer_map[region]: region for region in temporal_mapped_regions} + + stimulus_set = assembly.stimulus_set + + activations = self.base_model(stimulus_set, layers=list(layer_regions.keys())) + activations = self._set_region_coords(activations, layer_regions) + + self._temporal_maps = self._set_temporal_maps(self.identifier, temporal_mapped_regions, activations, assembly) + + def look_at(self, stimuli): + layer_regions = {self.region_layer_map[region]: region for region in self.recorded_regions} + assert len(layer_regions) == len(self.recorded_regions), f"duplicate layers for {self.recorded_regions}" + activations = self.base_model(stimuli, layers=list(layer_regions.keys())) + + activations = self._set_region_coords(activations ,layer_regions) + return self._temporal_activations(self.identifier, activations) + + @store(identifier_ignore=['assembly']) + def _temporal_activations(self, identifier, assembly): + temporal_assembly = [] + for region in self.recorded_regions: + temporal_regressors = self._temporal_maps[region] + region_activations = assembly.sel(region=region) + for time_bin in self.time_bins: + regressor = temporal_regressors[time_bin] + regressed_act = regressor.predict(region_activations) + regressed_act = self._package_temporal(time_bin, region, regressed_act) + temporal_assembly.append(regressed_act) + temporal_assembly = merge_data_arrays(temporal_assembly) + return temporal_assembly + + @store_dict(dict_key='temporal_mapped_regions', identifier_ignore=['temporal_mapped_regions', 'activations' ,'assembly']) + def _set_temporal_maps(self, identifier, temporal_mapped_regions, activations, assembly): + temporal_maps = {} + for region in temporal_mapped_regions: + time_bin_regressor = {} + region_activations = activations.sel(region=region) + for time_bin in assembly.time_bin.values: + target_assembly = assembly.sel(region=region, time_bin=time_bin) + regressor = pls_regression(neuroid_coord=('neuroid_id' ,'layer' ,'region')) + regressor.fit(region_activations, target_assembly) + time_bin_regressor[time_bin] = regressor + temporal_maps[region] = time_bin_regressor + return temporal_maps + + def _set_region_coords(self, activations, layer_regions): + coords = { 'region' : (('neuroid'), [layer_regions[layer] for layer in activations['layer'].values]) } + activations = activations.assign_coords(**coords) + activations = activations.set_index({'neuroid' :'region'}, append=True) + return activations + + def _package_temporal(self, time_bin, region, assembly): + assert len(time_bin) == 2 + assembly = assembly.expand_dims('time_bin', axis=-1) + coords = { + 'time_bin_start': (('time_bin'), [time_bin[0]]) + , 'time_bin_end': (('time_bin'), [time_bin[1]]) + , 'region' : (('neuroid'), [region] * assembly.shape[1]) + } + assembly = assembly.assign_coords(**coords) + assembly = assembly.set_index(time_bin=['time_bin_start', 'time_bin_end'], neuroid='region', append=True) + return assembly + + def start_recording(self, recording_target, time_bins: Optional[list] = None): + self.model_commitment.start_recording(recording_target) + assert self._temporal_maps + assert self.region_layer_map + assert recording_target in self._temporal_maps.keys() + if self.time_bins is None: + self.time_bins = self._temporal_maps[recording_target].keys() + else: + assert set(self._temporal_maps[recording_target].keys()).issuperset(set(time_bins)) + self.recorded_regions = [recording_target] + self.time_bins = time_bins diff --git a/tests/brain_transformation/test_temporal_map.py b/tests/brain_transformation/test_temporal_map.py new file mode 100644 index 0000000..45988f4 --- /dev/null +++ b/tests/brain_transformation/test_temporal_map.py @@ -0,0 +1,141 @@ +import pytest +import functools +import numpy as np +from os import path + +from model_tools.activations import PytorchWrapper +from model_tools.brain_transformation.temporal_map import TemporalModelCommitment + +from brainscore.metrics.regression import pls_regression +from brainscore.assemblies.public import load_assembly + +from xarray import DataArray +from pandas import DataFrame + +from brainio_base.stimuli import StimulusSet +from brainio_base.assemblies import NeuronRecordingAssembly + +def load_test_assemblies(variation, region): + image_dir = path.join(path.dirname(path.abspath(__file__)), 'test_temporal_stimulus') + if type(variation) is not list: + variation = [variation] + + num_stim = 5 + neuroid_cnt = 168 + time_bin_cnt = 5 + resp = np.random.rand(num_stim, neuroid_cnt, time_bin_cnt) + + dims = ['presentation', 'neuroid', 'time_bin'] + coords = { + 'image_id': ('presentation', range(num_stim)), + 'y': ('presentation', range(num_stim)), + 'neuroid_id': ('neuroid', [f'{i}' for i in range(neuroid_cnt)]), + 'region': ('neuroid', ['IT'] * neuroid_cnt), + 'x': ('neuroid', range(neuroid_cnt)), + 'time_bin_start': ('time_bin', range(-10, 40, 10)), + 'time_bin_end': ('time_bin', range(0, 50, 10)) + } + + assembly = DataArray(data=resp, dims=dims, coords=coords) + assembly = assembly.set_index(presentation=['image_id', 'y'], + neuroid=['neuroid_id','region', 'x'], + time_bin=['time_bin_start', 'time_bin_end'], + append=True) + + stim_meta = [{'id': k} for k in range(num_stim)] + image_paths = {} + for i in range(num_stim): + f_name = f"im_{i:05}.jpg" + im_path = path.join(image_dir, f_name) + + meta = stim_meta[i] + meta['image_id'] = f'{i}' + meta['image_file_name'] = f_name + image_paths[f'{i}'] = im_path + + stim_set = DataFrame(stim_meta) + + stim_set = StimulusSet(stim_set) + stim_set.image_paths = image_paths + stim_set.name = f'testing_temporal_stims_{region}_var{"".join(str(v) for v in variation)}' + + assembly = NeuronRecordingAssembly(assembly) + + assembly.attrs['stimulus_set'] = stim_set + assembly.attrs['stimulus_set_name'] = stim_set.name + return assembly + +def pytorch_custom(): + import torch + from torch import nn + from model_tools.activations.pytorch import load_preprocess_images + + class MyModel_Temporal(nn.Module): + def __init__(self): + super(MyModel_Temporal, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3) + self.relu1 = torch.nn.ReLU() + linear_input_size = np.power((224 - 3 + 2 * 0) / 1 + 1, 2) * 2 + self.linear = torch.nn.Linear(int(linear_input_size), 1000) + self.relu2 = torch.nn.ReLU() # can't get named ReLU output otherwise + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + x = x.view(x.size(0), -1) + x = self.linear(x) + x = self.relu2(x) + return x + + preprocessing = functools.partial(load_preprocess_images, image_size=224) + return PytorchWrapper(model=MyModel_Temporal(), preprocessing=preprocessing) + +class TestTemporalModelCommitment: + test_data = [(pytorch_custom, ['linear', 'relu2'], 'IT')] + @pytest.mark.parametrize("model_ctr, layers, region", test_data) + def test(self, model_ctr, layers, region): + commit_assembly = load_assembly(name='dicarlo.Majaj2015.lowvar.IT', + **{'average_repetition': False}) + + training_assembly = load_test_assemblies([0,3], region) + validation_assembly = load_test_assemblies(6, region) + + expected_region = region if type(region)==list else [region] + expected_region_count = len(expected_region) + expected_time_bin_count = len(training_assembly.time_bin.values) + + extractor = model_ctr() + + t_bins = [t for t in training_assembly.time_bin.values if 0 <= t[0] < 30] + expected_recorded_time_count = len(t_bins) + + temporal_model = TemporalModelCommitment('', extractor, layers) + # commit region: + temporal_model.commit_region(region, commit_assembly) + # make temporal: + temporal_model.make_temporal(training_assembly) + assert len(temporal_model._temporal_maps.keys()) == expected_region_count + assert len(temporal_model._temporal_maps[region].keys()) == expected_time_bin_count + # start recording: + temporal_model.start_recording(region, t_bins) + assert temporal_model.recorded_regions == expected_region + # look at: + stim = validation_assembly.stimulus_set + temporal_activations = temporal_model.look_at(stim) + assert set(temporal_activations.region.values) == set(expected_region) + assert len(set(temporal_activations.time_bin.values)) == expected_recorded_time_count + # + test_layer = temporal_model.region_layer_map[region] + train_stim_set = training_assembly.stimulus_set + for time_test in t_bins: + target_assembly = training_assembly.sel(time_bin=time_test, region=region) + region_activations = extractor(train_stim_set, [test_layer]) + regressor = pls_regression(neuroid_coord=('neuroid_id', 'layer', 'region')) + regressor.fit(region_activations, target_assembly) + # + test_activations = extractor(stim, [test_layer]) + test_predictions = regressor.predict(test_activations).values + # + temporal_model_prediction = temporal_activations.sel(region=region, time_bin=time_test).values + assert temporal_model_prediction == pytest.approx(test_predictions, rel=1e-3, abs=1e-6) + diff --git a/tests/brain_transformation/test_temporal_stimulus/im_00000.jpg b/tests/brain_transformation/test_temporal_stimulus/im_00000.jpg new file mode 100644 index 0000000..bcf9d88 Binary files /dev/null and b/tests/brain_transformation/test_temporal_stimulus/im_00000.jpg differ diff --git a/tests/brain_transformation/test_temporal_stimulus/im_00001.jpg b/tests/brain_transformation/test_temporal_stimulus/im_00001.jpg new file mode 100644 index 0000000..3385a4f Binary files /dev/null and b/tests/brain_transformation/test_temporal_stimulus/im_00001.jpg differ diff --git a/tests/brain_transformation/test_temporal_stimulus/im_00002.jpg b/tests/brain_transformation/test_temporal_stimulus/im_00002.jpg new file mode 100644 index 0000000..3ad52da Binary files /dev/null and b/tests/brain_transformation/test_temporal_stimulus/im_00002.jpg differ diff --git a/tests/brain_transformation/test_temporal_stimulus/im_00003.jpg b/tests/brain_transformation/test_temporal_stimulus/im_00003.jpg new file mode 100644 index 0000000..72a6ab9 Binary files /dev/null and b/tests/brain_transformation/test_temporal_stimulus/im_00003.jpg differ diff --git a/tests/brain_transformation/test_temporal_stimulus/im_00004.jpg b/tests/brain_transformation/test_temporal_stimulus/im_00004.jpg new file mode 100644 index 0000000..8bbb4c3 Binary files /dev/null and b/tests/brain_transformation/test_temporal_stimulus/im_00004.jpg differ