-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched inference CEBRA & padding at the Solver
level
#168
base: main
Are you sure you want to change the base?
Changes from 44 commits
283de06
202e379
1f1989d
8665660
8d5b114
32c5ecd
9928f63
be5630a
1300b20
bc6af24
ec377b9
6f9ca98
fbe7eb4
463b0f8
5219171
f9bd1a6
e23a7ef
19c3f87
87bebac
f0303e0
8c8be85
59df402
1aadc8b
bc8ee25
5e7a14c
928d882
07bac1c
0823b54
9fe3af3
b417a23
83c1669
9c46eb9
c845ec3
9db3e37
8e5f933
d08e400
0c693dd
ae056b2
794867b
0bb6549
04a102f
7aab282
ffa66eb
7f58607
c2544c7
ad5da03
f6aa2e6
e1b7cc7
0eac868
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -780,8 +780,6 @@ def _configure_for_all( | |
f"receptive fields/offsets larger than 1 via the sklearn API. " | ||
f"Please use a different model, or revert to the pytorch " | ||
f"API for training.") | ||
|
||
d.configure_for(model[n]) | ||
else: | ||
if not isinstance(model, cebra.models.ConvolutionalModelMixin): | ||
if len(model.get_offset()) > 1: | ||
|
@@ -791,37 +789,11 @@ def _configure_for_all( | |
f"Please use a different model, or revert to the pytorch " | ||
f"API for training.") | ||
|
||
dataset.configure_for(model) | ||
dataset.configure_for(model) | ||
|
||
def _select_model(self, X: Union[npt.NDArray, torch.Tensor], | ||
session_id: int): | ||
# Choose the model and get its corresponding offset | ||
if self.num_sessions is not None: # multisession implementation | ||
if session_id is None: | ||
raise RuntimeError( | ||
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." | ||
) | ||
if session_id >= self.num_sessions or session_id < 0: | ||
raise RuntimeError( | ||
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." | ||
) | ||
if self.n_features_[session_id] != X.shape[1]: | ||
raise ValueError( | ||
f"Invalid input shape: model for session {session_id} requires an input of shape" | ||
f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})." | ||
) | ||
|
||
model = self.model_[session_id] | ||
model.to(self.device_) | ||
else: # single session | ||
if session_id is not None and session_id > 0: | ||
raise RuntimeError( | ||
f"Invalid session_id {session_id}: single session models only takes an optional null session_id." | ||
) | ||
model = self.model_ | ||
|
||
offset = model.get_offset() | ||
return model, offset | ||
return self.solver_._select_model(X, session_id=session_id) | ||
|
||
def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): | ||
"""Check that the input labels are compatible with the labels used to fit the model. | ||
|
@@ -1203,11 +1175,13 @@ def fit( | |
|
||
def transform(self, | ||
X: Union[npt.NDArray, torch.Tensor], | ||
batch_size: Optional[int] = None, | ||
CeliaBenquet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
session_id: Optional[int] = None) -> npt.NDArray: | ||
"""Transform an input sequence and return the embedding. | ||
|
||
Args: | ||
X: A numpy array or torch tensor of size ``time x dimension``. | ||
batch_size: | ||
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for | ||
multisession, set to ``None`` for single session. | ||
|
||
|
@@ -1225,34 +1199,25 @@ def transform(self, | |
>>> embedding = cebra_model.transform(dataset) | ||
|
||
""" | ||
|
||
sklearn_utils_validation.check_is_fitted(self, "n_features_") | ||
model, offset = self._select_model(X, session_id) | ||
self.solver_._check_is_session_id_valid(session_id=session_id) | ||
|
||
# Input validation | ||
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) | ||
input_dtype = X.dtype | ||
|
||
with torch.no_grad(): | ||
model.eval() | ||
if torch.is_tensor(X) and X.device.type == "cuda": | ||
CeliaBenquet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
X = X.detach().cpu() | ||
CeliaBenquet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if self.pad_before_transform: | ||
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), | ||
mode="edge") | ||
X = torch.from_numpy(X).float().to(self.device_) | ||
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) | ||
|
||
if isinstance(model, cebra.models.ConvolutionalModelMixin): | ||
# Fully convolutional evaluation, switch (T, C) -> (1, C, T) | ||
X = X.transpose(1, 0).unsqueeze(0) | ||
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) | ||
else: | ||
# Standard evaluation, (T, C, dt) | ||
output = model(X).cpu().numpy() | ||
if isinstance(X, np.ndarray): | ||
X = torch.from_numpy(X) | ||
|
||
if input_dtype == "float64": | ||
return output.astype(input_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is quite a change to the transform logic --- can we add a test that the new CEBRA transform function matches exactly the outputs of the old CEBRA transform function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, see in |
||
with torch.no_grad(): | ||
output = self.solver_.transform( | ||
inputs=X, | ||
pad_before_transform=self.pad_before_transform, | ||
session_id=session_id, | ||
batch_size=batch_size) | ||
|
||
return output | ||
return output.detach().cpu().numpy() | ||
|
||
def fit_transform( | ||
self, | ||
|
@@ -1455,6 +1420,11 @@ def load(cls, | |
else: | ||
cebra_ = _check_type_checkpoint(checkpoint) | ||
|
||
n_features = cebra_.n_features_ | ||
cebra_.solver_.n_features = ([ | ||
session_n_features for session_n_features in n_features | ||
] if isinstance(n_features, list) else n_features) | ||
|
||
return cebra_ | ||
|
||
def to(self, device: Union[str, torch.device]): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo? / missing cleanup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. should the line below be removed here? why is that relevant for batched inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configure_for was done in the cebra.CEBRA class (in configure_for_all) and now it is moved to the solvers directly, and the configure_for in the multisession solver was wrongly implemented and not used.
So now not implemented in the base class and defined in multi and single solvers.