|
1 | | -""" |
2 | | -Orchestration layer - clean clustering pipeline coordination. |
3 | | -
|
4 | | -Replaces the original 370+ line subprocess/FD system with simple multiprocessing.Pool. |
5 | | -Each batch loads its own model and WandB run to match original design. |
6 | | -""" |
7 | | - |
8 | 1 | import argparse |
9 | 2 | from pathlib import Path |
10 | 3 |
|
| 4 | +from spd.clustering.clustering_pipeline import main |
11 | 5 | from spd.clustering.merge_run_config import MergeRunConfig |
12 | | -from spd.clustering.s1_split_dataset import split_and_save_dataset |
13 | | -from spd.clustering.s2_clustering import process_batches_parallel |
14 | | -from spd.clustering.s3_normalize_histories import normalize_and_ensemble_and_save |
15 | | -from spd.clustering.s4_compute_distances import ( |
16 | | - compute_and_save_distances_new, |
17 | | - create_clustering_report, |
18 | | -) |
19 | 6 | from spd.settings import REPO_ROOT |
20 | 7 |
|
21 | 8 |
|
22 | | -def main( |
23 | | - config: MergeRunConfig, |
24 | | - base_path: Path, |
25 | | - n_workers: int, |
26 | | - devices: list[str], |
27 | | -): |
28 | | - """ |
29 | | - The following is (hopefully) correct (thought see there's some repetition I'd like to change) |
30 | | -
|
31 | | - base_dir/ |
32 | | - {config.config_identifier}/ |
33 | | - merge_histories/ |
34 | | - {config.config_identifier}-data_{batch_id}/ |
35 | | - merge_history.zip |
36 | | - plots/ |
37 | | - activations_raw.pdf |
38 | | - activations_concat.pdf |
39 | | - activations_coact.pdf |
40 | | - activations_coact_log.pdf |
41 | | - merge_iteration.pdf |
42 | | - distances/ |
43 | | - figures/ |
44 | | - run_config.json |
45 | | - """ |
46 | | - |
47 | | - output_dir = base_path / config.config_identifier |
48 | | - |
49 | | - histories_path = output_dir / "merge_histories" |
50 | | - histories_path.mkdir(parents=True, exist_ok=True) |
51 | | - |
52 | | - # figures_path = output_dir / "figures" |
53 | | - # figures_path.mkdir(parents=True, exist_ok=True) |
54 | | - |
55 | | - distances_dir = output_dir / "distances" |
56 | | - distances_dir.mkdir(parents=True, exist_ok=True) |
57 | | - |
58 | | - # TODO see if we actually need this |
59 | | - # run_config_path = output_dir / "run_config.json" |
60 | | - # run_config_path.write_text( |
61 | | - # json.dumps( |
62 | | - # dict(merge_run_config=config.model_dump(mode="json"), base_path=str(base_path), devices=devices, max_concurrency=n_workers, plot=True, # can we remove this? repo_root=str(REPO_ROOT), run_id=config.config_identifier, run_path=str(output_dir),), |
63 | | - # indent="\t", |
64 | | - # ) |
65 | | - # ) |
66 | | - # print(f"Run config saved to {run_config_path}") |
67 | | - |
68 | | - print(f"Splitting dataset into {config.n_batches} batches...") |
69 | | - data_files = split_and_save_dataset( |
70 | | - config=config, |
71 | | - output_path=output_dir, |
72 | | - save_file_fmt="batch_{batch_idx}.npz", |
73 | | - cfg_file_fmt="config.json", # just a place we save a raw dict of metadata |
74 | | - ) |
75 | | - |
76 | | - print(f"Processing {len(data_files)} batches with {n_workers} workers...") |
77 | | - results = process_batches_parallel( |
78 | | - data_files=data_files, |
79 | | - config=config, |
80 | | - output_base_dir=histories_path, |
81 | | - n_workers=n_workers, |
82 | | - devices=devices, |
83 | | - ) |
84 | | - |
85 | | - enseble_merge_arr_path = normalize_and_ensemble_and_save( |
86 | | - history_paths=[r.history_save_path for r in results], |
87 | | - distances_dir=distances_dir, |
88 | | - ) |
89 | | - |
90 | | - distances = compute_and_save_distances_new( |
91 | | - merges_path=enseble_merge_arr_path, |
92 | | - method="perm_invariant_hamming", |
93 | | - ) |
94 | | - |
95 | | - create_clustering_report( |
96 | | - distances=distances, |
97 | | - method="perm_invariant_hamming", |
98 | | - wandb_urls=[r.wandb_url for r in results if r.wandb_url], # Gross - clean up, |
99 | | - config_identifier=config.config_identifier, |
100 | | - ) |
101 | | - |
102 | | - |
103 | 9 | def cli(): |
104 | 10 | """Command-line interface for clustering.""" |
105 | 11 | parser = argparse.ArgumentParser( |
@@ -138,6 +44,7 @@ def cli(): |
138 | 44 | # Parse devices |
139 | 45 | if args.devices is None: |
140 | 46 | import torch |
| 47 | + |
141 | 48 | devices = ["cuda" if torch.cuda.is_available() else "cpu"] |
142 | 49 | else: |
143 | 50 | devices = args.devices.split(",") |
|
0 commit comments