-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KeyError with cebra.fit(adapt=True)
on multi-session embeddings (only)
#108
Comments
Minimum example: import numpy as np
import cebra
timesteps = 10
neurons = 50
out_dim = 8
n_sessions = 3
neural_data = [np.random.normal(0,1,(timesteps, neurons)) for _ in range(n_sessions)]
continuous_label = [np.random.normal(0,1,(timesteps, 3)) for _ in range(n_sessions)]
multi_cebra_model = cebra.CEBRA(batch_size=512,
output_dimension=out_dim,
max_iterations=10,
max_adapt_iterations=10,
verbose=True
)
multi_cebra_model.fit(neural_data[:2], continuous_label[:2], adapt=False)
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True) which gives
However, if I do: multi_cebra_model.partial_fit(neural_data[:2], continuous_label[:2], )
multi_cebra_model.partial_fit(neural_data[2], continuous_label[2], ) that seems to work (i.e. I don't get an error), however my understanding is that Also this also errors: multi_cebra_model.partial_fit(neural_data[:2], continuous_label[:2], )
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True) |
Hi @FedeClaudi , thanks for the detailed report. The intended behavior would be to adapt the model by running multi_cebra_model.fit(neural_data[:2], continuous_label[:2], adapt=False)
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True) as you posted. I will investigate how we could fix this for the multi-session training, seems to be an issue with model loading/replacing the first layer in a multi-session model. |
Hi @FedeClaudi right now it's not implemented in the sklearn API, but we will prioritize doing this! thanks for raising the issue. |
cebra.fit(adapt=True)
.cebra.fit(adapt=True)
on multi-session embeddings (only)
Is there an existing issue for this?
Bug description
Hey everyone, thanks for releasing this awesome tool.
I get an error when calling
cebra_model.fit(..., adapt=True)
to fine tune a model on new data.This only happens if the model was trained on multiple sessions, but not when adapting a model trained on a single session.
The actual usage of
CEBRA
is within a larger code-base so I don't have a mwe, but these are the steps I do:1: define
CEBRA
model withhybrid=False
.2: fit
self.model
on multiple sessions with:where
Xs
andYs
are lists ofnp.ndarray
with neural and behavioral (continuous) data from multiple experimental sessions.3: adapt the model to new data:
with
X_new, Y_new
being arrays with the data for a single new session.I get
KeyError: '0.net.0.weight'
- this is theCEBRA
part of the error stack:I did some digging and the problem is that the
adapt_model
created here has different keys in.state_dict()
compared toself.model_
.adapt_model
keys:self.model_
keys:When I first train
CEBRA
on a single sessionself.model_
's keys match those ofadapt_model
.Is training on multiple sessions + fine-tuning on new a new one not allowed? Am I doing something wrong in using CEBRA?
Thanks,
Federico
Operating System
Windows
CEBRA version
cebra version: '0.2.0'
Device type
gpu: NVIDIA GeForce RTX 3080
Steps To Reproduce
No response
Relevant log output
No response
Anything else?
No response
Code of Conduct
The text was updated successfully, but these errors were encountered: