Skip to content

Commit 3ea89ae

Browse files
committed
wip
1 parent b44abe8 commit 3ea89ae

File tree

11 files changed

+4682
-143
lines changed

11 files changed

+4682
-143
lines changed

oli.patch

Lines changed: 4571 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Orchestration layer - clean clustering pipeline coordination.
3+
4+
Replaces the original 370+ line subprocess/FD system with simple multiprocessing.Pool.
5+
Each batch loads its own model and WandB run to match original design.
6+
"""
7+
8+
from pathlib import Path
9+
10+
from spd.clustering.merge_run_config import MergeRunConfig
11+
12+
13+
def main(
14+
config: MergeRunConfig,
15+
base_path: Path,
16+
n_workers: int,
17+
devices: list[str],
18+
):
19+
"""
20+
The following is (hopefully) correct (thought see there's some repetition I'd like to change)
21+
22+
base_dir/
23+
{config.config_identifier}/
24+
merge_histories/
25+
{config.config_identifier}-data_{batch_id}/
26+
merge_history.zip
27+
plots/
28+
activations_raw.pdf
29+
activations_concat.pdf
30+
activations_coact.pdf
31+
activations_coact_log.pdf
32+
merge_iteration.pdf
33+
distances/
34+
figures/
35+
run_config.json
36+
"""
37+
from spd.clustering.s1_split_dataset import split_and_save_dataset
38+
from spd.clustering.s2_clustering import process_batches_parallel
39+
from spd.clustering.s3_normalize_histories import normalize_and_ensemble_and_save
40+
from spd.clustering.s4_compute_distances import (
41+
compute_and_save_distances_new,
42+
create_clustering_report,
43+
)
44+
45+
output_dir = base_path / config.config_identifier
46+
47+
histories_path = output_dir / "merge_histories"
48+
histories_path.mkdir(parents=True, exist_ok=True)
49+
50+
distances_dir = output_dir / "distances"
51+
distances_dir.mkdir(parents=True, exist_ok=True)
52+
53+
# TODO see if we actually need this
54+
# run_config_path = output_dir / "run_config.json"
55+
# run_config_path.write_text(
56+
# json.dumps(
57+
# dict(merge_run_config=config.model_dump(mode="json"), base_path=str(base_path), devices=devices, max_concurrency=n_workers, plot=True, # can we remove this? repo_root=str(REPO_ROOT), run_id=config.config_identifier, run_path=str(output_dir),),
58+
# indent="\t",
59+
# )
60+
# )
61+
# print(f"Run config saved to {run_config_path}")
62+
63+
print(f"Splitting dataset into {config.n_batches} batches...")
64+
data_files = split_and_save_dataset(
65+
config=config,
66+
output_path=output_dir,
67+
save_file_fmt="batch_{batch_idx}.npz",
68+
cfg_file_fmt="config.json", # just a place we save a raw dict of metadata
69+
)
70+
71+
print(f"Processing {len(data_files)} batches with {n_workers} workers...")
72+
results = process_batches_parallel(
73+
data_files=data_files,
74+
config=config,
75+
output_base_dir=histories_path,
76+
n_workers=n_workers,
77+
devices=devices,
78+
)
79+
80+
enseble_merge_arr_path = normalize_and_ensemble_and_save(
81+
history_paths=[r.history_save_path for r in results],
82+
distances_dir=distances_dir,
83+
)
84+
85+
distances = compute_and_save_distances_new(
86+
merges_path=enseble_merge_arr_path,
87+
method="perm_invariant_hamming",
88+
)
89+
90+
create_clustering_report(
91+
distances=distances,
92+
method="perm_invariant_hamming",
93+
wandb_urls=[r.wandb_url for r in results if r.wandb_url], # Gross - clean up,
94+
config_identifier=config.config_identifier,
95+
)

spd/clustering/math/perm_invariant_hamming.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from jaxtyping import Float, Int
3+
from scipy.optimize import linear_sum_assignment
34

45

56
def perm_invariant_hamming_matrix(

spd/clustering/merge.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,8 @@ def _wandb_iter_log(
259259
)
260260

261261
if iter_idx > 0 and iter_idx % config.intervals["artifact"] == 0:
262-
with tempfile.TemporaryFile() as tmp_file:
263-
file: Path = Path(tmp_file.name)
264-
file.parent.mkdir(parents=True, exist_ok=True)
262+
with tempfile.NamedTemporaryFile() as tmp_file:
263+
file = Path(tmp_file.name)
265264
merge_history.save(file)
266265
artifact = wandb.Artifact(
267266
name=f"merge_hist_iter.{batch_id}.iter_{iter_idx}",

spd/clustering/s25.py

Whitespace-only changes.

spd/clustering/s2_clustering.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import wandb
99
from jaxtyping import Int
1010
from torch import Tensor
11-
from tqdm import tqdm
1211
from wandb.sdk.wandb_run import Run
1312

1413
from spd.clustering.activations import component_activations, process_activations
@@ -28,10 +27,6 @@ class ClusteringResult:
2827
wandb_url: str | None
2928

3029

31-
def _worker_fn(args: tuple[MergeRunConfig, Path, Path, str]) -> ClusteringResult:
32-
return run_clustering(*args)
33-
34-
3530
# TODO consider making this a generator
3631
def process_batches_parallel(
3732
config: MergeRunConfig,
@@ -48,22 +43,17 @@ def process_batches_parallel(
4843
for i, data_path in enumerate(data_files)
4944
]
5045

51-
# Simple pool without initializer
52-
# with Pool(n_workers) as pool:
53-
# # Process batches with progress bar
54-
# results = list(
55-
# tqdm(
56-
# pool.imap(_worker_fn, worker_args),
57-
# total=len(data_files),
58-
# desc="Processing batches",
59-
# )
60-
# )
61-
results = [_worker_fn(args) for args in worker_args]
46+
with Pool(n_workers) as pool:
47+
results = pool.map(_worker_fn, worker_args)
6248

6349
return results
6450

6551

66-
def run_clustering(
52+
def _worker_fn(args: tuple[MergeRunConfig, Path, Path, str]) -> ClusteringResult:
53+
return _run_clustering(*args)
54+
55+
56+
def _run_clustering(
6757
config: MergeRunConfig,
6858
data_path: Path,
6959
output_base_dir: Path,
@@ -167,29 +157,6 @@ def _setup_wandb(
167157
return run
168158

169159

170-
def _save_merge_history_to_wandb(
171-
run: Run,
172-
history_path: Path,
173-
batch_id: str,
174-
config_identifier: str,
175-
history: MergeHistory,
176-
):
177-
artifact = wandb.Artifact(
178-
name=f"merge_history_{batch_id}",
179-
type="merge_history",
180-
description=f"Merge history for batch {batch_id}",
181-
metadata={
182-
"batch_name": batch_id,
183-
"config_identifier": config_identifier,
184-
"n_iters_current": history.n_iters_current,
185-
"filename": history_path,
186-
},
187-
)
188-
# Add both files before logging the artifact
189-
artifact.add_file(str(history_path))
190-
run.log_artifact(artifact)
191-
192-
193160
def _log_merge_history_plots_to_wandb(run: Run, history: MergeHistory):
194161
fig_cs = plot_merge_history_cluster_sizes(history=history)
195162

spd/clustering/s3_normalize_histories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def normalize_and_ensemble_and_save(
4646
ZANJ().save(ensemble, path_hist_ensemble)
4747
logger.info(f"Ensemble saved to {path_hist_ensemble}")
4848

49-
return enseble_merge_arr_path
49+
return enseble_merge_arr_path

spd/clustering/scripts/main.py

Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,11 @@
1-
"""
2-
Orchestration layer - clean clustering pipeline coordination.
3-
4-
Replaces the original 370+ line subprocess/FD system with simple multiprocessing.Pool.
5-
Each batch loads its own model and WandB run to match original design.
6-
"""
7-
81
import argparse
92
from pathlib import Path
103

4+
from spd.clustering.clustering_pipeline import main
115
from spd.clustering.merge_run_config import MergeRunConfig
12-
from spd.clustering.s1_split_dataset import split_and_save_dataset
13-
from spd.clustering.s2_clustering import process_batches_parallel
14-
from spd.clustering.s3_normalize_histories import normalize_and_ensemble_and_save
15-
from spd.clustering.s4_compute_distances import (
16-
compute_and_save_distances_new,
17-
create_clustering_report,
18-
)
196
from spd.settings import REPO_ROOT
207

218

22-
def main(
23-
config: MergeRunConfig,
24-
base_path: Path,
25-
n_workers: int,
26-
devices: list[str],
27-
):
28-
"""
29-
The following is (hopefully) correct (thought see there's some repetition I'd like to change)
30-
31-
base_dir/
32-
{config.config_identifier}/
33-
merge_histories/
34-
{config.config_identifier}-data_{batch_id}/
35-
merge_history.zip
36-
plots/
37-
activations_raw.pdf
38-
activations_concat.pdf
39-
activations_coact.pdf
40-
activations_coact_log.pdf
41-
merge_iteration.pdf
42-
distances/
43-
figures/
44-
run_config.json
45-
"""
46-
47-
output_dir = base_path / config.config_identifier
48-
49-
histories_path = output_dir / "merge_histories"
50-
histories_path.mkdir(parents=True, exist_ok=True)
51-
52-
# figures_path = output_dir / "figures"
53-
# figures_path.mkdir(parents=True, exist_ok=True)
54-
55-
distances_dir = output_dir / "distances"
56-
distances_dir.mkdir(parents=True, exist_ok=True)
57-
58-
# TODO see if we actually need this
59-
# run_config_path = output_dir / "run_config.json"
60-
# run_config_path.write_text(
61-
# json.dumps(
62-
# dict(merge_run_config=config.model_dump(mode="json"), base_path=str(base_path), devices=devices, max_concurrency=n_workers, plot=True, # can we remove this? repo_root=str(REPO_ROOT), run_id=config.config_identifier, run_path=str(output_dir),),
63-
# indent="\t",
64-
# )
65-
# )
66-
# print(f"Run config saved to {run_config_path}")
67-
68-
print(f"Splitting dataset into {config.n_batches} batches...")
69-
data_files = split_and_save_dataset(
70-
config=config,
71-
output_path=output_dir,
72-
save_file_fmt="batch_{batch_idx}.npz",
73-
cfg_file_fmt="config.json", # just a place we save a raw dict of metadata
74-
)
75-
76-
print(f"Processing {len(data_files)} batches with {n_workers} workers...")
77-
results = process_batches_parallel(
78-
data_files=data_files,
79-
config=config,
80-
output_base_dir=histories_path,
81-
n_workers=n_workers,
82-
devices=devices,
83-
)
84-
85-
enseble_merge_arr_path = normalize_and_ensemble_and_save(
86-
history_paths=[r.history_save_path for r in results],
87-
distances_dir=distances_dir,
88-
)
89-
90-
distances = compute_and_save_distances_new(
91-
merges_path=enseble_merge_arr_path,
92-
method="perm_invariant_hamming",
93-
)
94-
95-
create_clustering_report(
96-
distances=distances,
97-
method="perm_invariant_hamming",
98-
wandb_urls=[r.wandb_url for r in results if r.wandb_url], # Gross - clean up,
99-
config_identifier=config.config_identifier,
100-
)
101-
102-
1039
def cli():
10410
"""Command-line interface for clustering."""
10511
parser = argparse.ArgumentParser(
@@ -138,6 +44,7 @@ def cli():
13844
# Parse devices
13945
if args.devices is None:
14046
import torch
47+
14148
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
14249
else:
14350
devices = args.devices.split(",")

spd/clustering/sweep.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import matplotlib.cm as cm
66
import matplotlib.pyplot as plt
77
import numpy as np
8-
import torch
98
from matplotlib.colors import LogNorm
109
from matplotlib.lines import Line2D
11-
from tqdm import tqdm
1210

1311
from spd.clustering.merge_config import MergeConfig
1412
from spd.clustering.merge_history import MergeHistory

spd/models/component_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def from_path(cls, path: ModelPath) -> "SPDRunInfo":
6868

6969
return cls(checkpoint_path=comp_model_path, config=config)
7070

71+
7172
# TODO encapsulate Gates in a separate class (containing sigmoid type and sampling mode)
7273
class ComponentModel(LoadableModule):
7374
"""Wrapper around an arbitrary model for running SPD.

0 commit comments

Comments
 (0)