Skip to content

Commit 6633e69

Browse files
committed
wip
1 parent e809810 commit 6633e69

File tree

3 files changed

+53
-41
lines changed

3 files changed

+53
-41
lines changed

spd/clustering/pipeline/clustering_pipeline.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,33 @@
55
Each batch loads its own model and WandB run to match original design.
66
"""
77

8+
from collections.abc import Iterator
9+
from typing import Any
10+
811
from spd.clustering.merge_run_config import RunConfig
912

1013

1114
def main(config: RunConfig) -> None:
12-
from spd.clustering.consts import DistancesArray, MergesArray
15+
from spd.clustering.consts import DistancesArray, DistancesMethod, MergesArray
1316
from spd.clustering.math.merge_distances import (
1417
compute_distances,
1518
)
16-
from spd.clustering.pipeline.s1_split_dataset import split_dataset
19+
from spd.clustering.pipeline.s1_split_dataset import BatchTensor, split_dataset
1720
from spd.clustering.pipeline.s2_clustering import ClusteringResult, process_batches_parallel
1821
from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save
1922
from spd.clustering.pipeline.s4_compute_distances import create_clustering_report
2023
from spd.clustering.pipeline.storage import ClusteringStorage
2124

22-
storage = ClusteringStorage(base_path=config.base_path, run_identifier=config.config_identifier)
25+
storage: ClusteringStorage = ClusteringStorage(
26+
base_path=config.base_path, run_identifier=config.config_identifier
27+
)
2328

2429
print(f"Run record saved to {storage.run_config_file}")
2530
storage.save_run_config(config)
2631

2732
print(f"Splitting dataset into {config.n_batches} batches...")
33+
batches: Iterator[BatchTensor]
34+
dataset_config: dict[str, Any]
2835
batches, dataset_config = split_dataset(config=config)
2936
storage.save_batches(batches=batches, config=dataset_config)
3037

@@ -40,7 +47,8 @@ def main(config: RunConfig) -> None:
4047

4148
normalized_merge_array: MergesArray = normalize_and_save(storage=storage)
4249

43-
method = "perm_invariant_hamming"
50+
# TODO: read method from config
51+
method: DistancesMethod = "perm_invariant_hamming"
4452
distances: DistancesArray = compute_distances(
4553
normalized_merge_array=normalized_merge_array,
4654
method=method,

tests/clustering/scripts/cluster_resid_mlp.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@
7070
# %%
7171
# Get component activations
7272
# ============================================================
73+
# Get a single batch from the dataloader
74+
BATCH_DATA: tuple[Tensor, Tensor] = next(iter(DATALOADER))
75+
BATCH: Tensor = BATCH_DATA[0]
76+
7377
COMPONENT_ACTS: dict[str, Tensor] = component_activations(
7478
model=MODEL,
7579
device=DEVICE,
76-
dataloader=DATALOADER,
80+
batch=BATCH,
7781
sigmoid_type="hard",
7882
)
7983

@@ -88,13 +92,13 @@
8892
PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations(
8993
COMPONENT_ACTS,
9094
filter_dead_threshold=FILTER_DEAD_THRESHOLD,
91-
sort_components=False, # Test the new sorting functionality
9295
)
9396

9497

9598
plot_activations(
9699
processed_activations=PROCESSED_ACTIVATIONS,
97-
save_pdf=False,
100+
save_dir=None,
101+
wandb_run=None,
98102
)
99103

100104
# %%
@@ -113,27 +117,24 @@
113117

114118

115119
def _plot_func(
116-
costs: torch.Tensor,
117-
# merge_history: MergeHistory,
118-
current_merge: Any,
119120
current_coact: torch.Tensor,
120-
# current_act_mask: torch.Tensor,
121-
i: int,
122-
# k_groups: int,
123-
# activation_mask_orig: torch.Tensor,
124121
component_labels: list[str],
125-
# sweep_params: dict[str, Any],
126-
**kwargs: Any,
122+
current_merge: Any,
123+
costs: torch.Tensor,
124+
merge_history: MergeHistory,
125+
iter_idx: int,
126+
k_groups: int,
127+
merge_pair_cost: float,
128+
mdl_loss: float,
129+
mdl_loss_norm: float,
130+
diag_acts: torch.Tensor,
127131
) -> None:
128-
assert kwargs
129-
if (i % 50 == 0 and i > 0) or i == 1:
130-
# latest = merge_history.latest()
131-
# latest['merges'].plot()
132+
if (iter_idx % 50 == 0 and iter_idx > 0) or iter_idx == 1:
132133
plot_merge_iteration(
133134
current_merge=current_merge,
134135
current_coact=current_coact,
135136
costs=costs,
136-
iteration=i,
137+
iteration=iter_idx,
137138
component_labels=component_labels,
138139
show=True, # Show the plot interactively
139140
)
@@ -144,7 +145,7 @@ def _plot_func(
144145
batch_id="batch_0",
145146
activations=PROCESSED_ACTIVATIONS.activations,
146147
component_labels=PROCESSED_ACTIVATIONS.labels,
147-
log_callback=None,
148+
log_callback=_plot_func,
148149
)
149150

150151
# %%
@@ -162,18 +163,18 @@ def _plot_func(
162163

163164
# Modern approach: run merge_iteration multiple times to create ensemble
164165
ENSEMBLE_SIZE: int = 4
165-
histories: list[MergeHistory] = []
166+
HISTORIES: list[MergeHistory] = []
166167
for i in range(ENSEMBLE_SIZE):
167-
history: MergeHistory = merge_iteration(
168+
HISTORY: MergeHistory = merge_iteration(
168169
merge_config=MERGE_CFG,
169170
batch_id=f"batch_{i}",
170171
activations=PROCESSED_ACTIVATIONS.activations,
171172
component_labels=PROCESSED_ACTIVATIONS.labels,
172173
log_callback=None,
173174
)
174-
histories.append(history)
175+
HISTORIES.append(HISTORY)
175176

176-
ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories)
177+
ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES)
177178

178179
DISTANCES = ENSEMBLE.get_distances(method="perm_invariant_hamming")
179180

tests/clustering/scripts/cluster_ss.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# %%
22

33
import matplotlib.pyplot as plt
4-
import numpy as np
54
import torch
65
from jaxtyping import Int
76
from muutils.dbg import dbg_auto
@@ -15,7 +14,8 @@
1514
from spd.clustering.merge import merge_iteration
1615
from spd.clustering.merge_config import MergeConfig
1716
from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble
18-
from spd.clustering.pipeline.s1_split_dataset import split_dataset_lm
17+
from spd.clustering.merge_run_config import RunConfig
18+
from spd.clustering.pipeline.s1_split_dataset import split_dataset
1919
from spd.clustering.plotting.activations import plot_activations
2020
from spd.clustering.plotting.merge import plot_dists_distribution
2121
from spd.models.component_model import ComponentModel, SPDRunInfo
@@ -31,23 +31,25 @@
3131
# ============================================================
3232
MODEL_PATH: str = "wandb:goodfire/spd/runs/ioprgffh"
3333

34-
_, DATA_CFG = split_dataset_lm(
35-
model_path=MODEL_PATH,
36-
n_batches=1,
37-
batch_size=2,
38-
)
39-
DATASET_PATH: str = DATA_CFG["output_files"][0]
40-
4134
SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH)
4235
MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path)
4336
MODEL.to(DEVICE)
4437
SPD_CONFIG = SPD_RUN.config
4538

39+
# Use split_dataset with RunConfig to get real data
40+
CONFIG: RunConfig = RunConfig(
41+
merge_config=MergeConfig(),
42+
model_path=MODEL_PATH,
43+
task_name="lm",
44+
n_batches=1,
45+
batch_size=2,
46+
)
47+
BATCHES, _ = split_dataset(config=CONFIG)
4648

4749
# %%
4850
# Load data batch
4951
# ============================================================
50-
DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = torch.tensor(np.load(DATASET_PATH)["input_ids"])
52+
DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES)
5153

5254
# %%
5355
# Get component activations
@@ -75,7 +77,8 @@
7577

7678
plot_activations(
7779
processed_activations=PROCESSED_ACTIVATIONS,
78-
save_pdf=False,
80+
save_dir=None,
81+
wandb_run=None,
7982
)
8083

8184
# %%
@@ -94,18 +97,18 @@
9497

9598
# Modern approach: run merge_iteration multiple times to create ensemble
9699
ENSEMBLE_SIZE: int = 2
97-
histories: list[MergeHistory] = []
100+
HISTORIES: list[MergeHistory] = []
98101
for i in range(ENSEMBLE_SIZE):
99-
history: MergeHistory = merge_iteration(
102+
HISTORY: MergeHistory = merge_iteration(
100103
merge_config=MERGE_CFG,
101104
batch_id=f"batch_{i}",
102105
activations=PROCESSED_ACTIVATIONS.activations,
103106
component_labels=PROCESSED_ACTIVATIONS.labels,
104107
log_callback=None,
105108
)
106-
histories.append(history)
109+
HISTORIES.append(HISTORY)
107110

108-
ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories)
111+
ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES)
109112

110113

111114
# %%

0 commit comments

Comments
 (0)