Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions overcomplete/optimization/archetypal_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""
Archetypal Analysis (AA) module.

We use the following notation:
- A: data matrix, tensor of shape (n_samples, n_features)
- Z: codes matrix, tensor of shape (n_samples, nb_concepts)
- W: coefficient matrix, tensor of shape (nb_concepts, n_samples)
- D: dictionary matrix, computed as D = W @ A

The objective is:
min_{Z,W} ||A - Z D||_F^2
subject to Z in Δ^nb_concepts and D in conv(A)

Say it otherwise, Z row stochastic and W row stochastic.
Currently supports projected gradient descent (PGD) solver.

For a complete and more faithful implementation, see the great SPAM toolbox:
https://thoth.inrialpes.fr/people/mairal/spams/
"""

import torch
from tqdm import tqdm

from .base import BaseOptimDictionaryLearning
from .utils import stopping_criterion


def project_simplex(W, temperature=0.0, dim=1):
"""
Project matrix W onto the simplex using softmax.

Parameters
----------
W : torch.Tensor
Input tensor.
temperature : float, optional
Temperature parameter for scaling before softmax, by default 0.0.
dim : int, optional
Dimension along which to apply softmax, by default 1.

Returns
-------
torch.Tensor
Row- or column-stochastic matrix.
"""
return torch.softmax(W / torch.exp(torch.tensor(temperature)), dim=dim)


def aa_pgd_solver(A, Z, W, lr=1e-2, update_Z=True, update_W=True,
max_iter=500, tol=1e-5, verbose=False):
"""
Alternating Projected Gradient Descent (PGD) solver for Archetypal Analysis.

Parameters
----------
A : torch.Tensor
Input data matrix (n_samples, n_features).
Z : torch.Tensor
Initial codes matrix (n_samples, nb_concepts).
W : torch.Tensor
Initial coefficient matrix (nb_concepts, n_samples).
lr : float
Learning rate.
update_Z : bool
Whether to update Z.
update_W : bool
Whether to update W.
max_iter : int
Maximum number of iterations.
tol : float
Convergence tolerance.
verbose : bool
Whether to display a progress bar.

Returns
-------
Z : torch.Tensor
Final codes matrix.
W : torch.Tensor
Final coefficient matrix.
"""
if update_Z:
Z = torch.nn.Parameter(Z)
if update_W:
W = torch.nn.Parameter(W)

params = [p for p in [Z, W] if isinstance(p, torch.nn.Parameter)]
optimizer = torch.optim.Adam(params, lr=lr)

for _ in tqdm(range(max_iter), disable=not verbose):
optimizer.zero_grad()
D = W @ A
loss = torch.mean((A - Z @ D).pow(2))

if update_Z:
Z_old = Z.data.clone()

loss.backward()
optimizer.step()

with torch.no_grad():
if update_Z:
Z.copy_(project_simplex(Z))
if update_W:
W.copy_(project_simplex(W))

if update_Z and tol > 0 and stopping_criterion(Z, Z_old, tol):
break

return Z.detach(), W.detach()


class ArchetypalAnalysis(BaseOptimDictionaryLearning):
"""
PyTorch Archetypal Analysis Dictionary Learning model.

Objective:
min_{Z,W} ||A - Z D||_F^2 with D = W A,
rows(Z) simplex, rows(W) simplex.

Parameters
----------
nb_concepts : int
Number of archetypes (concepts).
device : str, optional
Computation device.
tol : float, optional
Convergence tolerance.
solver : str, optional
Solver to use ('pgd').
verbose : bool, optional
Verbosity flag.
"""
_SOLVERS = {
'pgd': aa_pgd_solver,
}

def __init__(self, nb_concepts, device='cpu', tol=1e-4, solver='pgd', verbose=False):
super().__init__(nb_concepts, device)
assert solver in self._SOLVERS, f"Unknown solver '{solver}'"
self.tol = tol
self.verbose = verbose
self.solver = solver
self.solver_fn = self._SOLVERS[solver]

def encode(self, A, max_iter=300, tol=None):
"""
Encode the input data matrix into codes Z using fixed dictionary D.

Parameters
----------
A : torch.Tensor
Input data matrix (n_samples, n_features).
max_iter : int, optional
Maximum number of solver iterations.
tol : float, optional
Convergence tolerance.

Returns
-------
torch.Tensor
Codes matrix Z (n_samples, nb_concepts).
"""
self._assert_fitted()
A = A.to(self.device)
tol = tol or self.tol
Z = self.init_random_z(A)
Z, _ = self.solver_fn(A, Z, self.W, update_Z=True, update_W=False,
max_iter=max_iter, tol=tol, verbose=self.verbose)
return Z

def decode(self, Z):
"""
Decode the codes matrix Z into reconstructed data using dictionary D.

Parameters
----------
Z : torch.Tensor
Codes matrix (n_samples, nb_concepts).

Returns
-------
torch.Tensor
Reconstructed data matrix (n_samples, n_features).
"""
self._assert_fitted()
return Z.to(self.device) @ self.D

def fit(self, A, max_iter=500):
"""
Fit the AA model by jointly optimizing Z and W.

Parameters
----------
A : torch.Tensor
Input data matrix (n_samples, n_features).
max_iter : int, optional
Maximum number of solver iterations.

Returns
-------
Z : torch.Tensor
Final codes matrix (n_samples, nb_concepts).
D : torch.Tensor
Learned dictionary matrix (nb_concepts, n_features).
"""
A = A.to(self.device)
Z = self.init_random_z(A)
W = self.init_random_w(A)
Z, W = self.solver_fn(A, Z, W, update_Z=True, update_W=True,
max_iter=max_iter, tol=self.tol, verbose=self.verbose)
self.register_buffer('W', W)
self.register_buffer('D', W @ A)
self._set_fitted()
return Z, self.D

def get_dictionary(self):
"""
Return the learned dictionary D = W @ A.

Returns
-------
torch.Tensor
Dictionary matrix (nb_concepts, n_features).
"""
self._assert_fitted()
return self.D

def init_random_z(self, A):
"""
Initialize the codes matrix Z with random values projected onto the simplex.

Parameters
----------
A : torch.Tensor
Input data matrix (n_samples, n_features).

Returns
-------
torch.Tensor
Initialized codes matrix (n_samples, nb_concepts).
"""
mu = torch.sqrt(torch.mean(torch.abs(A)) / self.nb_concepts)
Z = torch.randn(A.shape[0], self.nb_concepts, device=self.device) * mu
return project_simplex(Z)

def init_random_w(self, A):
"""
Initialize the coefficient matrix W with random values projected onto the simplex.

Parameters
----------
A : torch.Tensor
Input data matrix (n_samples, n_features).

Returns
-------
torch.Tensor
Initialized coefficient matrix (nb_concepts, n_samples).
"""
mu = torch.sqrt(torch.mean(torch.abs(A)) / self.nb_concepts)
W = torch.randn(self.nb_concepts, A.shape[0], device=self.device) * mu
return project_simplex(W)
2 changes: 1 addition & 1 deletion overcomplete/optimization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def batched_matrix_nnls(D, X, max_iter=50, tol=1e-5, Z_init=None):
"""
batched non-negative least squares (nnls) via projected gradient descent.
Batched non-negative least squares (nnls) via projected gradient descent.

solves for Z in:
min_Z ||Z @ D - X||^2 subject to Z >= 0
Expand Down
5 changes: 3 additions & 2 deletions overcomplete/visualization/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def clip_percentile(img, percentile=0.1, clip_method='nearest'):
np.percentile(img, 100 - percentile, method=clip_method),)


def show(img, **kwargs):
def show(img, norm=True, **kwargs):
"""
Display an image with normalization and channels in the last dimension.

Expand All @@ -197,6 +197,7 @@ def show(img, **kwargs):
None
"""
img = np_channel_last(img)
img = normalize(img)
if norm:
img = normalize(img)
plt.imshow(img, **kwargs)
plt.axis('off')
90 changes: 90 additions & 0 deletions tests/optimization/test_archetypal_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import torch
import numpy as np

from overcomplete.metrics import relative_avg_l2_loss
from overcomplete.optimization.archetypal_analysis import ArchetypalAnalysis, project_simplex

data_shape = (50, 10)
nb_concepts = 5
A = torch.rand(data_shape, dtype=torch.float32)


def test_archetypal_initialization():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
assert model.nb_concepts == nb_concepts


def test_archetypal_fit_shapes():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
Z, D = model.fit(A)
assert Z.shape == (data_shape[0], nb_concepts)
assert D.shape == (nb_concepts, data_shape[1])


def test_archetypal_encoding():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
model.fit(A)
Z = model.encode(A)
assert Z.shape == (data_shape[0], nb_concepts)
row_sums = torch.sum(Z, dim=1)
assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-3)


def test_archetypal_decoding():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
model.fit(A)
Z = model.encode(A)
A_hat = model.decode(Z)
assert A_hat.shape == A.shape


def test_archetypal_simplex_W():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
model.fit(A)
W = model.W
row_sums = torch.sum(W, dim=1)
assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-3)
assert torch.all(W >= 0)


@pytest.mark.flaky(reruns=9, reruns_delay=0)
def test_archetypal_loss_decrease():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
Z0 = model.init_random_z(A)
W0 = model.init_random_w(A)
D0 = W0 @ A
initial_loss = torch.mean((A - Z0 @ D0).pow(2)).item()
Z, D = model.fit(A)
final_loss = torch.mean((A - Z @ D).pow(2)).item()
assert final_loss < initial_loss


def test_project_simplex_behavior():
W = torch.randn(20, 10)
P = project_simplex(W)
assert torch.allclose(torch.sum(P, dim=1), torch.ones(P.size(0)), atol=1e-3)
assert torch.all(P >= 0)


def test_archetypal_zero_input():
A_zero = torch.zeros_like(A)
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
Z, D = model.fit(A_zero)
assert torch.allclose(Z@D, A_zero, atol=1e-4)


def test_archetypal_doubly_stoch():
model = ArchetypalAnalysis(nb_concepts=nb_concepts)
Z, D = model.fit(A)
W = model.W

assert Z.shape == (data_shape[0], nb_concepts)
assert D.shape == (nb_concepts, data_shape[1])
assert W.shape == (nb_concepts, data_shape[0])

assert torch.all(Z >= 0)
assert torch.all(W >= 0)

assert torch.allclose(torch.sum(Z, dim=1), torch.ones(data_shape[0]), atol=1e-3)
assert torch.allclose(torch.sum(W, dim=1), torch.ones(nb_concepts), atol=1e-3)
1 change: 1 addition & 0 deletions tests/optimization/test_multi_nnls.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_batch_shape():
assert Z.shape == (n, k), f"unexpected shape: {Z.shape}"


@pytest.mark.flaky(reruns=9, reruns_delay=0)
def test_scipy_nnls_vs_pgd():
n, k, d = 16, 8, 6
tol = 1e-2
Expand Down
2 changes: 1 addition & 1 deletion tests/optimization/test_save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_methods_save_and_load(methods):
D = model.D

torch.save(model, 'test_optimization_model.pth')
model = torch.load('test_optimization_model.pth')
model = torch.load('test_optimization_model.pth', map_location='cpu', weights_only=False)

assert epsilon_equal(model.D, D), "Loaded model does not produce the same results."

Expand Down
Loading