Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
dherrera1911 committed Feb 4, 2025
1 parent e880d11 commit 763e0b0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/sqfa/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def fisher_rao_lower_bound_sq(means, covariances):
distance_squared = affine_invariant_sq(embedding, embedding)
return distance_squared


def fisher_rao_lower_bound(means, covariances):
"""
Compute the Calvo & Oller lower bound of the Fisher-Rao squared
Expand All @@ -186,4 +187,3 @@ def fisher_rao_lower_bound(means, covariances):
Shape (n_classes, n_classes), the lower bound of the Fisher-Rao distance.
"""
return torch.sqrt(fisher_rao_lower_bound_sq(means, covariances) + EPSILON)

16 changes: 8 additions & 8 deletions src/sqfa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def _check_statistics(data_statistics):


class SecondMomentsSQFA(nn.Module):
"""Second-moments Supervised Quadratic Feature Analysis (SQFA) model.
"""
Second-moments Supervised Quadratic Feature Analysis (SQFA) model.
This version of the model uses only the second moment matrices of the data,
and distances in the SPD manifold."""
and distances in the SPD manifold.
"""

def __init__(
self,
Expand Down Expand Up @@ -319,7 +321,7 @@ def fit(

# Store initial filters
filters_original = self.filters.detach().clone()
noise_original = self.diagonal_noise.detach().clone()[0,0]
noise_original = self.diagonal_noise.detach().clone()[0, 0]

# Require n_pairs to be even
if self.filters.shape[0] % 2 != 0:
Expand All @@ -332,18 +334,16 @@ def fit(
training_time = torch.tensor([])
filters_last_trained = torch.zeros(0)
for i in range(n_pairs):

# Re-initialize filters, to be a tensor of shape (2*(i+1), n_dim)
# with the first 2*i filters being the filters from the previous
# iteration
filters_last_trained = self.filters.detach().clone()
if i == 0:
filters_new_init = filters_original[:2]
else:
filters_new_init = torch.cat((
filters_last_trained,
filters_original[2 * i : 2 * (i + 1)]
))
filters_new_init = torch.cat(
(filters_last_trained, filters_original[2 * i : 2 * (i + 1)])
)
remove_parametrizations(self, "filters")
self.filters = nn.Parameter(filters_new_init)
self._add_constraint(constraint=self.constraint)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from make_examples import sample_spd
from sqfa.distances import (
affine_invariant_sq,
log_euclidean_sq,
fisher_rao_lower_bound_sq,
log_euclidean_sq,
)

torch.set_default_dtype(torch.float64)
Expand All @@ -19,6 +19,7 @@ def sample_spd_matrices(n_classes, n_dim):
spd_mat = sample_spd(n_classes, n_dim)
return spd_mat


@pytest.fixture(scope="function")
def sample_vectors(n_classes, n_dim):
"""Generate a tensor of vectors."""
Expand Down Expand Up @@ -115,4 +116,3 @@ def test_fisher_rao_sq(sample_spd_matrices, sample_vectors, n_classes, n_dim):
assert torch.allclose(
get_diag(fr_distances), torch.zeros(n_classes), atol=1e-5
), "The diagonal of the self-distance matrix for AIRM is not zero"

4 changes: 2 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_training_function(model_type):
)
# Make dictionary with covariance and means input
stats_dict = {
"covariances": covariances,
"means": torch.zeros_like(covariances[:,:,0]),
"covariances": covariances,
"means": torch.zeros_like(covariances[:, :, 0]),
}
loss, time = sqfa._optim.fitting_loop(
model=model,
Expand Down

0 comments on commit 763e0b0

Please sign in to comment.