diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 0d932e9..0661507 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -5,7 +5,6 @@ import torch import sqfa - from make_examples import rotated_classes_dataset MAX_EPOCHS = 100 diff --git a/tests/test_training.py b/tests/test_training.py index fbb2ca3..a259148 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -4,10 +4,10 @@ import torch import sqfa - from make_examples import rotated_classes_dataset MAX_EPOCHS = 50 +MAX_TRIES = 3 torch.manual_seed(0) @@ -68,9 +68,8 @@ def test_training_method(make_dataset, feature_noise, n_filters, pairwise): return else: n_tries = 0 - max_tries = 3 loss_decreased = False - while n_tries < 3 and not loss_decreased: + while n_tries < MAX_TRIES and not loss_decreased: loss, time = model.fit( data_scatters=covariances, lr=0.1,