Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds possibility to turn off final_readout_nonlinearity in Encoder and se_core_full_gauss_readout model #34

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
8 changes: 6 additions & 2 deletions nndichromacy/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@


class Encoder(nn.Module):
def __init__(self, core, readout, elu_offset, shifter=None):
def __init__(self, core, readout, final_nonlinearity, elu_offset, shifter=None):
super().__init__()
self.core = core
self.readout = readout
self.offset = elu_offset
self.shifter = shifter
self.readout_nonlinearity = final_nonlinearity

def forward(
self, *args, data_key=None, eye_pos=None, shift=None, trial_idx=None, **kwargs
Expand Down Expand Up @@ -52,7 +53,10 @@ def forward(
shift = self.shifter[data_key](eye_pos)

x = self.readout(x, data_key=data_key, shift=shift, **kwargs)
return F.elu(x + self.offset) + 1
if self.readout_nonlinearity is True:
return F.elu(x + self.offset) + 1
else:
return x

def regularizer(self, data_key):
return self.core.regularizer() + self.readout.regularizer(data_key=data_key)
Expand Down
99 changes: 27 additions & 72 deletions nndichromacy/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
except ModuleNotFoundError:
pass
except:
print(
"dj database connection could not be established. no access to pretrained models available."
)
print("dj database connection could not be established. no access to pretrained models available.")


# from . import logger as log
Expand Down Expand Up @@ -162,11 +160,7 @@ def se_core_gauss_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
def __init__(self, core, readout, elu_offset):
Expand Down Expand Up @@ -251,6 +245,7 @@ def se_core_full_gauss_readout(
init_sigma=1.0,
readout_bias=True, # readout args,
gamma_readout=4,
final_readout_nonlinearity=True,
elu_offset=0,
stack=None,
se_reduction=32,
Expand Down Expand Up @@ -313,11 +308,7 @@ def se_core_full_gauss_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

source_grids = None
grid_mean_predictor_type = None
Expand All @@ -326,18 +317,13 @@ def se_core_full_gauss_readout(
grid_mean_predictor_type = grid_mean_predictor.pop("type")
if grid_mean_predictor_type == "cortex":
input_dim = grid_mean_predictor.pop("input_dimensions", 2)
source_grids = {
k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim]
for k, v in dataloaders.items()
}
source_grids = {k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] for k, v in dataloaders.items()}
elif grid_mean_predictor_type == "shared":
pass

shared_match_ids = None
if share_features or share_grid:
shared_match_ids = {
k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()
}
shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()}
all_multi_unit_ids = set(np.hstack(shared_match_ids.values()))

for match_id in shared_match_ids.values():
Expand Down Expand Up @@ -414,7 +400,13 @@ def se_core_full_gauss_readout(
_, targets = next(iter(value))[:2]
readout[key].bias.data = targets.mean(0)

model = Encoder(core=core, readout=readout, elu_offset=elu_offset, shifter=shifter)
model = Encoder(
core=core,
readout=readout,
final_nonlinearity=final_readout_nonlinearity,
elu_offset=elu_offset,
shifter=shifter,
)

return model

Expand Down Expand Up @@ -503,11 +495,7 @@ def se_core_behavior_gauss(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

if "train" in dataloaders.keys():
dataloaders = dataloaders["train"]
Expand All @@ -527,18 +515,13 @@ def se_core_behavior_gauss(
grid_mean_predictor_type = grid_mean_predictor.pop("type")
if grid_mean_predictor_type == "cortex":
input_dim = grid_mean_predictor.pop("input_dimensions", 2)
source_grids = {
k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim]
for k, v in dataloaders.items()
}
source_grids = {k: v.dataset.neurons.cell_motor_coordinates[:, :input_dim] for k, v in dataloaders.items()}
elif grid_mean_predictor_type == "shared":
pass

shared_match_ids = None
if share_features or share_grid:
shared_match_ids = {
k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()
}
shared_match_ids = {k: v.dataset.neurons.multi_match_id for k, v in dataloaders.items()}
all_multi_unit_ids = set(np.hstack(shared_match_ids.values()))

for match_id in shared_match_ids.values():
Expand Down Expand Up @@ -618,9 +601,7 @@ def se_core_behavior_gauss(
_, targets = next(iter(value))[:2]
readout[key].bias.data = targets.mean(0)

model = GeneralEncoder(
core=core, readout=readout, elu_offset=elu_offset, shifter=shifter
)
model = GeneralEncoder(core=core, readout=readout, elu_offset=elu_offset, shifter=shifter)

return model

Expand Down Expand Up @@ -683,11 +664,7 @@ def se_core_point_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

set_random_seed(seed)

Expand Down Expand Up @@ -792,11 +769,7 @@ def stacked2d_core_gaussian_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
def __init__(self, core, readout, elu_offset):
Expand Down Expand Up @@ -905,11 +878,7 @@ def vgg_core_gauss_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
"""
Expand Down Expand Up @@ -1011,11 +980,7 @@ def vgg_core_full_gauss_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
"""
Expand Down Expand Up @@ -1118,11 +1083,7 @@ def se_core_spatialXfeature_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
def __init__(self, core, readout, elu_offset):
Expand Down Expand Up @@ -1219,11 +1180,7 @@ def rotation_equivariant_gauss_readout(
in_shapes_dict = {k: v[in_name] for k, v in session_shape_dict.items()}
input_channels = [v[in_name][1] for v in session_shape_dict.values()]

core_input_channels = (
list(input_channels.values())[0]
if isinstance(input_channels, dict)
else input_channels[0]
)
core_input_channels = list(input_channels.values())[0] if isinstance(input_channels, dict) else input_channels[0]

class Encoder(nn.Module):
def __init__(self, core, readout, elu_offset):
Expand Down Expand Up @@ -1337,9 +1294,9 @@ def augmented_full_readout(
model.readout["augmentation"].features.data[
:, :, :, insert_index : insert_index + neuron_repeats
] = features[:, :, :, None]
model.readout["augmentation"].bias.data[
insert_index : insert_index + neuron_repeats
] = model.readout[data_key].bias.data[i]
model.readout["augmentation"].bias.data[insert_index : insert_index + neuron_repeats] = model.readout[
data_key
].bias.data[i]
model.readout["augmentation"].sigma.data[
:, insert_index : insert_index + neuron_repeats, :, :
] = model.readout[data_key].sigma.data[:, i, ...]
Expand All @@ -1362,9 +1319,7 @@ def augmented_full_readout(

if rename_data_key is False:
if len(sessions) > 1:
raise ValueError(
"Renaming to original data key is only possible when dataloader has one data key only"
)
raise ValueError("Renaming to original data key is only possible when dataloader has one data key only")
model.readout[sessions[0]] = model.readout.pop("augmentation")

return models