Skip to content

Commit d362f64

Browse files
committed
Remove transformers; use OpenCLIP
1 parent 0dcc2ad commit d362f64

13 files changed

+620
-596
lines changed

analysis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def main(
192192
n_images: number of images to use. Use a smaller number for debugging.
193193
k_top_images: the number of top images to store per neuron.
194194
"""
195-
_, sae, acts_store = saev.utils.Session.from_disk(ckpt_path)
195+
_, sae, acts_store = saev.Session.from_disk(ckpt_path)
196196
get_feature_data(
197197
sae,
198198
acts_store,

logbook.md

+16-3
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,25 @@ With this in mind, there are several minor changes I want to make before I do so
8989

9090
1. Removing `transformer-lens` [done, commit [18612b7](https://github.com/samuelstevens/saev/commit/18612b75988c32ae8ab3db6656b44a442f3f7641)]
9191
2. Removing HookedVisionTransformer [done, commit [c7ba7c7](https://github.com/samuelstevens/saev/commit/c7ba7c72c76472fd8cf2e7b2dc668d03a15b803d)]
92-
3. OpenCLIP instead of huggingface `transformers`
92+
3. OpenCLIP instead of huggingface `transformers` [done, testing]
9393
4. Pre-computing ViT activations
9494

9595
I'm going to do each of these independently using a set of runs as references.
9696

9797
# 10/22/2024
9898

99-
* Removed HookedVisionTransformer (see above)
100-
* Checkpoint [v6jto37s](https://wandb.ai/samuelstevens/saev/runs/wwb20pa0) worked
99+
Removed HookedVisionTransformer (see above)
100+
Checkpoint [v6jto37s](https://wandb.ai/samuelstevens/saev/runs/wwb20pa0) worked for training, analysis, and app data.
101+
102+
Testing an implementation using OpenCLIP instead of `transformers`.
103+
Assuming it works (which seems likely given that the loss curve is identical), then I will pre-compute the activations, save them as a numpy array to disk, and memmap them during training rather than computing them.
104+
I expect this to take a little bit because I had issues with shuffling and such in the analysis step earlier.
105+
I think the best strategy is to work backwards.
106+
The `generate_app_data.py` script doesn't need an activation store at all.
107+
So I will start with the `analysis.py` script and add a new activations store class that meets the same interface as the original (maybe not for the constructor).
108+
Then I will verify that the analysis script works correctly.
109+
110+
Only after that will I use the new class in training.
111+
Working with the analysis script is a shorter feedback loop.
112+
113+
# 10/23/2024

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
torch.backends.cudnn.benchmark = True
1919
torch.backends.cudnn.deterministic = False
2020

21-
saev.training.train(tyro.cli(saev.Config))
21+
saev.train(tyro.cli(saev.Config))

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "saev"
33
version = "0.1.0"
4-
description = "Add your description here"
4+
description = "Sparse autoencoders for vision transformers in PyTorch"
55
readme = "README.md"
66
requires-python = ">=3.11"
77
dependencies = [
@@ -10,10 +10,10 @@ dependencies = [
1010
"einops>=0.8.0",
1111
"jaxtyping>=0.2.34",
1212
"marimo>=0.9.10",
13+
"open-clip-torch>=2.28.0",
1314
"pillow>=11.0.0",
1415
"torch>=2.5.0",
1516
"tqdm>=4.66.5",
16-
"transformers>=4.45.2",
1717
"tyro>=0.8.12",
1818
"wandb>=0.18.5",
1919
]

saev/__init__.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
from . import training, utils
2-
from .activations_store import ActivationsStore
3-
from .config import Config
4-
from .vits import RecordedVit
5-
from .sparse_autoencoder import SparseAutoencoder
1+
from .modeling import (
2+
ActivationsStore,
3+
Config,
4+
RecordedVit,
5+
Session,
6+
SparseAutoencoder,
7+
)
8+
from .training import train
69

710
__all__ = [
811
"ActivationsStore",
912
"Config",
1013
"RecordedVit",
1114
"SparseAutoencoder",
12-
"training",
13-
"utils",
15+
"Session",
16+
"train",
1417
]

saev/activations_store.py

-126
This file was deleted.

saev/config.py

+2-108
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,4 @@
1-
from dataclasses import dataclass
2-
3-
import beartype
4-
import torch
5-
6-
import wandb
7-
8-
9-
@beartype.beartype
10-
@dataclass
11-
class Config:
12-
"""
13-
Configuration for training a sparse autoencoder on a vision transformer.
14-
"""
15-
16-
# Data Generating Function (Model + Training Distibuion)
17-
image_width: int = 224
18-
image_height: int = 224
19-
model_name: str = "openai/clip-vit-large-patch14"
20-
module_name: str = "resid"
21-
block_layer: int = -2
22-
dataset_path: str = "ILSVRC/imagenet-1k"
23-
24-
# SAE Parameters
25-
d_in: int = 1024
26-
27-
# Activation Store Parameters
28-
total_training_tokens: int = 2_621_440
29-
n_batches_in_store: int = 15
30-
store_size: int | None = None
31-
vit_batch_size: int = 1024
32-
33-
# SAE Parameters
34-
expansion_factor: int = 64
35-
36-
# Training Parameters
37-
l1_coefficient: float = 0.00008
38-
lr: float = 0.0004
39-
lr_warm_up_steps: int = 500
40-
batch_size: int = 1024
41-
42-
# Resampling protocol args
43-
use_ghost_grads: bool = True
44-
feature_sampling_window: int = 64
45-
resample_batches: int = 32
46-
feature_reinit_scale: float = 0.2
47-
dead_feature_window: int = 64
48-
dead_feature_estimation_method: str = "no_fire"
49-
dead_feature_threshold: float = 1e-6
50-
51-
# WANDB
52-
log_to_wandb: bool = True
53-
wandb_project: str = "saev"
54-
wandb_log_freq: int = 10
55-
56-
# Misc
57-
device: str = "cuda"
58-
seed: int = 42
59-
dtype: torch.dtype = torch.float32
60-
checkpoint_path: str = "checkpoints"
61-
62-
def __post_init__(self):
63-
self.store_size = self.n_batches_in_store * self.batch_size
64-
65-
self.d_sae = self.d_in * self.expansion_factor
66-
67-
self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
68-
69-
self.device = torch.device(self.device)
70-
71-
unique_id = wandb.util.generate_id()
72-
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"
73-
74-
print(
75-
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
76-
)
77-
# Print out some useful info:
78-
79-
total_training_steps = self.total_training_tokens // self.batch_size
80-
print(f"Total training steps: {total_training_steps}")
81-
82-
total_wandb_updates = total_training_steps // self.wandb_log_freq
83-
print(f"Total wandb updates: {total_wandb_updates}")
84-
85-
# how many times will we sample dead neurons?
86-
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
87-
n_feature_window_samples = total_training_steps // self.feature_sampling_window
88-
print(
89-
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.batch_size) / 10** 6}"
90-
)
91-
print(
92-
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.batch_size) / 10** 6}"
93-
)
94-
95-
if self.use_ghost_grads:
96-
print("Using Ghost Grads.")
97-
98-
print(
99-
f"We will reset the sparsity calculation {n_feature_window_samples} times."
100-
)
101-
print(
102-
f"Number of tokens when resampling: {self.resample_batches * self.batch_size}"
103-
)
104-
print(
105-
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.batch_size:.2e}"
106-
)
107-
1+
from . import modeling
1082

1093
#################
1104
# COMPATIBILITY #
@@ -115,4 +9,4 @@ def __post_init__(self):
1159
# The classes are the same, just named differently.
11610

11711

118-
ViTSAERunnerConfig = Config
12+
ViTSAERunnerConfig = modeling.Config

0 commit comments

Comments
 (0)