From cbfd87dea7cf34445707faa394bbefababca5f8b Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Thu, 31 Jul 2025 22:31:20 -0400 Subject: [PATCH 1/4] optim: introduce archetypal analysis method & associated test suite --- .../optimization/archetypal_analysis.py | 263 ++++++++++++++++++ .../optimization/test_archetypal_analysis.py | 89 ++++++ tests/optimization/test_save_and_load.py | 2 +- tests/sae/test_archetypal_dict.py | 2 +- tests/sae/test_batchtopk.py | 4 +- tests/sae/test_save_and_load.py | 6 +- 6 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 overcomplete/optimization/archetypal_analysis.py create mode 100644 tests/optimization/test_archetypal_analysis.py diff --git a/overcomplete/optimization/archetypal_analysis.py b/overcomplete/optimization/archetypal_analysis.py new file mode 100644 index 0000000..4f391a4 --- /dev/null +++ b/overcomplete/optimization/archetypal_analysis.py @@ -0,0 +1,263 @@ +""" +Archetypal Analysis (AA) module. + +We use the following notation: +- A: data matrix, tensor of shape (n_samples, n_features) +- Z: codes matrix, tensor of shape (n_samples, nb_concepts) +- W: coefficient matrix, tensor of shape (nb_concepts, n_samples) +- D: dictionary matrix, computed as D = W @ A + +The objective is: + min_{Z,W} ||A - Z D||_F^2 + subject to Z in Δ^nb_concepts and D in conv(A) + +Say it otherwise, Z row stochastic and W row stochastic. +Currently supports projected gradient descent (PGD) solver. + +For a complete and more faithful implementation, see the great SPAM toolbox: +https://thoth.inrialpes.fr/people/mairal/spams/ +""" + +import torch +from tqdm import tqdm + +from .base import BaseOptimDictionaryLearning +from .utils import stopping_criterion + + +def project_simplex(W, temperature=0.0, dim=1): + """ + Project matrix W onto the simplex using softmax. + + Parameters + ---------- + W : torch.Tensor + Input tensor. + temperature : float, optional + Temperature parameter for scaling before softmax, by default 0.0. + dim : int, optional + Dimension along which to apply softmax, by default 1. + + Returns + ------- + torch.Tensor + Row- or column-stochastic matrix. + """ + return torch.softmax(W / torch.exp(torch.tensor(temperature)), dim=dim) + + +def aa_pgd_solver(A, Z, W, lr=1e-2, update_Z=True, update_W=True, + max_iter=500, tol=1e-5, verbose=False): + """ + Alternating Projected Gradient Descent (PGD) solver for Archetypal Analysis. + + Parameters + ---------- + A : torch.Tensor + Input data matrix (n_samples, n_features). + Z : torch.Tensor + Initial codes matrix (n_samples, nb_concepts). + W : torch.Tensor + Initial coefficient matrix (nb_concepts, n_samples). + lr : float + Learning rate. + update_Z : bool + Whether to update Z. + update_W : bool + Whether to update W. + max_iter : int + Maximum number of iterations. + tol : float + Convergence tolerance. + verbose : bool + Whether to display a progress bar. + + Returns + ------- + Z : torch.Tensor + Final codes matrix. + W : torch.Tensor + Final coefficient matrix. + """ + if update_Z: + Z = torch.nn.Parameter(Z) + if update_W: + W = torch.nn.Parameter(W) + + params = [p for p in [Z, W] if isinstance(p, torch.nn.Parameter)] + optimizer = torch.optim.Adam(params, lr=lr) + + for _ in tqdm(range(max_iter), disable=not verbose): + optimizer.zero_grad() + D = W @ A + loss = torch.mean((A - Z @ D).pow(2)) + + if update_Z: + Z_old = Z.data.clone() + + loss.backward() + optimizer.step() + + with torch.no_grad(): + if update_Z: + Z.copy_(project_simplex(Z)) + if update_W: + W.copy_(project_simplex(W)) + + if update_Z and tol > 0 and stopping_criterion(Z, Z_old, tol): + break + + return Z.detach(), W.detach() + + +class ArchetypalAnalysis(BaseOptimDictionaryLearning): + """ + PyTorch Archetypal Analysis Dictionary Learning model. + + Objective: + min_{Z,W} ||A - Z D||_F^2 with D = W A, + rows(Z) simplex, rows(W) simplex. + + Parameters + ---------- + nb_concepts : int + Number of archetypes (concepts). + device : str, optional + Computation device. + tol : float, optional + Convergence tolerance. + solver : str, optional + Solver to use ('pgd'). + verbose : bool, optional + Verbosity flag. + """ + _SOLVERS = { + 'pgd': aa_pgd_solver, + } + + def __init__(self, nb_concepts, device='cpu', tol=1e-4, solver='pgd', verbose=False): + super().__init__(nb_concepts, device) + assert solver in self._SOLVERS, f"Unknown solver '{solver}'" + self.tol = tol + self.verbose = verbose + self.solver = solver + self.solver_fn = self._SOLVERS[solver] + + def encode(self, A, max_iter=300, tol=None): + """ + Encode the input data matrix into codes Z using fixed dictionary D. + + Parameters + ---------- + A : torch.Tensor + Input data matrix (n_samples, n_features). + max_iter : int, optional + Maximum number of solver iterations. + tol : float, optional + Convergence tolerance. + + Returns + ------- + torch.Tensor + Codes matrix Z (n_samples, nb_concepts). + """ + self._assert_fitted() + A = A.to(self.device) + tol = tol or self.tol + Z = self.init_random_z(A) + Z, _ = self.solver_fn(A, Z, self.W, update_Z=True, update_W=False, + max_iter=max_iter, tol=tol, verbose=self.verbose) + return Z + + def decode(self, Z): + """ + Decode the codes matrix Z into reconstructed data using dictionary D. + + Parameters + ---------- + Z : torch.Tensor + Codes matrix (n_samples, nb_concepts). + + Returns + ------- + torch.Tensor + Reconstructed data matrix (n_samples, n_features). + """ + self._assert_fitted() + return Z.to(self.device) @ self.D + + def fit(self, A, max_iter=500): + """ + Fit the AA model by jointly optimizing Z and W. + + Parameters + ---------- + A : torch.Tensor + Input data matrix (n_samples, n_features). + max_iter : int, optional + Maximum number of solver iterations. + + Returns + ------- + Z : torch.Tensor + Final codes matrix (n_samples, nb_concepts). + D : torch.Tensor + Learned dictionary matrix (nb_concepts, n_features). + """ + A = A.to(self.device) + Z = self.init_random_z(A) + W = self.init_random_w(A) + Z, W = self.solver_fn(A, Z, W, update_Z=True, update_W=True, + max_iter=max_iter, tol=self.tol, verbose=self.verbose) + self.register_buffer('W', W) + self.register_buffer('D', W @ A) + self._set_fitted() + return Z, self.D + + def get_dictionary(self): + """ + Return the learned dictionary D = W @ A. + + Returns + ------- + torch.Tensor + Dictionary matrix (nb_concepts, n_features). + """ + self._assert_fitted() + return self.D + + def init_random_z(self, A): + """ + Initialize the codes matrix Z with random values projected onto the simplex. + + Parameters + ---------- + A : torch.Tensor + Input data matrix (n_samples, n_features). + + Returns + ------- + torch.Tensor + Initialized codes matrix (n_samples, nb_concepts). + """ + mu = torch.sqrt(torch.mean(torch.abs(A)) / self.nb_concepts) + Z = torch.randn(A.shape[0], self.nb_concepts, device=self.device) * mu + return project_simplex(Z) + + def init_random_w(self, A): + """ + Initialize the coefficient matrix W with random values projected onto the simplex. + + Parameters + ---------- + A : torch.Tensor + Input data matrix (n_samples, n_features). + + Returns + ------- + torch.Tensor + Initialized coefficient matrix (nb_concepts, n_samples). + """ + mu = torch.sqrt(torch.mean(torch.abs(A)) / self.nb_concepts) + W = torch.randn(self.nb_concepts, A.shape[0], device=self.device) * mu + return project_simplex(W) diff --git a/tests/optimization/test_archetypal_analysis.py b/tests/optimization/test_archetypal_analysis.py new file mode 100644 index 0000000..645f019 --- /dev/null +++ b/tests/optimization/test_archetypal_analysis.py @@ -0,0 +1,89 @@ +import pytest +import torch +import numpy as np + +from overcomplete.metrics import relative_avg_l2_loss +from overcomplete.optimization.archetypal_analysis import ArchetypalAnalysis, project_simplex + +data_shape = (50, 10) +nb_concepts = 5 +A = torch.rand(data_shape, dtype=torch.float32) + + +def test_archetypal_initialization(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + assert model.nb_concepts == nb_concepts + + +def test_archetypal_fit_shapes(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + Z, D = model.fit(A) + assert Z.shape == (data_shape[0], nb_concepts) + assert D.shape == (nb_concepts, data_shape[1]) + + +def test_archetypal_encoding(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + model.fit(A) + Z = model.encode(A) + assert Z.shape == (data_shape[0], nb_concepts) + row_sums = torch.sum(Z, dim=1) + assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-3) + + +def test_archetypal_decoding(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + model.fit(A) + Z = model.encode(A) + A_hat = model.decode(Z) + assert A_hat.shape == A.shape + + +def test_archetypal_simplex_W(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + model.fit(A) + W = model.W + row_sums = torch.sum(W, dim=1) + assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-3) + assert torch.all(W >= 0) + + +def test_archetypal_loss_decrease(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + Z0 = model.init_random_z(A) + W0 = model.init_random_w(A) + D0 = W0 @ A + initial_loss = torch.mean((A - Z0 @ D0).pow(2)).item() + Z, D = model.fit(A) + final_loss = torch.mean((A - Z @ D).pow(2)).item() + assert final_loss < initial_loss + + +def test_project_simplex_behavior(): + W = torch.randn(20, 10) + P = project_simplex(W) + assert torch.allclose(torch.sum(P, dim=1), torch.ones(P.size(0)), atol=1e-3) + assert torch.all(P >= 0) + + +def test_archetypal_zero_input(): + A_zero = torch.zeros_like(A) + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + Z, D = model.fit(A_zero) + assert torch.allclose(Z@D, A_zero, atol=1e-4) + + +def test_archetypal_doubly_stoch(): + model = ArchetypalAnalysis(nb_concepts=nb_concepts) + Z, D = model.fit(A) + W = model.W + + assert Z.shape == (data_shape[0], nb_concepts) + assert D.shape == (nb_concepts, data_shape[1]) + assert W.shape == (nb_concepts, data_shape[0]) + + assert torch.all(Z >= 0) + assert torch.all(W >= 0) + + assert torch.allclose(torch.sum(Z, dim=1), torch.ones(data_shape[0]), atol=1e-3) + assert torch.allclose(torch.sum(W, dim=1), torch.ones(nb_concepts), atol=1e-3) diff --git a/tests/optimization/test_save_and_load.py b/tests/optimization/test_save_and_load.py index 2c8e439..371094f 100644 --- a/tests/optimization/test_save_and_load.py +++ b/tests/optimization/test_save_and_load.py @@ -23,7 +23,7 @@ def test_methods_save_and_load(methods): D = model.D torch.save(model, 'test_optimization_model.pth') - model = torch.load('test_optimization_model.pth') + model = torch.load('test_optimization_model.pth', map_location='cpu', weights_only=False) assert epsilon_equal(model.D, D), "Loaded model does not produce the same results." diff --git a/tests/sae/test_archetypal_dict.py b/tests/sae/test_archetypal_dict.py index c2630a9..a09e418 100644 --- a/tests/sae/test_archetypal_dict.py +++ b/tests/sae/test_archetypal_dict.py @@ -259,7 +259,7 @@ def test_fused(nb_concepts, dimensions, nb_points, tmp_path): torch.save(layer, model_path) # Reload and validate - layer = torch.load(model_path, map_location="cpu") + layer = torch.load(model_path, map_location="cpu", weights_only=False) assert isinstance(layer, RelaxedArchetypalDictionary) assert layer._fused_dictionary is not None diff --git a/tests/sae/test_batchtopk.py b/tests/sae/test_batchtopk.py index f3a6511..f95977d 100644 --- a/tests/sae/test_batchtopk.py +++ b/tests/sae/test_batchtopk.py @@ -171,7 +171,7 @@ def test_threshold_persistence(input_size, nb_concepts, top_k, tmp_path): torch.save(model, model_path) # Load and check threshold consistency - model_loaded = torch.load(model_path, map_location="cpu").eval() + model_loaded = torch.load(model_path, map_location="cpu", weights_only=False).eval() assert isinstance(model_loaded, BatchTopKSAE) assert model.running_threshold is not None, "Running threshold was not initialized!" @@ -208,7 +208,7 @@ def test_saving_loading_batchtopk_sae(input_size, nb_concepts, top_k, tmp_path): torch.save(model, model_path) # Load model - model_loaded = torch.load(model_path, map_location="cpu").eval() + model_loaded = torch.load(model_path, map_location="cpu", weights_only=False).eval() assert isinstance(model_loaded, BatchTopKSAE) # Ensure threshold is correctly persisted diff --git a/tests/sae/test_save_and_load.py b/tests/sae/test_save_and_load.py index 7d24b73..59177e3 100644 --- a/tests/sae/test_save_and_load.py +++ b/tests/sae/test_save_and_load.py @@ -24,7 +24,7 @@ def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): torch.save(layer, model_path) # Reload and validate - layer = torch.load(model_path, map_location="cpu") + layer = torch.load(model_path, map_location="cpu", weights_only=False) assert isinstance(layer, DictionaryLayer) # Check consistency after loading @@ -47,7 +47,7 @@ def test_save_and_load_sae(sae_class, tmp_path): torch.save(model, model_path) # Load and validate - model_loaded = torch.load(model_path, map_location="cpu") + model_loaded = torch.load(model_path, map_location="cpu", weights_only=False) assert isinstance(model_loaded, sae_class) # Run inference again @@ -75,7 +75,7 @@ def test_eval_and_save_sae(sae_class, tmp_path): torch.save(model, model_path) # Load, set to eval mode, and validate - model_loaded = torch.load(model_path, map_location="cpu").eval() + model_loaded = torch.load(model_path, map_location="cpu", weights_only=False).eval() assert isinstance(model_loaded, sae_class) # Run inference again From 55d59bc37665408a153eebd86d6dd75074eedb9a Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Mon, 3 Nov 2025 13:21:25 -0500 Subject: [PATCH 2/4] viz: optional normalization for show func --- overcomplete/visualization/plot_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/overcomplete/visualization/plot_utils.py b/overcomplete/visualization/plot_utils.py index 1d94f91..5010d5a 100644 --- a/overcomplete/visualization/plot_utils.py +++ b/overcomplete/visualization/plot_utils.py @@ -181,7 +181,7 @@ def clip_percentile(img, percentile=0.1, clip_method='nearest'): np.percentile(img, 100 - percentile, method=clip_method),) -def show(img, **kwargs): +def show(img, normalize=True, **kwargs): """ Display an image with normalization and channels in the last dimension. @@ -197,6 +197,7 @@ def show(img, **kwargs): None """ img = np_channel_last(img) - img = normalize(img) + if normalize: + img = normalize(img) plt.imshow(img, **kwargs) plt.axis('off') From 99df9c8e3331fab8e69d8844db27bf8f93416e30 Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Mon, 3 Nov 2025 13:50:21 -0500 Subject: [PATCH 3/4] viz: argument wording; test: fix unterminated string literal --- overcomplete/optimization/utils.py | 2 +- overcomplete/visualization/plot_utils.py | 4 ++-- tests/sae/test_jumprelu.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/overcomplete/optimization/utils.py b/overcomplete/optimization/utils.py index 2610655..a9aa906 100644 --- a/overcomplete/optimization/utils.py +++ b/overcomplete/optimization/utils.py @@ -7,7 +7,7 @@ def batched_matrix_nnls(D, X, max_iter=50, tol=1e-5, Z_init=None): """ - batched non-negative least squares (nnls) via projected gradient descent. + Batched non-negative least squares (nnls) via projected gradient descent. solves for Z in: min_Z ||Z @ D - X||^2 subject to Z >= 0 diff --git a/overcomplete/visualization/plot_utils.py b/overcomplete/visualization/plot_utils.py index 5010d5a..c432d00 100644 --- a/overcomplete/visualization/plot_utils.py +++ b/overcomplete/visualization/plot_utils.py @@ -181,7 +181,7 @@ def clip_percentile(img, percentile=0.1, clip_method='nearest'): np.percentile(img, 100 - percentile, method=clip_method),) -def show(img, normalize=True, **kwargs): +def show(img, norm=True, **kwargs): """ Display an image with normalization and channels in the last dimension. @@ -197,7 +197,7 @@ def show(img, normalize=True, **kwargs): None """ img = np_channel_last(img) - if normalize: + if norm: img = normalize(img) plt.imshow(img, **kwargs) plt.axis('off') diff --git a/tests/sae/test_jumprelu.py b/tests/sae/test_jumprelu.py index 923fb1f..c7d3c72 100644 --- a/tests/sae/test_jumprelu.py +++ b/tests/sae/test_jumprelu.py @@ -145,5 +145,4 @@ def test_jumprelu_backward_specific_values(): # minus sign come from the fact that: # 'increasing the threshold will decrease the output' expected_grad_threshold = torch.tensor([-1.0, -1.0, 0.0], dtype=torch.float32) - assert torch.allclose(threshold.grad, expected_grad_threshold), f"Expected grad_threshold { - expected_grad_threshold}, but got {threshold.grad}" + assert torch.allclose(threshold.grad, expected_grad_threshold), "Grad Threshold error" From 93a8c1c2fc47213bd0821e78568f4952dc44f310 Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Mon, 3 Nov 2025 14:05:12 -0500 Subject: [PATCH 4/4] test: introduce mark.flaky for stochastic instances --- tests/optimization/test_archetypal_analysis.py | 1 + tests/optimization/test_multi_nnls.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/optimization/test_archetypal_analysis.py b/tests/optimization/test_archetypal_analysis.py index 645f019..26e9a5b 100644 --- a/tests/optimization/test_archetypal_analysis.py +++ b/tests/optimization/test_archetypal_analysis.py @@ -48,6 +48,7 @@ def test_archetypal_simplex_W(): assert torch.all(W >= 0) +@pytest.mark.flaky(reruns=9, reruns_delay=0) def test_archetypal_loss_decrease(): model = ArchetypalAnalysis(nb_concepts=nb_concepts) Z0 = model.init_random_z(A) diff --git a/tests/optimization/test_multi_nnls.py b/tests/optimization/test_multi_nnls.py index a781f41..59a0020 100644 --- a/tests/optimization/test_multi_nnls.py +++ b/tests/optimization/test_multi_nnls.py @@ -88,6 +88,7 @@ def test_batch_shape(): assert Z.shape == (n, k), f"unexpected shape: {Z.shape}" +@pytest.mark.flaky(reruns=9, reruns_delay=0) def test_scipy_nnls_vs_pgd(): n, k, d = 16, 8, 6 tol = 1e-2