Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched inference CEBRA & padding at the Solver level #168

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
283de06
first proposal for batching in tranform method
gonlairo Jun 21, 2023
202e379
first running version of padding with batched inference
gonlairo Jun 22, 2023
1f1989d
start tests
gonlairo Jun 23, 2023
8665660
add pad_before_transform to fit function and add support for convolut…
gonlairo Sep 27, 2023
8d5b114
remove print statements
gonlairo Sep 27, 2023
32c5ecd
first passing test
gonlairo Sep 27, 2023
9928f63
add support for hybrid models
gonlairo Sep 28, 2023
be5630a
rewrite transform in sklearn API
gonlairo Sep 28, 2023
1300b20
baseline version of a torch.Datset
gonlairo Oct 16, 2023
bc6af24
move batching logic outside solver
gonlairo Oct 20, 2023
ec377b9
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
6f9ca98
add test_select_model for single session
gonlairo Oct 27, 2023
fbe7eb4
add checks and test for _process_batch
gonlairo Oct 27, 2023
463b0f8
add test_select_model for multisession
gonlairo Oct 30, 2023
5219171
make self.num_sessions compatible with single session training
gonlairo Oct 31, 2023
f9bd1a6
improve test_batched_transform_singlesession
gonlairo Nov 1, 2023
e23a7ef
make it work with small batches
gonlairo Nov 7, 2023
19c3f87
make test with multisession work
gonlairo Nov 8, 2023
87bebac
change to torch padding
gonlairo Nov 9, 2023
f0303e0
add argument to sklearn api
gonlairo Nov 9, 2023
8c8be85
add torch padding to _transform
gonlairo Nov 9, 2023
59df402
convert to torch if numpy array as inputs
gonlairo Nov 9, 2023
1aadc8b
add distinction between pad with data and pad with zeros and modify t…
gonlairo Nov 15, 2023
bc8ee25
differentiate between data padding and zero padding
gonlairo Nov 17, 2023
5e7a14c
remove float16
gonlairo Nov 24, 2023
928d882
change argument position
gonlairo Nov 27, 2023
07bac1c
clean test
gonlairo Nov 27, 2023
0823b54
clean test
gonlairo Nov 27, 2023
9fe3af3
Fix warning
CeliaBenquet Mar 26, 2024
b417a23
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
83c1669
Add tests to solver
CeliaBenquet Aug 22, 2024
9c46eb9
Remove unused import in solver/utils
CeliaBenquet Aug 22, 2024
c845ec3
Fix test plot
CeliaBenquet Aug 22, 2024
9db3e37
Add some coverage
CeliaBenquet Aug 22, 2024
8e5f933
Fix save/load
CeliaBenquet Aug 22, 2024
d08e400
Remove duplicate configure_for in multi dataset
CeliaBenquet Aug 22, 2024
0c693dd
Make save/load cleaner
CeliaBenquet Aug 22, 2024
ae056b2
Merge branch 'main' into batched-inference-and-padding
CeliaBenquet Sep 18, 2024
794867b
Fix codespell errors
CeliaBenquet Sep 18, 2024
0bb6549
Fix docs compilation errors
CeliaBenquet Sep 18, 2024
04a102f
Fix formatting
CeliaBenquet Sep 18, 2024
7aab282
Fix extra docs errors
CeliaBenquet Sep 18, 2024
ffa66eb
Fix offset in docs
CeliaBenquet Sep 18, 2024
7f58607
Remove attribute ref
CeliaBenquet Sep 18, 2024
c2544c7
Add review updates
CeliaBenquet Sep 19, 2024
ad5da03
Merge branch 'main' into batched-inference-and-padding
stes Oct 20, 2024
f6aa2e6
Merge branch 'main' into batched-inference-and-padding
MMathisLab Oct 20, 2024
e1b7cc7
apply ruff auto-fixes
stes Oct 27, 2024
0eac868
Merge remote-tracking branch 'origin/main' into batched-inference-and…
stes Oct 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def load_batch(self, index: BatchIndex) -> Batch:
"""
raise NotImplementedError()

@abc.abstractmethod
def configure_for(self, model: "cebra.models.Model"):
"""Configure the dataset offset for the provided model.

Expand All @@ -205,6 +206,7 @@ def configure_for(self, model: "cebra.models.Model"):
Args:
model: The model to configure the dataset for.
"""
raise NotImplementedError
self.offset = model.get_offset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo? / missing cleanup?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. should the line below be removed here? why is that relevant for batched inference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configure_for was done in the cebra.CEBRA class (in configure_for_all) and now it is moved to the solvers directly, and the configure_for in the multisession solver was wrongly implemented and not used.

So now not implemented in the base class and defined in multi and single solvers.



Expand All @@ -230,6 +232,8 @@ class Loader(abc.ABC, cebra.io.HasDevice):
doc="""A dataset instance specifying a ``__getitem__`` function.""",
)

time_offset: int = dataclasses.field(default=10)

num_steps: int = dataclasses.field(
default=None,
doc=
Expand Down
27 changes: 16 additions & 11 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch

import cebra.data as cebra_data
import cebra.distributions as cebra_distr
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex

Expand Down Expand Up @@ -106,10 +106,17 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
) for session_id, session in enumerate(self.iter_sessions())
]

def configure_for(self, model):
self.offset = model.get_offset()
for session in self.iter_sessions():
session.configure_for(model)
def configure_for(self, model: "cebra.models.Model"):
"""Configure the dataset offset for the provided model.

Call this function before indexing the dataset. This sets the
`offset` attribute of the dataset.

Args:
model: The model to configure the dataset for.
"""
for i, session in enumerate(self.iter_sessions()):
session.configure_for(model[i])


@dataclasses.dataclass
Expand All @@ -121,12 +128,10 @@ class MultiSessionLoader(cebra_data.Loader):
dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`.
"""

time_offset: int = dataclasses.field(default=10)

def __post_init__(self):
super().__post_init__()
self.sampler = cebra_distr.MultisessionSampler(self.dataset,
self.time_offset)
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
Expand All @@ -151,7 +156,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader):
"""Contrastive learning conditioned on a continuous behavior variable."""

conditional: str = "time_delta"
time_offset: int = dataclasses.field(default=10)

@property
def index(self):
Expand All @@ -165,7 +169,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
self.dataset)

@property
def index(self):
Expand Down
14 changes: 11 additions & 3 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def load_batch(self, index: BatchIndex) -> Batch:
reference=self[index.reference],
)

def configure_for(self, model: "cebra.models.Model"):
"""Configure the dataset offset for the provided model.

Call this function before indexing the dataset. This sets the
`offset` attribute of the dataset.

Args:
model: The model to configure the dataset for.
"""
self.offset = model.get_offset()


@dataclasses.dataclass
class DiscreteDataLoader(cebra_data.Loader):
Expand Down Expand Up @@ -192,7 +203,6 @@ class ContinuousDataLoader(cebra_data.Loader):
and become equivalent to time contrastive learning.
""",
)
time_offset: int = dataclasses.field(default=10)
stes marked this conversation as resolved.
Show resolved Hide resolved
delta: float = dataclasses.field(default=0.1)

def __post_init__(self):
Expand Down Expand Up @@ -274,7 +284,6 @@ class MixedDataLoader(cebra_data.Loader):
"""

conditional: str = dataclasses.field(default="time_delta")
time_offset: int = dataclasses.field(default=10)

@property
def dindex(self):
Expand Down Expand Up @@ -337,7 +346,6 @@ class HybridDataLoader(cebra_data.Loader):
"""

conditional: str = dataclasses.field(default="time_delta")
time_offset: int = dataclasses.field(default=10)
delta: float = dataclasses.field(default=0.1)

@property
Expand Down
74 changes: 22 additions & 52 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,6 @@ def _configure_for_all(
f"receptive fields/offsets larger than 1 via the sklearn API. "
f"Please use a different model, or revert to the pytorch "
f"API for training.")

d.configure_for(model[n])
else:
if not isinstance(model, cebra.models.ConvolutionalModelMixin):
if len(model.get_offset()) > 1:
Expand All @@ -791,37 +789,11 @@ def _configure_for_all(
f"Please use a different model, or revert to the pytorch "
f"API for training.")

dataset.configure_for(model)
dataset.configure_for(model)

def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
session_id: int):
# Choose the model and get its corresponding offset
if self.num_sessions is not None: # multisession implementation
if session_id is None:
raise RuntimeError(
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
)
if session_id >= self.num_sessions or session_id < 0:
raise RuntimeError(
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}."
)
if self.n_features_[session_id] != X.shape[1]:
raise ValueError(
f"Invalid input shape: model for session {session_id} requires an input of shape"
f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})."
)

model = self.model_[session_id]
model.to(self.device_)
else: # single session
if session_id is not None and session_id > 0:
raise RuntimeError(
f"Invalid session_id {session_id}: single session models only takes an optional null session_id."
)
model = self.model_

offset = model.get_offset()
return model, offset
return self.solver_._select_model(X, session_id=session_id)

def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
"""Check that the input labels are compatible with the labels used to fit the model.
Expand Down Expand Up @@ -1203,11 +1175,13 @@ def fit(

def transform(self,
X: Union[npt.NDArray, torch.Tensor],
batch_size: Optional[int] = None,
CeliaBenquet marked this conversation as resolved.
Show resolved Hide resolved
session_id: Optional[int] = None) -> npt.NDArray:
"""Transform an input sequence and return the embedding.

Args:
X: A numpy array or torch tensor of size ``time x dimension``.
batch_size:
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
multisession, set to ``None`` for single session.

Expand All @@ -1225,34 +1199,25 @@ def transform(self,
>>> embedding = cebra_model.transform(dataset)

"""

sklearn_utils_validation.check_is_fitted(self, "n_features_")
model, offset = self._select_model(X, session_id)
self.solver_._check_is_session_id_valid(session_id=session_id)

# Input validation
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
input_dtype = X.dtype

with torch.no_grad():
model.eval()
if torch.is_tensor(X) and X.device.type == "cuda":
CeliaBenquet marked this conversation as resolved.
Show resolved Hide resolved
X = X.detach().cpu()
CeliaBenquet marked this conversation as resolved.
Show resolved Hide resolved

if self.pad_before_transform:
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
mode="edge")
X = torch.from_numpy(X).float().to(self.device_)
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))

if isinstance(model, cebra.models.ConvolutionalModelMixin):
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
X = X.transpose(1, 0).unsqueeze(0)
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
else:
# Standard evaluation, (T, C, dt)
output = model(X).cpu().numpy()
if isinstance(X, np.ndarray):
X = torch.from_numpy(X)

if input_dtype == "float64":
return output.astype(input_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a change to the transform logic --- can we add a test that the new CEBRA transform function matches exactly the outputs of the old CEBRA transform function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, see in test_sklearn if that's satisfying.

with torch.no_grad():
output = self.solver_.transform(
inputs=X,
pad_before_transform=self.pad_before_transform,
session_id=session_id,
batch_size=batch_size)

return output
return output.detach().cpu().numpy()

def fit_transform(
self,
Expand Down Expand Up @@ -1455,6 +1420,11 @@ def load(cls,
else:
cebra_ = _check_type_checkpoint(checkpoint)

n_features = cebra_.n_features_
cebra_.solver_.n_features = ([
session_n_features for session_n_features in n_features
] if isinstance(n_features, list) else n_features)

return cebra_

def to(self, device: Union[str, torch.device]):
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def infonce_loss(
f"got {len(y[0])} sessions.")

model, _ = cebra_model._select_model(
X, session_id) # check session_id validity and corresponding model
X, session_id=session_id
) # check session_id validity and corresponding model
cebra_model._check_labels_types(y, session_id=session_id)

dataset, is_multisession = cebra_model._prepare_data(X, y) # single session
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
# NOTE: remove float16 because F.pad does not allow float16.
dtype=("float32", "float64"),
order=None,
copy=False,
force_all_finite=True,
Expand Down
Loading
Loading