diff --git a/overcomplete/optimization/nmf.py b/overcomplete/optimization/nmf.py index 2e3ddb8..90bbe8e 100644 --- a/overcomplete/optimization/nmf.py +++ b/overcomplete/optimization/nmf.py @@ -9,8 +9,9 @@ from tqdm import tqdm import torch -from scipy.optimize import nnls as scipy_nnls from sklearn.decomposition._nmf import _initialize_nmf +from scipy.optimize import nnls as scipy_nnls + from .base import BaseOptimDictionaryLearning from .utils import matrix_nnls, stopping_criterion, _assert_shapes diff --git a/overcomplete/sae/base.py b/overcomplete/sae/base.py index 1b406b0..bc4ebd2 100644 --- a/overcomplete/sae/base.py +++ b/overcomplete/sae/base.py @@ -9,6 +9,7 @@ from ..base import BaseDictionaryLearning from .dictionary import DictionaryLayer from .factory import EncoderFactory +from .modules import TieableEncoder class SAE(BaseDictionaryLearning): @@ -172,3 +173,53 @@ def fit(self, x): """ raise NotImplementedError('SAE does not support fit method. You have to train the model \ using a custom training loop.') + + def tied(self, bias=False): + """ + Tie encoder weights to dictionary (use D^T as encoder). + + Parameters + ---------- + bias : bool, optional + Whether to include bias in encoder, by default False. + + Returns + ------- + self + Returns self for method chaining. + """ + self.encoder = TieableEncoder( + in_dimensions=self.dictionary.in_dimensions, + nb_concepts=self.nb_concepts, + bias=bias, + tied_to=self.dictionary, + device=self.device + ) + return self + + def untied(self, bias=False, copy_from_dictionary=True): + """ + Create a new encoder with weight from the current dictionary (or random init). + + Parameters + ---------- + bias : bool, optional + Whether to include bias in encoder, by default False. + copy_from_dictionary : bool, optional + If True, initialize encoder with current dictionary weights, by default True. + + Returns + ------- + self + Returns self for method chaining. + """ + weight_init = self.get_dictionary().clone().detach() if copy_from_dictionary else None + self.encoder = TieableEncoder( + in_dimensions=self.dictionary.in_dimensions, + nb_concepts=self.nb_concepts, + bias=bias, + tied_to=None, + weight_init=weight_init, + device=self.device + ) + return self diff --git a/overcomplete/sae/kernels.py b/overcomplete/sae/kernels.py index 2e39a52..469437e 100644 --- a/overcomplete/sae/kernels.py +++ b/overcomplete/sae/kernels.py @@ -6,7 +6,6 @@ """ import torch -import matplotlib.pyplot as plt def rectangle_kernel(x, bandwith): diff --git a/overcomplete/sae/modules.py b/overcomplete/sae/modules.py index bd76a50..65d3491 100644 --- a/overcomplete/sae/modules.py +++ b/overcomplete/sae/modules.py @@ -2,6 +2,7 @@ Collections of torch modules for the encoding of SAE. """ +import torch from torch import nn from einops import rearrange @@ -463,3 +464,74 @@ def forward(self, x): z = self.final_activation(pre_z) return pre_z, z + + +class TieableEncoder(nn.Module): + """ + Linear encoder that can be tied to a dictionary or use independent weights. + + Parameters + ---------- + in_dimensions : int + Input dimensionality. + nb_concepts : int + Number of latent dimensions. + bias : bool, optional + Whether to include bias, by default True. + tied_to : DictionaryLayer, optional + If provided, uses D^T (tied weights). If None, uses independent weights, by default None. + weight_init : torch.Tensor, optional + Initial weights for untied mode, by default None (uses Xavier initialization). + device : str, optional + Device for parameters, by default 'cpu'. + """ + + def __init__(self, in_dimensions, nb_concepts, bias=False, + tied_to=None, weight_init=None, device='cpu'): + super().__init__() + self.tied_to = tied_to + + if tied_to is None: + # untied: create own weights + self.weight = nn.Parameter(torch.empty(nb_concepts, in_dimensions, device=device)) + if weight_init is not None: + self.weight.data.copy_(weight_init) + else: + nn.init.xavier_uniform_(self.weight) + else: + # tied weights: we use the dictionary transpose as encoder + # no weights needed + self.register_parameter('weight', None) + + if bias: + self.bias = nn.Parameter(torch.zeros(nb_concepts, device=device)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + """ + Encode input. + + Parameters + ---------- + x : torch.Tensor + Input of shape (batch_size, in_dimensions). + + Returns + ------- + z_pre : torch.Tensor + Pre-activation codes. + z : torch.Tensor + Activated codes (ReLU applied). + """ + if self.tied_to is not None: + z_pre = x @ self.tied_to.get_dictionary().T + else: + # untied mode: use own weights + z_pre = x @ self.weight.T + + if self.bias is not None: + z_pre = z_pre + self.bias + + z = torch.relu(z_pre) + return z_pre, z diff --git a/overcomplete/sae/train.py b/overcomplete/sae/train.py index 5558bf0..ef36a1b 100644 --- a/overcomplete/sae/train.py +++ b/overcomplete/sae/train.py @@ -91,8 +91,8 @@ def _log_metrics(monitoring, logs, model, z, loss, optimizer): if monitoring > 1: # store directly some z values - # and the params / gradients norms - logs['z'].append(z.detach()[::10]) + # if needed you can even store z statistics here, + # e.g. logs['z'].append(z.detach()[::10]) logs['z_l2'].append(l2(z).item()) logs['dictionary_sparsity'].append(l0_eps(model.get_dictionary()).mean().item()) diff --git a/pyproject.toml b/pyproject.toml index 98dc87d..a3d0649 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ opencv-python = "*" torch = "*" torchvision = "*" timm = "*" +scipy = "*" # requirements dev [tool.poetry.group.dev.dependencies] @@ -29,6 +30,7 @@ pylint = "*" bumpversion = "*" mkdocs = "*" mkdocs-material = "*" +numkdoc = "*" # versioning [tool.bumpversion] diff --git a/tests/sae/test_base_sae.py b/tests/sae/test_base_sae.py index d107cf7..fc7482c 100644 --- a/tests/sae/test_base_sae.py +++ b/tests/sae/test_base_sae.py @@ -2,6 +2,9 @@ import torch from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae.modules import TieableEncoder + +from ..utils import epsilon_equal all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] @@ -57,3 +60,100 @@ def test_sae_device(sae_class): # ensure dictionary is on the meta device dictionary = model.get_dictionary() assert dictionary.device.type == 'meta' + + +def test_tieable_encoder_basic(): + """Test TieableEncoder can be created in both tied and untied modes.""" + input_size = 10 + nb_concepts = 5 + + # Create a dummy dictionary layer + dictionary = DictionaryLayer(input_size, nb_concepts) + + # Test untied mode + encoder_untied = TieableEncoder(input_size, nb_concepts, tied_to=None) + assert encoder_untied.weight is not None + assert encoder_untied.tied_to is None + + # Test tied mode + encoder_tied = TieableEncoder(input_size, nb_concepts, tied_to=dictionary) + assert encoder_tied.weight is None + assert encoder_tied.tied_to is dictionary + + +def test_tieable_encoder_forward(): + """Test TieableEncoder forward pass in both modes.""" + input_size = 10 + nb_concepts = 5 + batch_size = 3 + + dictionary = DictionaryLayer(input_size, nb_concepts) + x = torch.randn(batch_size, input_size) + + # Test untied forward + encoder_untied = TieableEncoder(input_size, nb_concepts, tied_to=None) + z_pre, z = encoder_untied(x) + assert z_pre.shape == (batch_size, nb_concepts) + assert z.shape == (batch_size, nb_concepts) + assert (z >= 0).all() # ReLU activation + + # Test tied forward + encoder_tied = TieableEncoder(input_size, nb_concepts, tied_to=dictionary) + z_pre, z = encoder_tied(x) + assert z_pre.shape == (batch_size, nb_concepts) + assert z.shape == (batch_size, nb_concepts) + assert (z >= 0).all() + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_sae_tied_untied(sae_class): + """Test that SAE can switch between tied and untied modes.""" + input_size = 10 + nb_concepts = 5 + + model = sae_class(input_size, nb_concepts) + + # Tie weights + model.tied() + assert isinstance(model.encoder, TieableEncoder) + assert model.encoder.tied_to is not None + + # Untie weights + model.untied() + assert isinstance(model.encoder, TieableEncoder) + assert model.encoder.tied_to is None + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_sae_tied_forward(sae_class): + """Test that tied SAE produces valid outputs.""" + input_size = 10 + nb_concepts = 5 + + model = sae_class(input_size, nb_concepts) + model.tied() + + x = torch.randn(3, input_size) + z_pre, z, x_hat = model(x) + + assert z.shape == (3, nb_concepts) + assert x_hat.shape == (3, input_size) + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_sae_untied_copy_weights(sae_class): + """Test that untied with copy_from_dictionary copies weights correctly.""" + input_size = 10 + nb_concepts = 5 + + model = sae_class(input_size, nb_concepts) + model.tied() + + # Get dictionary before untying + dict_before = model.get_dictionary().clone() + + # Untie and copy + model.untied(copy_from_dictionary=True) + + # Check that encoder weights match dictionary + assert epsilon_equal(model.encoder.weight, dict_before) diff --git a/tests/sae/test_sae_dictionary.py b/tests/sae/test_sae_dictionary.py index 631d040..bc2d8f3 100644 --- a/tests/sae/test_sae_dictionary.py +++ b/tests/sae/test_sae_dictionary.py @@ -2,6 +2,7 @@ import pytest from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal @@ -270,3 +271,46 @@ def test_multiplier_optimizer_step(): # Check that the multiplier has been updated. assert not torch.allclose(init_multiplier, layer.multiplier.detach(), atol=1e-6) + + +def test_tied_encoder_shares_dictionary_weights(): + """Test that tied encoder uses dictionary weights (not a copy).""" + input_size = 10 + nb_concepts = 5 + + dictionary = DictionaryLayer(input_size, nb_concepts) + encoder = TieableEncoder(input_size, nb_concepts, tied_to=dictionary) + + x = torch.randn(3, input_size) + + # Forward pass + z_pre1, z1 = encoder(x) + + # Modify dictionary weights + with torch.no_grad(): + dictionary._weights.data *= 10.0 + dictionary._weights.data += torch.randn_like(dictionary._weights) + + # Forward pass again + z_pre2, z2 = encoder(x) + + # Results should be different (weights are shared) + assert not epsilon_equal(z_pre1, z_pre2) + + +def test_tied_encoder_gradient_flow(): + """Test that gradients flow to dictionary through tied encoder.""" + input_size = 10 + nb_concepts = 5 + + dictionary = DictionaryLayer(input_size, nb_concepts) + encoder = TieableEncoder(input_size, nb_concepts, tied_to=dictionary) + + x = torch.randn(3, input_size, requires_grad=True) + z_pre, z = encoder(x) + + loss = z.sum() + loss.backward() + + # Dictionary should have gradients + assert dictionary._weights.grad is not None diff --git a/tests/sae/test_save_and_load.py b/tests/sae/test_save_and_load.py index 59177e3..f61c21d 100644 --- a/tests/sae/test_save_and_load.py +++ b/tests/sae/test_save_and_load.py @@ -3,12 +3,17 @@ import torch from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +def _load(path): + return torch.load(path, map_location="cpu", weights_only=False) + + @pytest.mark.parametrize("nb_concepts, dimensions", [(5, 10)]) def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): # Initialize and run layer @@ -86,3 +91,62 @@ def test_eval_and_save_sae(sae_class, tmp_path): assert epsilon_equal(z, z_loaded) assert epsilon_equal(x_hat, x_hat_loaded) assert epsilon_equal(z_pre, z_pre_loaded) + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_save_and_load_tied_sae(sae_class, tmp_path): + """Test that tied SAE can be saved and loaded.""" + input_size = 10 + nb_concepts = 5 + + model = sae_class(input_size, nb_concepts) + model.tied() + + x = torch.randn(3, input_size) + output = model(x) + z_pre, z, x_hat = output + + # Save + model_path = tmp_path / "test_tied_sae.pth" + torch.save(model, model_path) + + # Load + model_loaded = _load(model_path) + assert isinstance(model_loaded, sae_class) + + # Test output consistency + output_loaded = model_loaded(x) + z_pre_loaded, z_loaded, x_hat_loaded = output_loaded + + assert epsilon_equal(z, z_loaded) + assert epsilon_equal(x_hat, x_hat_loaded) + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_save_and_load_untied_with_copy(sae_class, tmp_path): + """Test that untied SAE with copied weights can be saved and loaded.""" + input_size = 10 + nb_concepts = 5 + + model = sae_class(input_size, nb_concepts) + model.tied() + model.untied(copy_from_dictionary=True) + + x = torch.randn(3, input_size) + output = model(x) + z_pre, z, x_hat = output + + # Save + model_path = tmp_path / "test_untied_sae.pth" + torch.save(model, model_path) + + # Load + model_loaded = _load(model_path) + assert isinstance(model_loaded, sae_class) + + # Test output consistency + output_loaded = model_loaded(x) + z_pre_loaded, z_loaded, x_hat_loaded = output_loaded + + assert epsilon_equal(z, z_loaded) + assert epsilon_equal(x_hat, x_hat_loaded) diff --git a/tests/sae/test_train.py b/tests/sae/test_train.py index 0fb44ff..5eb6f2c 100644 --- a/tests/sae/test_train.py +++ b/tests/sae/test_train.py @@ -40,7 +40,6 @@ def get_dictionary(self): logs = train_sae(model, dataloader, criterion, optimizer, scheduler, nb_epochs=2, monitoring=2, device="cpu") assert isinstance(logs, defaultdict) - assert "z" in logs assert "z_l2" in logs assert "z_sparsity" in logs assert "time_epoch" in logs @@ -72,7 +71,6 @@ def get_dictionary(self): logs = train_sae(model, dataloader, criterion, optimizer, monitoring=2, device="cpu") assert isinstance(logs, defaultdict) - assert "z" in logs assert "z_l2" in logs assert "z_sparsity" in logs assert "time_epoch" in logs diff --git a/tests/sae/test_train_sae.py b/tests/sae/test_train_sae.py index 6b00bda..4359aff 100644 --- a/tests/sae/test_train_sae.py +++ b/tests/sae/test_train_sae.py @@ -9,6 +9,7 @@ from overcomplete.sae.train import train_sae, train_sae_amp from overcomplete.sae.losses import mse_l1 from overcomplete.sae import SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal @@ -70,7 +71,6 @@ def test_train_mlp_sae(module_name, sae_class): ) assert isinstance(logs, defaultdict) - assert "z" in logs assert "z_l2" in logs assert "z_sparsity" in logs assert "time_epoch" in logs @@ -102,7 +102,6 @@ def criterion(x, x_hat, z_pre, z, dictionary): logs = train_sae_amp(model, dataloader, criterion, optimizer, scheduler, nb_epochs=2, monitoring=2, device="cpu") assert isinstance(logs, defaultdict) - assert "z" in logs assert "z_l2" in logs assert "z_sparsity" in logs assert "time_epoch" in logs @@ -135,7 +134,6 @@ def criterion(x, x_hat, z_pre, z, dictionary): logs = train_sae(model, dataloader, criterion, optimizer, scheduler, nb_epochs=2, monitoring=2, device="cpu") assert isinstance(logs, defaultdict) - assert "z" in logs assert "z_l2" in logs assert "z_sparsity" in logs assert "time_epoch" in logs @@ -221,7 +219,6 @@ def test_monitoring(): assert isinstance(logs, defaultdict) assert "lr" in logs - assert "z" not in logs logs = train_sae( model, @@ -235,7 +232,6 @@ def test_monitoring(): ) assert isinstance(logs, defaultdict) - assert "z" in logs def test_top_k_constraint(): @@ -326,3 +322,79 @@ def test_q_sae_quantization_levels(): assert isinstance(logs, defaultdict) assert len(logs) == 0 + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_train_tied_sae(sae_class): + """Test that tied SAE can be trained.""" + data = torch.randn(10, 10) + dataset = TensorDataset(data) + dataloader = DataLoader(dataset, batch_size=10) + criterion = mse_l1 + n_components = 2 + + model = sae_class(data.shape[1], n_components) + model.tied() + + optimizer = optim.SGD(model.parameters(), lr=0.001) + + logs = train_sae( + model, + dataloader, + criterion, + optimizer, + None, + nb_epochs=2, + monitoring=False, + device="cpu", + ) + + assert isinstance(logs, defaultdict) + + +@pytest.mark.parametrize("sae_class", all_sae) +def test_train_untied_after_tied(sae_class): + """Test training workflow: tied -> train -> untied with copy -> train.""" + data = torch.randn(10, 10) + dataset = TensorDataset(data) + dataloader = DataLoader(dataset, batch_size=10) + criterion = mse_l1 + n_components = 2 + + model = sae_class(data.shape[1], n_components) + model.tied() + + # Train tied + optimizer = optim.SGD(model.parameters(), lr=0.001) + train_sae(model, dataloader, criterion, optimizer, None, nb_epochs=1, monitoring=False) + + # Untie and copy weights + dict_after_training = model.get_dictionary().clone() + model.untied(copy_from_dictionary=True) + + # Check weights were copied + assert epsilon_equal(model.encoder.weight, dict_after_training) + + # Train untied + optimizer = optim.SGD(model.parameters(), lr=0.001) + train_sae(model, dataloader, criterion, optimizer, None, nb_epochs=1, monitoring=False) + + +def test_tied_encoder_bias_training(): + """Test that bias in tied encoder is trainable.""" + data = torch.randn(10, 10) + dataset = TensorDataset(data) + dataloader = DataLoader(dataset, batch_size=10) + criterion = mse_l1 + + model = SAE(10, 5) + model.tied(bias=True) + + # Record initial bias + initial_bias = model.encoder.bias.clone() + + optimizer = optim.SGD(model.parameters(), lr=0.01) + train_sae(model, dataloader, criterion, optimizer, None, nb_epochs=5, monitoring=False) + + # Bias should have changed + assert not epsilon_equal(initial_bias, model.encoder.bias)