Skip to content

Commit bc37ae3

Browse files
committed
wip
1 parent 3ea89ae commit bc37ae3

File tree

5 files changed

+77
-46
lines changed

5 files changed

+77
-46
lines changed

spd/clustering/clustering_pipeline.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,24 @@
77

88
from pathlib import Path
99

10-
from spd.clustering.merge_run_config import MergeRunConfig
10+
from pydantic import BaseModel
11+
12+
from spd.clustering.merge_run_config import RunFilePaths, MergeRunConfig
13+
14+
15+
class RunRecord(BaseModel):
16+
merge_run_config: MergeRunConfig
17+
output_dir: Path
18+
devices: list[str]
19+
max_concurrency: int
20+
plot: bool
1121

1222

1323
def main(
1424
config: MergeRunConfig,
1525
base_path: Path,
16-
n_workers: int,
1726
devices: list[str],
27+
workers_per_device: int,
1828
):
1929
"""
2030
The following is (hopefully) correct (thought see there's some repetition I'd like to change)
@@ -42,38 +52,29 @@ def main(
4252
create_clustering_report,
4353
)
4454

45-
output_dir = base_path / config.config_identifier
46-
47-
histories_path = output_dir / "merge_histories"
48-
histories_path.mkdir(parents=True, exist_ok=True)
49-
50-
distances_dir = output_dir / "distances"
51-
distances_dir.mkdir(parents=True, exist_ok=True)
55+
run_path = base_path / config.config_identifier
56+
histories_path = run_path / "merge_histories"
57+
dataset_dir = run_path / "dataset"
58+
distances_dir = run_path / "distances"
59+
run_config_path = run_path / "run_config.json"
5260

53-
# TODO see if we actually need this
54-
# run_config_path = output_dir / "run_config.json"
55-
# run_config_path.write_text(
56-
# json.dumps(
57-
# dict(merge_run_config=config.model_dump(mode="json"), base_path=str(base_path), devices=devices, max_concurrency=n_workers, plot=True, # can we remove this? repo_root=str(REPO_ROOT), run_id=config.config_identifier, run_path=str(output_dir),),
58-
# indent="\t",
59-
# )
60-
# )
61-
# print(f"Run config saved to {run_config_path}")
61+
print(f"Run config saved to {run_config_path}")
62+
run_config_path.write_text(config.model_dump_json(indent=2))
6263

6364
print(f"Splitting dataset into {config.n_batches} batches...")
6465
data_files = split_and_save_dataset(
6566
config=config,
66-
output_path=output_dir,
67+
output_dir=dataset_dir,
6768
save_file_fmt="batch_{batch_idx}.npz",
6869
cfg_file_fmt="config.json", # just a place we save a raw dict of metadata
6970
)
7071

71-
print(f"Processing {len(data_files)} batches with {n_workers} workers...")
72+
print(f"Processing {len(data_files)} batches with {workers_per_device} workers per device...")
7273
results = process_batches_parallel(
7374
data_files=data_files,
7475
config=config,
7576
output_base_dir=histories_path,
76-
n_workers=n_workers,
77+
workers_per_device=workers_per_device,
7778
devices=devices,
7879
)
7980

spd/clustering/merge_run_config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,47 @@
3535
}
3636

3737

38+
class RunFilePaths:
39+
def __init__(self, run_path: Path):
40+
self.run_path = run_path
41+
42+
@property
43+
def histories_path(self) -> Path:
44+
return self.run_path / "merge_histories"
45+
46+
@property
47+
def distances_dir(self) -> Path:
48+
return self.run_path / "distances"
49+
50+
@property
51+
def run_config_path(self) -> Path:
52+
return self.run_path / "run_config.json"
53+
54+
def scaffold(self) -> None:
55+
self.histories_path.mkdir(exist_ok=True)
56+
self.distances_dir.mkdir(exist_ok=True)
57+
58+
3859
class MergeRunConfig(MergeConfig):
3960
"""Configuration for a complete merge clustering run.
4061
4162
Extends MergeConfig with parameters for model, dataset, and batch configuration.
4263
CLI-only parameters (base_path, devices, max_concurrency) are intentionally excluded.
4364
"""
4465

66+
base_path: Path = Field(
67+
...,
68+
description="Base path for saving clustering outputs",
69+
)
70+
workers_per_device: int = Field(
71+
...,
72+
description="Maximum number of concurrent clustering processes per device",
73+
)
74+
devices: list[str] = Field(
75+
...,
76+
description="Devices to use for clustering",
77+
)
78+
4579
model_path: str = Field(
4680
description="WandB path to the model (format: wandb:entity/project/run_id)",
4781
)

spd/clustering/s1_split_dataset.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def split_dataset_lm(
2121
model_path: str,
2222
n_batches: int,
2323
batch_size: int,
24-
output_path: Path,
24+
output_dir: Path,
2525
save_file_fmt: str,
2626
cfg_file_fmt: str,
2727
) -> list[Path]:
@@ -72,9 +72,9 @@ def split_dataset_lm(
7272
)
7373

7474
# make dirs
75-
output_path.mkdir(parents=True, exist_ok=True)
75+
output_dir.mkdir(parents=True, exist_ok=True)
7676
(
77-
output_path
77+
output_dir
7878
/ save_file_fmt.format(batch_size=batch_size, batch_idx="XX", n_batches=f"{n_batches:02d}")
7979
).parent.mkdir(parents=True, exist_ok=True)
8080
# iterate over the requested number of batches and save them
@@ -86,7 +86,7 @@ def split_dataset_lm(
8686
):
8787
if batch_idx >= n_batches:
8888
break
89-
batch_path: Path = output_path / save_file_fmt.format(
89+
batch_path: Path = output_dir / save_file_fmt.format(
9090
batch_size=batch_size,
9191
batch_idx=f"{batch_idx:02d}",
9292
n_batches=f"{n_batches:02d}",
@@ -98,7 +98,7 @@ def split_dataset_lm(
9898
output_paths.append(batch_path)
9999

100100
# save a config file
101-
cfg_path: Path = output_path / cfg_file_fmt.format(batch_size=batch_size)
101+
cfg_path: Path = output_dir / cfg_file_fmt.format(batch_size=batch_size)
102102
cfg_data: dict[str, Any] = dict(
103103
# args to this function
104104
model_path=model_path,
@@ -110,7 +110,7 @@ def split_dataset_lm(
110110
tokenizer_type=str(getattr(_tokenizer, "__class__", None)),
111111
# files we saved
112112
output_files=[str(p) for p in output_paths],
113-
output_dir=str(output_path),
113+
output_dir=str(output_dir),
114114
output_file_fmt=save_file_fmt,
115115
cfg_file_fmt=cfg_file_fmt,
116116
cfg_file=str(cfg_path),
@@ -127,7 +127,7 @@ def split_dataset_resid_mlp(
127127
model_path: str,
128128
n_batches: int,
129129
batch_size: int,
130-
output_path: Path,
130+
output_dir: Path,
131131
save_file_fmt: str,
132132
cfg_file_fmt: str,
133133
) -> list[Path]:
@@ -168,9 +168,9 @@ def split_dataset_resid_mlp(
168168
dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False)
169169

170170
# make dirs
171-
output_path.mkdir(parents=True, exist_ok=True)
171+
output_dir.mkdir(parents=True, exist_ok=True)
172172
(
173-
output_path
173+
output_dir
174174
/ save_file_fmt.format(batch_size=batch_size, batch_idx="XX", n_batches=f"{n_batches:02d}")
175175
).parent.mkdir(parents=True, exist_ok=True)
176176

@@ -186,7 +186,7 @@ def split_dataset_resid_mlp(
186186
if batch_idx >= n_batches:
187187
break
188188

189-
batch_path: Path = output_path / save_file_fmt.format(
189+
batch_path: Path = output_dir / save_file_fmt.format(
190190
batch_size=batch_size,
191191
batch_idx=f"{batch_idx:02d}",
192192
n_batches=f"{n_batches:02d}",
@@ -198,7 +198,7 @@ def split_dataset_resid_mlp(
198198
output_paths.append(batch_path)
199199

200200
# save the config file
201-
cfg_path: Path = output_path / cfg_file_fmt.format(batch_size=batch_size)
201+
cfg_path: Path = output_dir / cfg_file_fmt.format(batch_size=batch_size)
202202
cfg_data: dict[str, Any] = dict(
203203
# args to this function
204204
model_path=model_path,
@@ -208,7 +208,7 @@ def split_dataset_resid_mlp(
208208
resid_mlp_dataset_kwargs=resid_mlp_dataset_kwargs,
209209
# files we saved
210210
output_files=[str(p) for p in output_paths],
211-
output_dir=str(output_path),
211+
output_dir=str(output_dir),
212212
output_file_fmt=save_file_fmt,
213213
cfg_file_fmt=cfg_file_fmt,
214214
cfg_file=str(cfg_path),
@@ -223,7 +223,7 @@ def split_dataset_resid_mlp(
223223

224224
def split_and_save_dataset(
225225
config: MergeRunConfig,
226-
output_path: Path,
226+
output_dir: Path,
227227
save_file_fmt: str,
228228
cfg_file_fmt: str,
229229
) -> list[Path]:
@@ -234,7 +234,7 @@ def split_and_save_dataset(
234234
model_path=config.model_path,
235235
n_batches=config.n_batches,
236236
batch_size=config.batch_size,
237-
output_path=output_path,
237+
output_dir=output_dir,
238238
save_file_fmt=save_file_fmt,
239239
cfg_file_fmt=cfg_file_fmt,
240240
)
@@ -243,7 +243,7 @@ def split_and_save_dataset(
243243
model_path=config.model_path,
244244
n_batches=config.n_batches,
245245
batch_size=config.batch_size,
246-
output_path=output_path,
246+
output_dir=output_dir,
247247
save_file_fmt=save_file_fmt,
248248
cfg_file_fmt=cfg_file_fmt,
249249
)

spd/clustering/s2_clustering.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,19 @@ class ClusteringResult:
2727
wandb_url: str | None
2828

2929

30-
# TODO consider making this a generator
3130
def process_batches_parallel(
3231
config: MergeRunConfig,
3332
data_files: list[Path],
3433
output_base_dir: Path,
35-
n_workers: int,
34+
workers_per_device: int,
3635
devices: list[str],
3736
) -> list[ClusteringResult]:
38-
devices = devices or ["cuda:0"]
39-
40-
# Create worker arguments with device assignment
4137
worker_args = [
4238
(config, data_path, output_base_dir, devices[i % len(devices)])
4339
for i, data_path in enumerate(data_files)
4440
]
4541

46-
with Pool(n_workers) as pool:
42+
with Pool(workers_per_device * len(devices)) as pool:
4743
results = pool.map(_worker_fn, worker_args)
4844

4945
return results

spd/clustering/scripts/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def cli():
3333
help="comma-separated list of devices to use for clustering (e.g., 'cuda:0,cuda:1')",
3434
)
3535
parser.add_argument(
36-
"--max-concurrency",
36+
"--workers-per-device",
3737
"-x",
3838
type=int,
39-
default=None,
40-
help="Maximum number of concurrent clustering processes (default: all devices)",
39+
default=1,
40+
help="Maximum number of concurrent clustering processes per device (default: 1)",
4141
)
4242
args = parser.parse_args()
4343

@@ -53,7 +53,7 @@ def cli():
5353
config=MergeRunConfig.from_file(args.config),
5454
base_path=args.base_path,
5555
devices=devices,
56-
n_workers=args.max_concurrency if args.max_concurrency is not None else len(devices),
56+
workers_per_device=args.workers_per_device,
5757
)
5858

5959

0 commit comments

Comments
 (0)