Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch API usage example #99

Merged
merged 13 commits into from
Dec 7, 2023
92 changes: 92 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1317,3 +1317,95 @@ Below is the documentation on the available arguments.
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
--share-model

Model training using the Torch API
----------------------------------

The scikit-learn API provides parametrization to many common use cases.
The Torch API however allows for more flexibility and customization, for e.g.
sampling, criterions, and data loaders.

In this minimal example we show how to initialize a CEBRA model using the Torch API.
Here the :py:class:`cebra.data.single_session.DiscreteDataLoader`
gets initialized which also allows the `prior` to be directly parametrized.

👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`.


.. testcode::

import numpy as np
import cebra.datasets
import torch

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

neural_data = cebra.load_data(file="neural_data.npz", key="neural")

discrete_label = cebra.load_data(
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
input_data = cebra.data.TensorDataset(
torch.from_numpy(neural_data).type(torch.FloatTensor),
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to(device)

# 2. Define Cebra Model
neural_model = cebra.models.init(
name="offset10-model",
num_neurons=input_data.input_dimension,
num_units=32,
num_output=2,
).to(device)

input_data.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
crit = cebra.models.criterions.LearnableCosineInfoNCE(
temperature=0.001,
min_temperature=0.0001
).to(device)

opt = torch.optim.Adam(
list(neural_model.parameters()) + list(crit.parameters()),
lr=0.001,
weight_decay=0,
)

# 4. Initialize Cebra Model
solver = cebra.solver.init(
name="single-session",
model=neural_model,
criterion=crit,
optimizer=opt,
tqdm_on=True,
).to(device)

# 5. Define Data Loader
loader = cebra.data.single_session.DiscreteDataLoader(
dataset=input_data, num_steps=10, batch_size=200, prior="uniform"
)

# 6. Fit Model
solver.fit(loader=loader)

# 7. Transform Embedding
train_batches = np.lib.stride_tricks.sliding_window_view(
neural_data, neural_model.get_offset().__len__(), axis=0
)

x_train_emb = solver.transform(
torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
).to(device)
MMathisLab marked this conversation as resolved.
Show resolved Hide resolved

# 8. Plot Embedding
cebra.plot_embedding(
x_train_emb,
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
MMathisLab marked this conversation as resolved.
Show resolved Hide resolved
markersize=10,
)
Loading