Skip to content

Commit 7e83333

Browse files
committed
re-add notebooks as tests
1 parent e850e41 commit 7e83333

File tree

3 files changed

+318
-4
lines changed

3 files changed

+318
-4
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# %%
2+
from typing import Any
3+
4+
import matplotlib.pyplot as plt
5+
import torch
6+
from muutils.dbg import dbg_auto
7+
from torch import Tensor
8+
9+
from spd.clustering.activations import (
10+
ProcessedActivations,
11+
component_activations,
12+
process_activations,
13+
)
14+
from spd.clustering.merge import merge_iteration, merge_iteration_ensemble
15+
from spd.clustering.merge_config import MergeConfig
16+
from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble
17+
from spd.clustering.merge_sweep import sweep_multiple_parameters
18+
from spd.clustering.plotting.activations import plot_activations
19+
from spd.clustering.plotting.merge import (
20+
plot_dists_distribution,
21+
plot_merge_iteration,
22+
)
23+
from spd.configs import Config
24+
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset
25+
from spd.models.component_model import ComponentModel, SPDRunInfo
26+
from spd.registry import EXPERIMENT_REGISTRY
27+
from spd.utils.data_utils import DatasetGeneratedDataLoader
28+
29+
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
30+
31+
# magic autoreload
32+
# %load_ext autoreload
33+
# %autoreload 2
34+
35+
# %%
36+
# Load model
37+
# ============================================================
38+
_CANONICAL_RUN: str | None = EXPERIMENT_REGISTRY["resid_mlp2"].canonical_run
39+
assert _CANONICAL_RUN is not None, "No canonical run found for resid_mlp2 experiment"
40+
SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(_CANONICAL_RUN)
41+
MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path)
42+
MODEL.to(DEVICE)
43+
SPD_CONFIG: Config = SPD_RUN.config
44+
45+
# %%
46+
# Setup dataset and dataloader
47+
# ============================================================
48+
N_SAMPLES: int = 128
49+
50+
DATASET: ResidMLPDataset = ResidMLPDataset(
51+
n_features=MODEL.target_model.config.n_features, # pyright: ignore[reportAttributeAccessIssue, reportArgumentType],
52+
feature_probability=SPD_CONFIG.task_config.feature_probability, # pyright: ignore[reportAttributeAccessIssue]
53+
device=DEVICE,
54+
calc_labels=False,
55+
label_type=None,
56+
act_fn_name=None,
57+
label_fn_seed=None,
58+
label_coeffs=None,
59+
data_generation_type=SPD_CONFIG.task_config.data_generation_type, # pyright: ignore[reportAttributeAccessIssue]
60+
)
61+
62+
dbg_auto(
63+
dict(
64+
n_features=DATASET.n_features,
65+
feature_probability=DATASET.feature_probability,
66+
data_generation_type=DATASET.data_generation_type,
67+
)
68+
)
69+
DATALOADER = DatasetGeneratedDataLoader(DATASET, batch_size=N_SAMPLES, shuffle=False)
70+
71+
# %%
72+
# Get component activations
73+
# ============================================================
74+
COMPONENT_ACTS: dict[str, Tensor] = component_activations(
75+
model=MODEL,
76+
device=DEVICE,
77+
dataloader=DATALOADER,
78+
sigmoid_type="hard",
79+
)
80+
81+
dbg_auto(COMPONENT_ACTS)
82+
83+
# %%
84+
85+
FILTER_DEAD_THRESHOLD: float = 0.1
86+
87+
# Process activations
88+
# ============================================================
89+
PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations(
90+
COMPONENT_ACTS,
91+
filter_dead_threshold=FILTER_DEAD_THRESHOLD,
92+
sort_components=False, # Test the new sorting functionality
93+
)
94+
95+
96+
plot_activations(
97+
processed_activations=PROCESSED_ACTIVATIONS,
98+
save_pdf=False,
99+
)
100+
101+
# %%
102+
# run the merge iteration
103+
# ============================================================
104+
105+
MERGE_CFG: MergeConfig = MergeConfig(
106+
activation_threshold=0.1,
107+
alpha=1,
108+
iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9),
109+
merge_pair_sampling_method="range",
110+
merge_pair_sampling_kwargs={"threshold": 0.0},
111+
pop_component_prob=0,
112+
filter_dead_threshold=FILTER_DEAD_THRESHOLD,
113+
)
114+
115+
116+
def _plot_func(
117+
costs: torch.Tensor,
118+
# merge_history: MergeHistory,
119+
current_merge: Any,
120+
current_coact: torch.Tensor,
121+
# current_act_mask: torch.Tensor,
122+
i: int,
123+
# k_groups: int,
124+
# activation_mask_orig: torch.Tensor,
125+
component_labels: list[str],
126+
# sweep_params: dict[str, Any],
127+
**kwargs: Any,
128+
) -> None:
129+
assert kwargs
130+
if (i % 50 == 0 and i > 0) or i == 1:
131+
# latest = merge_history.latest()
132+
# latest['merges'].plot()
133+
plot_merge_iteration(
134+
current_merge=current_merge,
135+
current_coact=current_coact,
136+
costs=costs,
137+
iteration=i,
138+
component_labels=component_labels,
139+
show=True, # Show the plot interactively
140+
)
141+
142+
143+
MERGE_HIST: MergeHistory = merge_iteration(
144+
activations=PROCESSED_ACTIVATIONS.activations,
145+
merge_config=MERGE_CFG,
146+
component_labels=PROCESSED_ACTIVATIONS.labels,
147+
plot_callback=_plot_func,
148+
)
149+
150+
# %%
151+
# Plot merge history
152+
# ============================================================
153+
154+
# plt.hist(mh[270]["merges"].components_per_group, bins=np.linspace(0, 56, 57))
155+
# plt.yscale("log")
156+
# plt.xscale("log")
157+
158+
159+
# %%
160+
# compute and plot distances in an ensemble
161+
# ============================================================
162+
163+
ENSEMBLE: MergeHistoryEnsemble = merge_iteration_ensemble(
164+
activations=PROCESSED_ACTIVATIONS.activations,
165+
component_labels=PROCESSED_ACTIVATIONS.labels,
166+
merge_config=MERGE_CFG,
167+
ensemble_size=4,
168+
)
169+
170+
DISTANCES = ENSEMBLE.get_distances(method="perm_invariant_hamming")
171+
172+
plot_dists_distribution(
173+
distances=DISTANCES,
174+
mode="points",
175+
# label="v1"
176+
)
177+
plt.legend()
178+
179+
180+
# %%
181+
# do sweeps
182+
# ============================================================
183+
184+
SWEEP_RESULTS: dict[str, Any] = sweep_multiple_parameters(
185+
activations=PROCESSED_ACTIVATIONS.activations,
186+
parameter_sweeps={
187+
"alpha": [1, 5],
188+
# "check_threshold": [0.0001, 0.001, 0.01, 0.1, 0.5],
189+
# "pop_component_prob": [0.0001, 0.01, 0.5],
190+
},
191+
base_config=MERGE_CFG.model_dump(mode="json"), # pyright: ignore[reportArgumentType],
192+
component_labels=PROCESSED_ACTIVATIONS.labels,
193+
ensemble_size=4,
194+
)
195+
196+
# Show all plots
197+
for param_name, (ensembles, fig, ax) in SWEEP_RESULTS.items(): # noqa: B007
198+
plt.show()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# %%
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import torch
6+
from jaxtyping import Int
7+
from muutils.dbg import dbg_auto
8+
from torch import Tensor
9+
10+
from spd.clustering.activations import (
11+
ProcessedActivations,
12+
component_activations,
13+
process_activations,
14+
)
15+
from spd.clustering.merge import merge_iteration_ensemble
16+
from spd.clustering.merge_config import MergeConfig
17+
from spd.clustering.merge_history import MergeHistoryEnsemble
18+
from spd.clustering.plotting.activations import plot_activations
19+
from spd.clustering.plotting.merge import plot_dists_distribution
20+
from spd.clustering.scripts.s1_split_dataset import split_dataset_lm
21+
from spd.models.component_model import ComponentModel, SPDRunInfo
22+
23+
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
24+
25+
# magic autoreload
26+
# %load_ext autoreload
27+
# %autoreload 2
28+
29+
# %%
30+
# Load model and dataset
31+
# ============================================================
32+
MODEL_PATH: str = "wandb:goodfire/spd/runs/ioprgffh"
33+
34+
_, DATA_CFG = split_dataset_lm(
35+
model_path=MODEL_PATH,
36+
n_batches=1,
37+
batch_size=2,
38+
)
39+
DATASET_PATH: str = DATA_CFG["output_files"][0]
40+
41+
SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH)
42+
MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path)
43+
MODEL.to(DEVICE)
44+
SPD_CONFIG = SPD_RUN.config
45+
46+
47+
# %%
48+
# Load data batch
49+
# ============================================================
50+
DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = torch.tensor(np.load(DATASET_PATH)["input_ids"])
51+
52+
# %%
53+
# Get component activations
54+
# ============================================================
55+
COMPONENT_ACTS: dict[str, Tensor] = component_activations(
56+
model=MODEL,
57+
batch=DATA_BATCH,
58+
device=DEVICE,
59+
sigmoid_type="hard",
60+
)
61+
62+
_ = dbg_auto(COMPONENT_ACTS)
63+
# %%
64+
# Process activations
65+
# ============================================================
66+
FILTER_DEAD_THRESHOLD: float = 0.001
67+
FILTER_MODULES: str = "model.layers.0"
68+
69+
PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations(
70+
activations=COMPONENT_ACTS,
71+
filter_dead_threshold=FILTER_DEAD_THRESHOLD,
72+
filter_modules=lambda x: x.startswith(FILTER_MODULES),
73+
seq_mode="concat",
74+
)
75+
76+
plot_activations(
77+
processed_activations=PROCESSED_ACTIVATIONS,
78+
save_pdf=False,
79+
)
80+
81+
# %%
82+
# Compute ensemble merge iterations
83+
# ============================================================
84+
MERGE_CFG: MergeConfig = MergeConfig(
85+
activation_threshold=0.01,
86+
alpha=0.01,
87+
iters=2,
88+
merge_pair_sampling_method="range",
89+
merge_pair_sampling_kwargs={"threshold": 0.1},
90+
pop_component_prob=0,
91+
module_name_filter=FILTER_MODULES,
92+
filter_dead_threshold=FILTER_DEAD_THRESHOLD,
93+
)
94+
95+
ENSEMBLE: MergeHistoryEnsemble = merge_iteration_ensemble(
96+
activations=PROCESSED_ACTIVATIONS.activations,
97+
component_labels=PROCESSED_ACTIVATIONS.labels,
98+
merge_config=MERGE_CFG,
99+
ensemble_size=2,
100+
)
101+
102+
103+
# %%
104+
# Compute and plot distances
105+
# ============================================================
106+
DISTANCES = ENSEMBLE.get_distances()
107+
108+
plot_dists_distribution(
109+
distances=DISTANCES,
110+
mode="points",
111+
)
112+
plt.legend()

tests/clustering/test_clustering_experiments.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66

77
import pytest
88

9+
# Test resource directories
10+
NOTEBOOK_DIR: Path = Path("tests/clustering/scripts")
11+
CONFIG_DIR: Path = Path("spd/clustering/configs")
12+
913

1014
@pytest.mark.slow
1115
def test_cluster_resid_mlp_notebook():
1216
"""Test running the cluster_resid_mlp.py notebook-style script."""
13-
script_path = Path("spd/clustering/experiments/cluster_resid_mlp.py")
17+
script_path = NOTEBOOK_DIR / "cluster_resid_mlp.py"
1418
assert script_path.exists(), f"Script not found: {script_path}"
1519

1620
# Run the script as-is
@@ -30,7 +34,7 @@ def test_cluster_resid_mlp_notebook():
3034
@pytest.mark.slow
3135
def test_clustering_with_resid_mlp1_config():
3236
"""Test running clustering with test-resid_mlp1.json config."""
33-
config_path = Path("spd/clustering/configs/test-resid_mlp1.json")
37+
config_path = CONFIG_DIR / "test-resid_mlp1.json"
3438
assert config_path.exists(), f"Config not found: {config_path}"
3539

3640
# Run the clustering main script with the test config
@@ -54,7 +58,7 @@ def test_clustering_with_resid_mlp1_config():
5458
@pytest.mark.slow
5559
def test_cluster_ss_notebook():
5660
"""Test running the cluster_ss.py notebook-style script."""
57-
script_path = Path("spd/clustering/experiments/cluster_ss.py")
61+
script_path = NOTEBOOK_DIR / "cluster_ss.py"
5862
assert script_path.exists(), f"Script not found: {script_path}"
5963

6064
# Run the script as-is
@@ -74,7 +78,7 @@ def test_cluster_ss_notebook():
7478
@pytest.mark.slow
7579
def test_clustering_with_simplestories_config():
7680
"""Test running clustering with test-simplestories.json config."""
77-
config_path = Path("spd/clustering/configs/test-simplestories.json")
81+
config_path = CONFIG_DIR / "test-simplestories.json"
7882
assert config_path.exists(), f"Config not found: {config_path}"
7983

8084
# Run the clustering main script with the test config

0 commit comments

Comments
 (0)