Skip to content

Commit

Permalink
apply pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Mar 28, 2024
1 parent bbaf35d commit 91f07d9
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_indices(self, num_samples: int) -> List[BatchIndex]:
ref_idx = torch.from_numpy(ref_idx)
neg_idx = torch.from_numpy(neg_idx)
pos_idx = torch.from_numpy(pos_idx)

return BatchIndex(
reference=ref_idx,
positive=pos_idx,
Expand Down
9 changes: 5 additions & 4 deletions cebra/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ class MultiDiscrete(cebra.data.DatasetCollection):
"""Demo dataset for testing."""

def __init__(
self,
nums_neural=[3, 4, 5],
num_timepoints=_DEFAULT_NUM_TIMEPOINTS,
):
self,
nums_neural=[3, 4, 5],
num_timepoints=_DEFAULT_NUM_TIMEPOINTS,
):
super().__init__(*[
DemoDatasetDiscrete(num_timepoints, num_neural)
for num_neural in nums_neural
])


@register("demo-continuous-multisession")
class MultiContinuous(cebra.data.DatasetCollection):

Expand Down
9 changes: 5 additions & 4 deletions cebra/distributions/multisession.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ def __getitem__(self, pos_idx):
pos_samples[i] = self.data[i][pos_idx[i]]
return pos_samples


class DiscreteMultisessionSampler(cebra_distr.PriorDistribution,
cebra_distr.ConditionalDistribution):
cebra_distr.ConditionalDistribution):
"""Discrete multi-session sampling.
Discrete indices don't need to be aligned. Positive pairs are found
Expand Down Expand Up @@ -370,9 +371,9 @@ def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor:
# sample conditional for each assigned session
pos_idx = torch.zeros(shape, device=_device).long()
for i in range(self.num_sessions):
pos_idx[i] = self.index[i].sample_conditional(query[i])
pos_idx[i] = self.index[i].sample_conditional(query[i])
pos_idx = pos_idx.cpu().numpy()

# reverse indices to recover the ref/pos samples matching
idx_rev = _invert_index(idx)
return pos_idx, idx, idx_rev
Expand All @@ -381,4 +382,4 @@ def __getitem__(self, pos_idx):
pos_samples = np.zeros(pos_idx.shape[:2] + (self.data.shape[2],))
for i in range(self.num_sessions):
pos_samples[i] = self.data[i][pos_idx[i]]
return pos_samples
return pos_samples
4 changes: 1 addition & 3 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ def _require_arg(key):

# Discrete behavior contrastive training is selected with the default dataloader
if not is_cont and is_disc:
kwargs = dict(
**shared_kwargs,
)
kwargs = dict(**shared_kwargs,)
if is_full:
if is_hybrid:
raise_not_implemented_error = True
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,3 @@ dev =

[bdist_wheel]
universal=1

10 changes: 7 additions & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_multi_session_time_contrastive(time_offset):
assert (idx.flatten()[rev_idx.flatten()].all() == np.arange(
len(rev_idx.flatten())).all())


def test_multi_session_discrete():
dataset = cebra_datasets.init("demo-discrete-multisession")
sampler = cebra_distr.DiscreteMultisessionSampler(dataset)
Expand All @@ -312,13 +313,16 @@ def test_multi_session_discrete():
# NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat
assert (idx.flatten()[rev_idx.flatten()].all() == np.arange(
len(rev_idx.flatten())).all())

# Check positive samples' labels match reference samples' labels
sample_labels = sampler.all_data[(sample + sampler.lengths[:, None]).flatten()]
sample_labels = sampler.all_data[(sample +
sampler.lengths[:, None]).flatten()]
sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten()
positive_labels = sampler.all_data[(positive + sampler.lengths[:, None]).flatten()]
positive_labels = sampler.all_data[(positive +
sampler.lengths[:, None]).flatten()]
assert (sample_labels == positive_labels).all()


class OldDeltaDistribution(cebra_distr_base.JointDistribution,
cebra_distr_base.HasGenerator):
"""
Expand Down
14 changes: 6 additions & 8 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,10 @@ def _process(batch, feature_dim=1):
assert dummy_prediction.shape == (3, 32, 6)
_mix(dummy_prediction, batch[0].index)


def test_multisession_disc_loader():
data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5],
num_timepoints=100)
num_timepoints=100)
loader = cebra.data.DiscreteMultiSessionDataLoader(
data,
num_steps=10,
Expand Down Expand Up @@ -313,15 +314,13 @@ def _process(batch, feature_dim=1):
assert dummy_prediction.shape == (3, 32, 6)
_mix(dummy_prediction, batch[0].index)


@parametrize_device
@pytest.mark.parametrize(
"data_name, loader_initfunc",
[
('demo-discrete-multisession',
cebra.data.DiscreteMultiSessionDataLoader),
("demo-continuous-multisession",
cebra.data.ContinuousMultiSessionDataLoader)
],
[('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader),
("demo-continuous-multisession",
cebra.data.ContinuousMultiSessionDataLoader)],
)
def test_multisession_loader(data_name, loader_initfunc, device):
# TODO change number of timepoints across the sessions
Expand All @@ -339,4 +338,3 @@ def test_multisession_loader(data_name, loader_initfunc, device):
_check_attributes(batch, is_list=True)
for session_batch in batch:
assert len(session_batch.positive) == 32

1 change: 1 addition & 0 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class __Dataset(cebra.datasets.MultiContinuous):
neural = torch.zeros((50, 10), dtype=torch.float)
continuous_index = torch.zeros((50, 10), dtype=torch.float)
elif is_multi and is_disc:

class __Dataset(cebra.datasets.MultiDiscrete):
neural = torch.zeros((50, 10), dtype=torch.float)
discrete_index = torch.zeros((50,), dtype=torch.int)
Expand Down
1 change: 0 additions & 1 deletion tests/test_sklearn_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def test_sklearn_datasets_consistency():
dataset_ids=["achilles", "buddy"],
between="datasets",
)



def test_sklearn_runs_consistency():
Expand Down
9 changes: 5 additions & 4 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
(*args, cebra.solver.SingleSessionHybridSolver))

multi_session_tests = []
for args in [
("demo-continuous-multisession", cebra.data.ContinuousMultiSessionDataLoader),
("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader)
]:
for args in [("demo-continuous-multisession",
cebra.data.ContinuousMultiSessionDataLoader),
("demo-discrete-multisession",
cebra.data.DiscreteMultiSessionDataLoader)]:
multi_session_tests.append((*args, cebra.solver.MultiSessionSolver))
# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver))

Expand Down Expand Up @@ -169,6 +169,7 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc):

solver.fit(loader)


@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
multi_session_tests)
def test_multi_session(data_name, loader_initfunc, solver_initfunc):
Expand Down

0 comments on commit 91f07d9

Please sign in to comment.