diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 935ac1a..cba2adc 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.2.8 +current_version = 0.3.0 commit = True tag = True diff --git a/README.md b/README.md index 7cb2087..4d171dd 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ This work has been made possible in part by the generous support provided by the # Additional Resources -For a complete LLM implementation of the SAE, we strongly recommend exploring the following resources. The [Sparsify library by EleutherAI](https://github.com/EleutherAI/sparsify) provides a comprehensive toolset for implementing the SAE. The original TopK implementation is available through [OpenAI's Sparse Autoencoder](https://github.com/openai/sparse_autoencoder). Additionally, [SAE Lens](https://github.com/jbloomAus/SAELens) is an excellent resource, especially if you are interested in using the [SAE-vis](https://github.com/callummcdougall/sae_vis). +For a complete LLM implementation of the SAE, we strongly recommend exploring the following resources. The [Sparsify library by EleutherAI](https://github.com/EleutherAI/sparsify) provides a comprehensive toolset for implementing the SAE. The original TopK implementation is available through [OpenAI's Sparse Autoencoder](https://github.com/openai/sparse_autoencoder). Additionally, [SAE Lens](https://github.com/jbloomAus/SAELens) is an excellent resource, especially if you are interested in using the [SAE-vis](https://github.com/callummcdougall/sae_vis). Finally, closer to our work, for those interested in the vision domain, [ViT-Prisma](https://github.com/Prisma-Multimodal/ViT-Prisma) is an excellent mechanistic interpretability library for Vision Transformers (ViTs) that supports activation caching and SAE training. # Related Publications diff --git a/docs/metrics.md b/docs/metrics.md index c87e22e..844f11a 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -51,7 +51,7 @@ wasserstein_dist = wasserstein_1d(x, x_hat) - `r2_score(x, x_hat)`: Measures reconstruction accuracy. ### **Sparsity Metrics** -- `sparsity(x)`: Alias for `l0(x)`. +- `l0(x)`: Cardinality of the support of `x`. - `sparsity_eps(x, threshold)`: L0 with an epsilon threshold. - `kappa_4(x)`: Kurtosis-based sparsity measure. - `dead_codes(x)`: Identifies unused codes in a dictionary. @@ -75,7 +75,6 @@ For further details, refer to the module documentation. {{overcomplete.metrics.relative_avg_l2_loss}} {{overcomplete.metrics.relative_avg_l1_loss}} {{overcomplete.metrics.l0}} -{{overcomplete.metrics.sparsity}} {{overcomplete.metrics.l1_l2_ratio}} {{overcomplete.metrics.hoyer}} {{overcomplete.metrics.kappa_4}} diff --git a/docs/saes/archetypal_saes.md b/docs/saes/archetypal_saes.md new file mode 100644 index 0000000..28eabe5 --- /dev/null +++ b/docs/saes/archetypal_saes.md @@ -0,0 +1,40 @@ +# Archetypal Sparse Autoencoders (TopK & Jump) + +Archetypal SAEs combine the archetypal dictionary constraint with familiar sparse encoders: +- **RATopKSAE**: TopK selection with archetypal atoms. +- **RAJumpSAE**: JumpReLU selection with archetypal atoms. + +Dictionary atoms stay close to convex combinations of provided data points (controlled by `delta` and an optional multiplier), stabilizing training and improving interpretability. This is the SAE-form of the [Archetypal SAE](https://arxiv.org/abs/2502.12892) idea. + +## Basic Usage +```python +import torch +from overcomplete.sae import RATopKSAE, RAJumpSAE + +points = torch.randn(2_000, 768) # e.g. k-means centroids or sampled activations + +ra_topk = RATopKSAE( + input_shape=768, + nb_concepts=10_000, + points=points, + top_k=20, + delta=1.0, # relaxation radius + use_multiplier=True # learnable scaling of the archetypal hull +) + +ra_jump = RAJumpSAE( + input_shape=768, + nb_concepts=10_000, + points=points, + bandwidth=1e-3, + delta=1.5 +) +``` + +Tips: +- Provide reasonably diverse `points` (e.g., k-means cluster centers) for stable archetypes. +- `use_multiplier` allows atoms to scale beyond the convex hull; set False to stay tighter. +- All standard training utilities (`train_sae`, custom losses) work unchanged. + +{{overcomplete.sae.rasae.RATopKSAE | num_parents=1, skip_methods=fit}} +{{overcomplete.sae.rasae.RAJumpSAE | num_parents=1, skip_methods=fit}} diff --git a/docs/saes/dead_codes.md b/docs/saes/dead_codes.md new file mode 100644 index 0000000..af5a646 --- /dev/null +++ b/docs/saes/dead_codes.md @@ -0,0 +1,39 @@ +# Handling Dead Codes in SAEs + +One of the most persistent challenges in training Sparse Autoencoders (SAEs) is the issue of **dead dictionary elements** (often called "dead codes" or feature collapse). These are latent features that, early in training, fall into a regime where they never activate (i.e., they have zero magnitude), preventing them from receiving gradients and learning useful representations. + +Below is a simple auxiliary loss to gently nudge inactive atoms back into use while keeping the main reconstruction objective intact. + +### How it works +1. **Identify Dead Codes:** It calculates a boolean mask for features that have not fired a single time across the current batch. +2. **Boost Pre-activations:** It isolates the "pre-codes" (the values *before* the activation function like ReLU or TopK is applied) for these dead atoms. +3. **Revive:** It subtracts these pre-activation values from the loss. Since the optimizer minimizes loss, this effectively **pushes the pre-activations toward the positive direction**, making them more likely to cross the activation threshold in future steps. + +## Recommended Auxiliary Loss + +```python +def criterion(x, x_hat, pre_codes, codes): + # 1. Standard reconstruction loss (MSE) + loss = (x - x_hat).square().mean() + + # 2. Identify dead codes + # is_dead has shape [dict_size] and is 1.0 when a code never fires in the batch + is_dead = ((codes > 0).sum(dim=0) == 0).float().detach() + + # 3. Calculate re-animation term + # We want to maximize the pre_codes of dead atoms to push them > 0 + # Therefore, we subtract their mean value from the loss. + reanim_loss = (pre_codes * is_dead[None, :]).mean() + + # 4. Combine + loss -= reanim_loss * 1e-3 # Keep this factor small + + return loss +``` + +### Guidance for Implementation + +* **Coefficient Sensitivity:** Use a **small coefficient** (e.g., `1e-4,1e-3`) so the reconstruction error remains the dominant term. If the coefficient is too high, the model may hallucinate features just to satisfy the auxiliary loss. +* **Monitoring:** Monitor the `dead_codes` metric (e.g., via `overcomplete.metrics.dead_codes`) to confirm the auxiliary term is reducing the dead count without simply creating "dense" noise. +* **Scheduling:** This is primarily useful during the early to mid-stages of training. You can anneal the coefficient to `0` once the dictionary utilization stabilizes. +* **Compatibility:** This auxiliary pairs well with any SAE variant (TopK, JumpReLU, Standard ReLU) provided you have access to the `pre_codes`. \ No newline at end of file diff --git a/docs/saes/mp.md b/docs/saes/mp.md new file mode 100644 index 0000000..fae41a3 --- /dev/null +++ b/docs/saes/mp.md @@ -0,0 +1,37 @@ +# Matching Pursuit SAE (MpSAE) + +MpSAE replaces thresholding with a **greedy matching pursuit** loop: at each step it picks the atom most correlated with the residual, updates the codes, and subtracts the atom’s contribution, yielding sparse codes that track reconstruction progress. We encourage reading [^1] for the full method. + +## Basic Usage +```python +from overcomplete import MpSAE + +# define a Matching Pursuit SAE with input dimension 512, 4k concepts +sae = MpSAE(512, 4_096, k=4, dropout=0.1) + +# k = number of pursuit steps, dropout optionally masks atoms each step +residual, codes = sae.encode(x) +``` + +## Advanced: auxiliary loss to revive dead codes + +To ensure high dictionary utilization in MP-SAE, we strongly recommend implementing an auxiliary loss term. +Here is an example of such loss: + +```python +def criterion(x, x_hat, residual, z, d): + recon_loss = ((x - x_hat) ** 2).mean() + + revive_mask = (z.amax(dim=0) < 1e-2).detach() # shape: [c] + + if revive_mask.sum() > 10: + projected = residual @ d.T # shape: [n, c] + revive_term = projected[:, revive_mask].mean() + recon_loss -= revive_term * 1e-2 + + return recon_loss +``` + +{{overcomplete.sae.mp_sae.MpSAE | num_parents=1, skip_methods=fit}} + +[^1]: [From Flat to Hierarchical: Extracting Sparse Representations with Matching Pursuit](https://arxiv.org/abs/2506.03093). diff --git a/docs/saes/omp.md b/docs/saes/omp.md new file mode 100644 index 0000000..9e87a2e --- /dev/null +++ b/docs/saes/omp.md @@ -0,0 +1,29 @@ +# Orthogonal Matching Pursuit SAE (OMPSAE) + +OMPSAE uses **orthogonal matching pursuit** for sparse coding. Each iteration picks the atom most correlated with the current residual, then resolves NNLS on all selected atoms to refine the codes. This tighter refit often improves reconstruction over plain matching pursuit. For background, see [Sparse Autoencoders via Matching Pursuit](https://arxiv.org/pdf/2506.03093). + +## Basic Usage +```python +import torch +from overcomplete.sae import OMPSAE + +x = torch.randn(64, 512) +sae = OMPSAE( + input_shape=512, + nb_concepts=4_096, + k=4, # pursuit steps + max_iter=15, # NNLS iterations + dropout=0.1, # optional atom dropout + encoder_module="identity", + device="cuda" +) + +residual, codes = sae.encode(x) +``` + +Notes: +- `encode` returns `(residual, codes)`; residual is the reconstruction error after pursuit steps. +- Set `dropout` to randomly mask atoms each iteration. +- Inputs must be 1D features (no 3D/4D tensors); `k` and `max_iter` must be positive. + +{{overcomplete.sae.omp_sae.OMPSAE | num_parents=1, skip_methods=fit}} diff --git a/mkdocs.yml b/mkdocs.yml index e206612..83780eb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,7 +39,11 @@ nav: - TopK: saes/topk.md - JumpReLU: saes/jumprelu.md - BatchTopK: saes/batchtopk.md - - Archetypal: saes/archetypal.md + - Archetypal Dictionary: saes/archetypal.md + - Archetypal SAEs: saes/archetypal_saes.md + - Matching Pursuit: saes/mp.md + - Orthogonal Matching Pursuit: saes/omp.md + - Dead Codes: saes/dead_codes.md - Optimization: - NMF: optimization/nmf.md - Semi-NMF: optimization/seminmf.md diff --git a/overcomplete/__init__.py b/overcomplete/__init__.py index e67838d..549f4f5 100644 --- a/overcomplete/__init__.py +++ b/overcomplete/__init__.py @@ -2,7 +2,7 @@ Overcomplete: Personal toolbox for experimenting with Dictionary Learning. """ -__version__ = '0.2.8' +__version__ = '0.3.0' from .optimization import (SkPCA, SkICA, SkNMF, SkKMeans, diff --git a/pyproject.toml b/pyproject.toml index a3d0649..2cab1cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Overcomplete" -version = "0.2.8" +version = "0.3.0" description = "Toolbox for experimenting with (Overcomplete) Dictionary Learning for Vision model" authors = ["Thomas Fel "] license = "MIT" @@ -34,7 +34,7 @@ numkdoc = "*" # versioning [tool.bumpversion] -current_version = "0.2.8" +current_version = "0.3.0" commit = true tag = true