Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding example script with custom Loader for PyTorch API documentation #97

Closed
timonmerk opened this issue Oct 29, 2023 · 3 comments
Closed
Assignees
Labels
documentation Improvements or additions to documentation enhancement New feature or request

Comments

@timonmerk
Copy link
Contributor

The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization.
In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html

But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling:

prior: str = dataclasses.field(

I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?

Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
#    file="auxiliary_behavior_data.h5",
#    key="auxiliary_variables",
#    columns=["continuous1", "continuous2", "continuous3"],
# )

discrete_label = cebra.load_data(
    file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
    torch.from_numpy(neural_data).type(torch.FloatTensor),
    # continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
    discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")

# 2. Define Cebra Model
neural_model = cebra.models.init(
    name="offset10-model",
    num_neurons=InputData.input_dimension,
    num_units=32,
    num_output=2,
).to("cpu")

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
    # temperature=0.001,
    # min_temperature=0.0001
).to("cpu")

Opt = torch.optim.Adam(
    list(neural_model.parameters()) + list(Crit.parameters()),
    # lr=0.001,
    weight_decay=0,
)

# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
    name="single-session",
    model=neural_model,
    criterion=Crit,
    optimizer=Opt,
    tqdm_on=True,
).to("cpu")

# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
#    dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
    dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)

# 6. Fit model
cebra_model.fit(loader=loader)

# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
    neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
    torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")

# 8. Potentially plot embedding
plot_embedding(
    X_train_emb,
    discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
    markersize=10,
)
@MMathisLab
Copy link
Member

MMathisLab commented Oct 29, 2023

Thanks @timonmerk I think that's a great idea! We can also link it to the Allen Demo, which uses the PyTorch API: https://cebra.ai/docs/demo_notebooks/Demo_Allen.html -- would you like to make a PR?

Edit to add, we DO have an example, just to be sure you saw it, https://cebra.ai/docs/usage.html#quick-start-torch-api-example, but agree it's not nearly as expressive as your good suggestion to make the sklearn Quick Start guide!

@MMathisLab MMathisLab added documentation Improvements or additions to documentation enhancement New feature or request labels Oct 29, 2023
@MMathisLab MMathisLab self-assigned this Oct 29, 2023
@timonmerk
Copy link
Contributor Author

timonmerk commented Oct 29, 2023

Thanks! Yes, that notebook was super helpful already but might be good to link it also in the usage.rst.
I also compiled the documentation and ran my linked example locally, but of course it's a bit difficult to test it since it's in a rst file only..

@stes
Copy link
Member

stes commented Oct 29, 2023

but of course it's a bit difficult to test it since it's in a rst file only..

Actually this is included in the unit tests!

Locally, you can run make test or also directly (adapt as needed)

python -m pytest --ff --doctest-modules -m "not requires_dataset" tests ./docs/source/usage.rst cebra

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants