Skip to content

Commit 2f3c8b9

Browse files
committed
Add CachedActivationStore
1 parent d362f64 commit 2f3c8b9

7 files changed

+348
-264
lines changed

analysis.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import beartype
77
import torch
8-
import tqdm
98
import tyro
109
from jaxtyping import Float, Int, jaxtyped
1110
from torch import Tensor
1211

1312
import saev
13+
from saev import helpers
1414

1515
# Fix pickle renaming errors.
1616
sys.modules["sae_training"] = saev
@@ -39,7 +39,7 @@ def batched_idx(
3939

4040
@jaxtyped(typechecker=beartype.beartype)
4141
def get_vit_acts(
42-
acts_store: saev.ActivationsStore, n: int
42+
acts_store: saev.CachedActivationsStore, n: int
4343
) -> tuple[Float[Tensor, "n d_model"], Int[Tensor, " n"]]:
4444
"""
4545
Args:
@@ -54,7 +54,6 @@ def get_vit_acts(
5454
batches.append(batch)
5555
indices.append(i)
5656
n_seen += len(batch)
57-
logger.info("Got batch of size %d (%d total).", len(batch), n_seen)
5857

5958
batches = torch.cat(batches, dim=0)
6059
indices = torch.cat(indices, dim=0)
@@ -94,7 +93,7 @@ def get_new_topk(
9493
@torch.inference_mode()
9594
def get_feature_data(
9695
sae: saev.SparseAutoencoder,
97-
acts_store: saev.ActivationsStore,
96+
acts_store: saev.CachedActivationsStore,
9897
*,
9998
n_images: int = 32_768,
10099
k_top_images: int = 10,
@@ -112,14 +111,14 @@ def get_feature_data(
112111
torch.cuda.empty_cache()
113112
sae.eval()
114113

115-
if n_images > len(acts_store.dataset):
114+
if n_images > len(acts_store):
116115
logger.warning(
117116
"The dataset '%s' only has %d images, but you requested %d images.",
118117
sae.cfg.dataset_path,
119-
len(acts_store.dataset),
118+
len(acts_store),
120119
n_images,
121120
)
122-
n_images = len(acts_store.dataset)
121+
n_images = len(acts_store)
123122

124123
top_values = torch.zeros((sae.cfg.d_sae, k_top_images)).to(sae.cfg.device)
125124
top_indices = torch.zeros((sae.cfg.d_sae, k_top_images), dtype=torch.int)
@@ -128,13 +127,18 @@ def get_feature_data(
128127
sae_sparsity = torch.zeros((sae.cfg.d_sae,)).to(sae.cfg.device)
129128
sae_mean_acts = torch.zeros((sae.cfg.d_sae,)).to(sae.cfg.device)
130129

131-
n_seen = 0
130+
dataloader = torch.utils.data.DataLoader(
131+
acts_store,
132+
batch_size=images_per_it,
133+
shuffle=True,
134+
num_workers=sae.cfg.n_workers,
135+
drop_last=True,
136+
)
132137

133-
while n_seen < n_images:
138+
for batch in helpers.progress(dataloader):
134139
torch.cuda.empty_cache()
140+
vit_acts, indices = batch
135141

136-
# tensor of size [batch, d_resid]
137-
vit_acts, indices = get_vit_acts(acts_store, images_per_it)
138142
# tensor of size [feature_idx, batch]
139143
sae_acts = get_sae_acts(vit_acts.to(sae.cfg.device), sae).transpose(0, 1)
140144
del vit_acts
@@ -149,24 +153,18 @@ def get_feature_data(
149153
top_values, top_indices, values, indices, k_top_images
150154
)
151155

152-
n_seen += images_per_it
153-
logger.info("%d/%d (%.1f%%)", n_seen, n_images, n_seen / n_images * 100)
154-
155156
sae_mean_acts /= sae_sparsity
156-
sae_sparsity /= n_images
157+
sae_sparsity /= len(acts_store)
157158

158159
# Check if the directory exists
159160
if not os.path.exists(directory):
160161
# Create the directory if it does not exist
161162
os.makedirs(directory)
162163

163-
# compute the label tensor
164-
top_image_label_indices = torch.tensor([
165-
acts_store.dataset[int(index)]["label"]
166-
for index in tqdm.tqdm(top_indices.flatten(), desc="Getting labels")
167-
])
168-
# Reshape to original dimensions
169-
top_image_label_indices = top_image_label_indices.view(top_indices.shape)
164+
# Compute the label tensor
165+
top_image_label_indices = acts_store.labels[top_indices.view(-1).cpu()].view(
166+
top_indices.shape
167+
)
170168
torch.save(top_indices, f"{directory}/max_activating_image_indices.pt")
171169
torch.save(top_values, f"{directory}/max_activating_image_values.pt")
172170
torch.save(

generate_app_data.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@ def safe_load(path: str) -> object:
2020
return torch.load(path, map_location="cpu", weights_only=True)
2121

2222

23+
@beartype.beartype
24+
def make_img_grid(imgs: list):
25+
# Resize to 224x224
26+
img_width, img_height = 224, 224
27+
imgs = [img.resize((img_width, img_height)).convert("RGB") for img in imgs]
28+
29+
# Create an image grid
30+
grid_size = 4
31+
border_size = 2 # White border thickness
32+
33+
# Create a new image with white background
34+
grid_width = grid_size * img_width + (grid_size - 1) * border_size
35+
grid_height = grid_size * img_height + (grid_size - 1) * border_size
36+
img_grid = Image.new("RGB", (grid_width, grid_height), "white")
37+
38+
# Paste images in the grid
39+
x_offset, y_offset = 0, 0
40+
for i, img in enumerate(imgs):
41+
img_grid.paste(img, (x_offset, y_offset))
42+
x_offset += img_width + border_size
43+
if (i + 1) % grid_size == 0:
44+
x_offset = 0
45+
y_offset += img_height + border_size
46+
return img_grid
47+
48+
2349
@beartype.beartype
2450
def main(ckpt_path: str, in_dir: str = "dashboard", out_dir: str = "web_app"):
2551
"""
@@ -85,31 +111,6 @@ def main(ckpt_path: str, in_dir: str = "dashboard", out_dir: str = "web_app"):
85111
indices = torch.tensor([i for i in range(n_neurons)])
86112
indices = list(indices[mask])
87113

88-
@beartype.beartype
89-
def make_img_grid(imgs: list):
90-
# Resize to 224x224
91-
img_width, img_height = 224, 224
92-
imgs = [img.resize((img_width, img_height)).convert("RGB") for img in imgs]
93-
94-
# Create an image grid
95-
grid_size = 4
96-
border_size = 2 # White border thickness
97-
98-
# Create a new image with white background
99-
grid_width = grid_size * img_width + (grid_size - 1) * border_size
100-
grid_height = grid_size * img_height + (grid_size - 1) * border_size
101-
img_grid = Image.new("RGB", (grid_width, grid_height), "white")
102-
103-
# Paste images in the grid
104-
x_offset, y_offset = 0, 0
105-
for i, img in enumerate(imgs):
106-
img_grid.paste(img, (x_offset, y_offset))
107-
x_offset += img_width + border_size
108-
if (i + 1) % grid_size == 0:
109-
x_offset = 0
110-
y_offset += img_height + border_size
111-
return img_grid
112-
113114
os.makedirs(f"{out_dir}/neurons", exist_ok=True)
114115
torch.save(entropies, f"{out_dir}/neurons/entropy.pt")
115116
for i in tqdm.tqdm(indices, desc="saving highest activating grids"):

logbook.md

+22-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ With this in mind, there are several minor changes I want to make before I do so
8989

9090
1. Removing `transformer-lens` [done, commit [18612b7](https://github.com/samuelstevens/saev/commit/18612b75988c32ae8ab3db6656b44a442f3f7641)]
9191
2. Removing HookedVisionTransformer [done, commit [c7ba7c7](https://github.com/samuelstevens/saev/commit/c7ba7c72c76472fd8cf2e7b2dc668d03a15b803d)]
92-
3. OpenCLIP instead of huggingface `transformers` [done, testing]
93-
4. Pre-computing ViT activations
92+
3. OpenCLIP instead of huggingface `transformers` [done, commit [d362f64](https://github.com/samuelstevens/saev/commit/d362f64437b3599f56bb698136712d7590ee897b)]
93+
4. Pre-computing ViT activations [done, commit [ee79f5b](https://github.com/samuelstevens/saev/commit/ee79f5b84186e655b2e5d485e972fe69bb73dd65)]
9494

9595
I'm going to do each of these independently using a set of runs as references.
9696

@@ -111,3 +111,23 @@ Only after that will I use the new class in training.
111111
Working with the analysis script is a shorter feedback loop.
112112

113113
# 10/23/2024
114+
115+
OpenCLIP instead of transformers works (training, analysis, generate).
116+
So now I am pre-computing activations.
117+
I'm waiting on the activations to be saved (~3 hours).
118+
119+
CachedActivationsStore produced some duplicates in the analysis step.
120+
Why is that?
121+
122+
For example, neuron 78 has the same image for image 6 and 7 (1-indexed, images 5 and 6 if zero-indexed).
123+
124+
Fixed it.
125+
We no longer randomly sample batches; instead, we use a dataloader and `__getitem__`.
126+
127+
With training, however, the metrics no longer match the reference metrics.
128+
Why is that?
129+
We can find out by comparing to the original activations store.
130+
Likely, we will need to build a custom data order using `np.random.default_rng(seed=cfg.seed)`.
131+
132+
My strategy for calculating the mean activations only used 15 examples instead of 15 x 1024.
133+
With 15 x 1024 examples, the b_dec is better initialized and it works exactly like before.

saev/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .modeling import (
22
ActivationsStore,
3+
CachedActivationsStore,
34
Config,
45
RecordedVit,
56
Session,
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"ActivationsStore",
13+
"CachedActivationsStore",
1214
"Config",
1315
"RecordedVit",
1416
"SparseAutoencoder",

saev/helpers.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import collections.abc
2+
import logging
3+
import time
4+
5+
import beartype
6+
7+
8+
@beartype.beartype
9+
class progress:
10+
def __init__(self, it, *, every: int = 10, desc: str = "progress", total: int = 0):
11+
"""
12+
Wraps an iterable with a logger like tqdm but doesn't use any control codes to manipulate a progress bar, which doesn't work well when your output is redirected to a file. Instead, simple logging statements are used, but it includes quality-of-life features like iteration speed and predicted time to finish.
13+
14+
Args:
15+
it: Iterable to wrap.
16+
every: How many iterations between logging progress.
17+
desc: What to name the logger.
18+
total: If non-zero, how long the iterable is.
19+
"""
20+
self.it = it
21+
self.every = every
22+
self.logger = logging.getLogger(desc)
23+
self.total = total
24+
25+
def __iter__(self):
26+
start = time.time()
27+
for i, obj in enumerate(self.it):
28+
yield obj
29+
30+
if (i + 1) % self.every == 0:
31+
now = time.time()
32+
duration_s = now - start
33+
per_min = (i + 1) / (duration_s / 60)
34+
35+
if isinstance(self.it, collections.abc.Sized):
36+
pred_min = (len(self) - (i + 1)) / per_min
37+
self.logger.info(
38+
"%d/%d (%.1f%%) | %.1f it/m (expected finish in %.1fm)",
39+
i + 1,
40+
len(self),
41+
(i + 1) / len(self) * 100,
42+
per_min,
43+
pred_min,
44+
)
45+
else:
46+
self.logger.info("%d/? | %.1f it/m", i + 1, per_min)
47+
48+
def __len__(self) -> int:
49+
if self.total > 0:
50+
return self.total
51+
52+
# Will throw exception.
53+
return len(self.it)

0 commit comments

Comments
 (0)