Skip to content

Commit

Permalink
Fix linting errors in tests (#188)
Browse files Browse the repository at this point in the history
* apply auto-fixes

* Fix linting errors in tests/

* Fix version check
  • Loading branch information
stes authored Oct 27, 2024
1 parent e652b9a commit 9898850
Show file tree
Hide file tree
Showing 15 changed files with 18 additions and 42 deletions.
1 change: 0 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@
#
def test_api():
import cebra.distributions
from cebra.distributions import TimedeltaDistribution

cebra.distributions.TimedeltaDistribution
3 changes: 0 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse

import pytest
3 changes: 1 addition & 2 deletions tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest
import torch
from torch import nn
Expand Down Expand Up @@ -294,7 +293,7 @@ def _sample_dist_matrices(seed):


@pytest.mark.parametrize("seed", [42, 4242, 424242])
def test_infonce(seed):
def test_infonce_check_output_parts(seed):
pos_dist, neg_dist = _sample_dist_matrices(seed)

ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
Expand Down
6 changes: 0 additions & 6 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def test_demo():
@pytest.mark.requires_dataset
def test_hippocampus():
pytest.skip("Outdated")

from cebra.datasets import hippocampus # noqa: F401
dataset = cebra.datasets.init("rat-hippocampus-single")
loader = cebra.data.ContinuousDataLoader(
dataset=dataset,
Expand Down Expand Up @@ -99,8 +97,6 @@ def test_hippocampus():

@pytest.mark.requires_dataset
def test_monkey():
from cebra.datasets import monkey_reaching # noqa: F401

dataset = cebra.datasets.init(
"area2-bump-pos-active-passive",
path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40",
Expand All @@ -111,8 +107,6 @@ def test_monkey():

@pytest.mark.requires_dataset
def test_allen():
from cebra.datasets import allen # noqa: F401

pytest.skip("Test takes too long")

ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111")
Expand Down
1 change: 0 additions & 1 deletion tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#
import glob
import re
import sys

import pytest

Expand Down
6 changes: 3 additions & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"):
continuous = torch.randn(N, d).to(device)

rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device)
qidx = discrete[rand].to(device)
_ = discrete[rand].to(device)
query = continuous[rand] + 0.1 * torch.randn(n, d).to(device)
query = query.to(device)

Expand Down Expand Up @@ -173,7 +173,7 @@ def test_mixed():
discrete, continuous)

reference_idx = distribution.sample_prior(10)
positive_idx = distribution.sample_conditional(reference_idx)
_ = distribution.sample_conditional(reference_idx)

# The conditional distribution p(· | disc, cont) should yield
# samples where the label exactly matches the reference sample.
Expand All @@ -193,7 +193,7 @@ def test_continuous(benchmark):
def _test_distribution(dist):
distribution = dist(continuous)
reference_idx = distribution.sample_prior(10)
positive_idx = distribution.sample_conditional(reference_idx)
_ = distribution.sample_conditional(reference_idx)
return distribution

distribution = _test_distribution(
Expand Down
1 change: 0 additions & 1 deletion tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
# limitations under the License.
#
import numpy as np
import pytest

import cebra
import cebra.grid_search
Expand Down
1 change: 0 additions & 1 deletion tests/test_integration_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
# limitations under the License.
#
import itertools
from typing import List

import pytest
import torch
Expand Down
8 changes: 2 additions & 6 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
import itertools
import pathlib
import pickle
import platform
import tempfile
import unittest
from unittest.mock import patch

import h5py
import hdf5storage
Expand Down Expand Up @@ -125,7 +122,7 @@ def generate_numpy_confounder(filename, dtype):


@register("npz")
def generate_numpy_path(filename, dtype):
def generate_numpy_path_2(filename, dtype):
A = np.arange(1000, dtype=dtype).reshape(10, 100)
np.savez(filename, array=A, other_data="test")
loaded_A = cebra_load.load(pathlib.Path(filename))
Expand Down Expand Up @@ -418,7 +415,7 @@ def generate_csv_path(filename, dtype):

@register_error("csv")
def generate_csv_empty_file(filename, dtype):
with open(filename, "w") as creating_new_csv_file:
with open(filename, "w") as _:
pass
_ = cebra_load.load(filename)

Expand Down Expand Up @@ -619,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype):

@register_error("pkl", "p")
def generate_pickle_no_array(filename, dtype):
A = np.arange(1000, dtype=dtype).reshape(10, 100)
with open(filename, "wb") as f:
pickle.dump({"A": "test_1", "B": "test_2"}, f)
_ = cebra_load.load(filename)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def test_version_check(version, raises):
cebra.models.model._check_torch_version(raise_error=True)


def test_version_check():
raises = not cebra.models.model._check_torch_version(raise_error=False)
def test_version_check_dropout_available():
raises = cebra.models.model._check_torch_version(raise_error=False)
if raises:
assert len(cebra.models.get_options("*dropout*")) == 0
else:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def test_plot_imports():
def test_colormaps():
import matplotlib

import cebra

cmap = matplotlib.colormaps["cebra"]
assert cmap is not None
plt.scatter([1], [2], c=[2], cmap="cebra")
Expand Down Expand Up @@ -241,7 +239,7 @@ def test_compare_models():
_ = cebra_plot.compare_models(models, labels=long_labels, ax=ax)
with pytest.raises(ValueError, match="Invalid.*labels"):
invalid_labels = copy.deepcopy(labels)
ele = invalid_labels.pop()
_ = invalid_labels.pop()
invalid_labels.append(["a"])
_ = cebra_plot.compare_models(models, labels=invalid_labels, ax=ax)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_override():
_Foo1 = test_module.register("foo")(Foo)
assert _Foo1 == Foo
assert _Foo1 != Bar
assert f"foo" in test_module.get_options()
assert "foo" in test_module.get_options()

# Check that the class was actually added to the module
assert (
Expand All @@ -137,15 +137,15 @@ def test_override():
_Foo2 = test_module.register("foo", override=True)(Bar)
assert _Foo2 != Foo
assert _Foo2 == Bar
assert f"foo" in test_module.get_options()
assert "foo" in test_module.get_options()


def test_depreciation():
test_module = _make_registry()
Foo = _make_class()
_Foo1 = test_module.register("foo")(Foo)
assert _Foo1 == Foo
assert f"foo" in test_module.get_options()
assert "foo" in test_module.get_options()

# Registering the same class under different names
# also raises and error
Expand Down
7 changes: 2 additions & 5 deletions tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def test_api(estimator, check):
pytest.skip(f"Model architecture {estimator.model_architecture} "
f"requires longer input sizes than 20 samples.")

success = True
exception = None
num_successful = 0
total_runs = 0
Expand Down Expand Up @@ -334,7 +333,6 @@ def test_sklearn(model_architecture, device):
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
y_c2_s2 = np.random.uniform(0, 1, (800, 2))
y_d = np.random.randint(0, 10, (1000,))
y_d_s2 = np.random.randint(0, 10, (800,))

Expand Down Expand Up @@ -817,7 +815,6 @@ def test_sklearn_full(model_architecture, device, pad_before_transform):
X = np.random.uniform(0, 1, (1000, 50))
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
y_d = np.random.randint(0, 10, (1000,))

# time contrastive
cebra_model.fit(X)
Expand Down Expand Up @@ -883,7 +880,7 @@ def test_sklearn_resampling_model_not_yet_supported(model_architecture, device):

with pytest.raises(ValueError):
cebra_model.fit(X, y_c1)
output = cebra_model.transform(X)
_ = cebra_model.transform(X)


def _iterate_actions():
Expand Down Expand Up @@ -1097,7 +1094,7 @@ def test_move_cpu_to_cuda_device(device):
def test_move_cpu_to_mps_device(device):

if not cebra.helper._is_mps_availabe(torch):
pytest.skip(f"MPS device is not available")
pytest.skip("MPS device is not available")

X = np.random.uniform(0, 1, (10, 5))
cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import itertools

import pytest
import torch
Expand Down Expand Up @@ -100,11 +99,12 @@ def test_single_session(data_name, loader_initfunc, solver_initfunc):
@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
single_session_tests)
def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc):
return # TODO

pytest.skip("Not yet supported")

loader = _get_loader(data_name, loader_initfunc)
model = _make_model(loader.dataset)
behavior_model = _make_behavior_model(loader.dataset)
behavior_model = _make_behavior_model(loader.dataset) # noqa: F841

criterion = cebra.models.InfoNCE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc):

@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc",
multi_session_tests)
def test_multi_session(data_name, loader_initfunc, solver_initfunc):
def test_multi_session_2(data_name, loader_initfunc, solver_initfunc):
loader = _get_loader(data_name, loader_initfunc)
criterion = cebra.models.InfoNCE()
model = nn.ModuleList(
Expand Down
1 change: 0 additions & 1 deletion tests/test_usecases.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"""

import itertools
import pickle

import numpy as np
import pytest
Expand Down

0 comments on commit 9898850

Please sign in to comment.