Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.8
current_version = 0.3.0
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ This work has been made possible in part by the generous support provided by the

# Additional Resources

For a complete LLM implementation of the SAE, we strongly recommend exploring the following resources. The [Sparsify library by EleutherAI](https://github.com/EleutherAI/sparsify) provides a comprehensive toolset for implementing the SAE. The original TopK implementation is available through [OpenAI's Sparse Autoencoder](https://github.com/openai/sparse_autoencoder). Additionally, [SAE Lens](https://github.com/jbloomAus/SAELens) is an excellent resource, especially if you are interested in using the [SAE-vis](https://github.com/callummcdougall/sae_vis).
For a complete LLM implementation of the SAE, we strongly recommend exploring the following resources. The [Sparsify library by EleutherAI](https://github.com/EleutherAI/sparsify) provides a comprehensive toolset for implementing the SAE. The original TopK implementation is available through [OpenAI's Sparse Autoencoder](https://github.com/openai/sparse_autoencoder). Additionally, [SAE Lens](https://github.com/jbloomAus/SAELens) is an excellent resource, especially if you are interested in using the [SAE-vis](https://github.com/callummcdougall/sae_vis). Finally, closer to our work, for those interested in the vision domain, [ViT-Prisma](https://github.com/Prisma-Multimodal/ViT-Prisma) is an excellent mechanistic interpretability library for Vision Transformers (ViTs) that supports activation caching and SAE training.

# Related Publications

Expand Down
3 changes: 1 addition & 2 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ wasserstein_dist = wasserstein_1d(x, x_hat)
- `r2_score(x, x_hat)`: Measures reconstruction accuracy.

### **Sparsity Metrics**
- `sparsity(x)`: Alias for `l0(x)`.
- `l0(x)`: Cardinality of the support of `x`.
- `sparsity_eps(x, threshold)`: L0 with an epsilon threshold.
- `kappa_4(x)`: Kurtosis-based sparsity measure.
- `dead_codes(x)`: Identifies unused codes in a dictionary.
Expand All @@ -75,7 +75,6 @@ For further details, refer to the module documentation.
{{overcomplete.metrics.relative_avg_l2_loss}}
{{overcomplete.metrics.relative_avg_l1_loss}}
{{overcomplete.metrics.l0}}
{{overcomplete.metrics.sparsity}}
{{overcomplete.metrics.l1_l2_ratio}}
{{overcomplete.metrics.hoyer}}
{{overcomplete.metrics.kappa_4}}
Expand Down
40 changes: 40 additions & 0 deletions docs/saes/archetypal_saes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Archetypal Sparse Autoencoders (TopK & Jump)

Archetypal SAEs combine the archetypal dictionary constraint with familiar sparse encoders:
- **RATopKSAE**: TopK selection with archetypal atoms.
- **RAJumpSAE**: JumpReLU selection with archetypal atoms.

Dictionary atoms stay close to convex combinations of provided data points (controlled by `delta` and an optional multiplier), stabilizing training and improving interpretability. This is the SAE-form of the [Archetypal SAE](https://arxiv.org/abs/2502.12892) idea.

## Basic Usage
```python
import torch
from overcomplete.sae import RATopKSAE, RAJumpSAE

points = torch.randn(2_000, 768) # e.g. k-means centroids or sampled activations

ra_topk = RATopKSAE(
input_shape=768,
nb_concepts=10_000,
points=points,
top_k=20,
delta=1.0, # relaxation radius
use_multiplier=True # learnable scaling of the archetypal hull
)

ra_jump = RAJumpSAE(
input_shape=768,
nb_concepts=10_000,
points=points,
bandwidth=1e-3,
delta=1.5
)
```

Tips:
- Provide reasonably diverse `points` (e.g., k-means cluster centers) for stable archetypes.
- `use_multiplier` allows atoms to scale beyond the convex hull; set False to stay tighter.
- All standard training utilities (`train_sae`, custom losses) work unchanged.

{{overcomplete.sae.rasae.RATopKSAE | num_parents=1, skip_methods=fit}}
{{overcomplete.sae.rasae.RAJumpSAE | num_parents=1, skip_methods=fit}}
39 changes: 39 additions & 0 deletions docs/saes/dead_codes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Handling Dead Codes in SAEs

One of the most persistent challenges in training Sparse Autoencoders (SAEs) is the issue of **dead dictionary elements** (often called "dead codes" or feature collapse). These are latent features that, early in training, fall into a regime where they never activate (i.e., they have zero magnitude), preventing them from receiving gradients and learning useful representations.

Below is a simple auxiliary loss to gently nudge inactive atoms back into use while keeping the main reconstruction objective intact.

### How it works
1. **Identify Dead Codes:** It calculates a boolean mask for features that have not fired a single time across the current batch.
2. **Boost Pre-activations:** It isolates the "pre-codes" (the values *before* the activation function like ReLU or TopK is applied) for these dead atoms.
3. **Revive:** It subtracts these pre-activation values from the loss. Since the optimizer minimizes loss, this effectively **pushes the pre-activations toward the positive direction**, making them more likely to cross the activation threshold in future steps.

## Recommended Auxiliary Loss

```python
def criterion(x, x_hat, pre_codes, codes):
# 1. Standard reconstruction loss (MSE)
loss = (x - x_hat).square().mean()

# 2. Identify dead codes
# is_dead has shape [dict_size] and is 1.0 when a code never fires in the batch
is_dead = ((codes > 0).sum(dim=0) == 0).float().detach()

# 3. Calculate re-animation term
# We want to maximize the pre_codes of dead atoms to push them > 0
# Therefore, we subtract their mean value from the loss.
reanim_loss = (pre_codes * is_dead[None, :]).mean()

# 4. Combine
loss -= reanim_loss * 1e-3 # Keep this factor small

return loss
```

### Guidance for Implementation

* **Coefficient Sensitivity:** Use a **small coefficient** (e.g., `1e-4,1e-3`) so the reconstruction error remains the dominant term. If the coefficient is too high, the model may hallucinate features just to satisfy the auxiliary loss.
* **Monitoring:** Monitor the `dead_codes` metric (e.g., via `overcomplete.metrics.dead_codes`) to confirm the auxiliary term is reducing the dead count without simply creating "dense" noise.
* **Scheduling:** This is primarily useful during the early to mid-stages of training. You can anneal the coefficient to `0` once the dictionary utilization stabilizes.
* **Compatibility:** This auxiliary pairs well with any SAE variant (TopK, JumpReLU, Standard ReLU) provided you have access to the `pre_codes`.
37 changes: 37 additions & 0 deletions docs/saes/mp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Matching Pursuit SAE (MpSAE)

MpSAE replaces thresholding with a **greedy matching pursuit** loop: at each step it picks the atom most correlated with the residual, updates the codes, and subtracts the atom’s contribution, yielding sparse codes that track reconstruction progress. We encourage reading [^1] for the full method.

## Basic Usage
```python
from overcomplete import MpSAE

# define a Matching Pursuit SAE with input dimension 512, 4k concepts
sae = MpSAE(512, 4_096, k=4, dropout=0.1)

# k = number of pursuit steps, dropout optionally masks atoms each step
residual, codes = sae.encode(x)
```

## Advanced: auxiliary loss to revive dead codes

To ensure high dictionary utilization in MP-SAE, we strongly recommend implementing an auxiliary loss term.
Here is an example of such loss:

```python
def criterion(x, x_hat, residual, z, d):
recon_loss = ((x - x_hat) ** 2).mean()

revive_mask = (z.amax(dim=0) < 1e-2).detach() # shape: [c]

if revive_mask.sum() > 10:
projected = residual @ d.T # shape: [n, c]
revive_term = projected[:, revive_mask].mean()
recon_loss -= revive_term * 1e-2

return recon_loss
```

{{overcomplete.sae.mp_sae.MpSAE | num_parents=1, skip_methods=fit}}

[^1]: [From Flat to Hierarchical: Extracting Sparse Representations with Matching Pursuit](https://arxiv.org/abs/2506.03093).
29 changes: 29 additions & 0 deletions docs/saes/omp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Orthogonal Matching Pursuit SAE (OMPSAE)

OMPSAE uses **orthogonal matching pursuit** for sparse coding. Each iteration picks the atom most correlated with the current residual, then resolves NNLS on all selected atoms to refine the codes. This tighter refit often improves reconstruction over plain matching pursuit. For background, see [Sparse Autoencoders via Matching Pursuit](https://arxiv.org/pdf/2506.03093).

## Basic Usage
```python
import torch
from overcomplete.sae import OMPSAE

x = torch.randn(64, 512)
sae = OMPSAE(
input_shape=512,
nb_concepts=4_096,
k=4, # pursuit steps
max_iter=15, # NNLS iterations
dropout=0.1, # optional atom dropout
encoder_module="identity",
device="cuda"
)

residual, codes = sae.encode(x)
```

Notes:
- `encode` returns `(residual, codes)`; residual is the reconstruction error after pursuit steps.
- Set `dropout` to randomly mask atoms each iteration.
- Inputs must be 1D features (no 3D/4D tensors); `k` and `max_iter` must be positive.

{{overcomplete.sae.omp_sae.OMPSAE | num_parents=1, skip_methods=fit}}
6 changes: 5 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ nav:
- TopK: saes/topk.md
- JumpReLU: saes/jumprelu.md
- BatchTopK: saes/batchtopk.md
- Archetypal: saes/archetypal.md
- Archetypal Dictionary: saes/archetypal.md
- Archetypal SAEs: saes/archetypal_saes.md
- Matching Pursuit: saes/mp.md
- Orthogonal Matching Pursuit: saes/omp.md
- Dead Codes: saes/dead_codes.md
- Optimization:
- NMF: optimization/nmf.md
- Semi-NMF: optimization/seminmf.md
Expand Down
2 changes: 1 addition & 1 deletion overcomplete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Overcomplete: Personal toolbox for experimenting with Dictionary Learning.
"""

__version__ = '0.2.8'
__version__ = '0.3.0'


from .optimization import (SkPCA, SkICA, SkNMF, SkKMeans,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "Overcomplete"
version = "0.2.8"
version = "0.3.0"
description = "Toolbox for experimenting with (Overcomplete) Dictionary Learning for Vision model"
authors = ["Thomas Fel <tfel@g.harvard.edu>"]
license = "MIT"
Expand Down Expand Up @@ -34,7 +34,7 @@ numkdoc = "*"

# versioning
[tool.bumpversion]
current_version = "0.2.8"
current_version = "0.3.0"
commit = true
tag = true

Expand Down