Skip to content

Commit e30d680

Browse files
committed
wip
1 parent bc37ae3 commit e30d680

File tree

4 files changed

+96
-165
lines changed

4 files changed

+96
-165
lines changed

spd/clustering/clustering_pipeline.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pydantic import BaseModel
1111

12-
from spd.clustering.merge_run_config import RunFilePaths, MergeRunConfig
12+
from spd.clustering.merge_run_config import MergeRunConfig
1313

1414

1515
class RunRecord(BaseModel):
@@ -31,6 +31,12 @@ def main(
3131
3232
base_dir/
3333
{config.config_identifier}/
34+
dataset/
35+
config.json
36+
batches/
37+
batch_00.npz
38+
batch_01.npz
39+
...
3440
merge_histories/
3541
{config.config_identifier}-data_{batch_id}/
3642
merge_history.zip
@@ -53,7 +59,7 @@ def main(
5359
)
5460

5561
run_path = base_path / config.config_identifier
56-
histories_path = run_path / "merge_histories"
62+
histories_dir = run_path / "merge_histories"
5763
dataset_dir = run_path / "dataset"
5864
distances_dir = run_path / "distances"
5965
run_config_path = run_path / "run_config.json"
@@ -62,18 +68,13 @@ def main(
6268
run_config_path.write_text(config.model_dump_json(indent=2))
6369

6470
print(f"Splitting dataset into {config.n_batches} batches...")
65-
data_files = split_and_save_dataset(
66-
config=config,
67-
output_dir=dataset_dir,
68-
save_file_fmt="batch_{batch_idx}.npz",
69-
cfg_file_fmt="config.json", # just a place we save a raw dict of metadata
70-
)
71+
data_files = split_and_save_dataset(config=config, output_dir=dataset_dir)
7172

7273
print(f"Processing {len(data_files)} batches with {workers_per_device} workers per device...")
7374
results = process_batches_parallel(
7475
data_files=data_files,
7576
config=config,
76-
output_base_dir=histories_path,
77+
output_dir=histories_dir,
7778
workers_per_device=workers_per_device,
7879
devices=devices,
7980
)

spd/clustering/merge_run_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,10 @@ def model_dump_with_properties(self) -> dict[str, Any]:
234234
)
235235

236236
return base_dump
237+
238+
if __name__ == "__main__":
239+
with open("merge_run_config.json", "w") as f:
240+
json.dump(MergeRunConfig.model_json_schema(), f, indent=2)
241+
242+
# config = MergeRunConfig.from_file(Path("data/clustering/configs/1234567890.json"))
243+
# print(config.model_dump_with_properties())

spd/clustering/s1_split_dataset.py

Lines changed: 74 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1+
"""
2+
writes a dataset into a directory of batches
3+
4+
directory structure:
5+
dataset/
6+
config.json
7+
batches/
8+
batch_00.npz
9+
batch_01.npz
10+
...
11+
"""
12+
113
import json
14+
from collections.abc import Generator
215
from pathlib import Path
316
from typing import Any
417

@@ -16,15 +29,61 @@
1629
from spd.experiments.resid_mlp.models import ResidMLP
1730
from spd.models.component_model import ComponentModel, SPDRunInfo
1831

32+
CONFIG_FILE_PATH = "dataset_config.json"
33+
BATCHES_DIR_PATH = "batches"
34+
BATCH_FILE_FMT = "batch_{batch_idx:02d}.npz"
35+
1936

20-
def split_dataset_lm(
37+
def split_and_save_dataset(config: MergeRunConfig, output_dir: Path) -> list[Path]:
38+
"""Split a dataset into n_batches of batch_size and save the batches"""
39+
match config.task_name:
40+
case "lm":
41+
ds, ds_config_dict = _get_dataloader_lm(
42+
model_path=config.model_path,
43+
batch_size=config.batch_size,
44+
)
45+
case "resid_mlp":
46+
ds, ds_config_dict = _get_dataloader_resid_mlp(
47+
model_path=config.model_path,
48+
batch_size=config.batch_size,
49+
)
50+
case name:
51+
raise ValueError(
52+
f"Unsupported task name '{name}'. Supported tasks are 'lm' and 'resid_mlp'. {config.model_path=}, {name=}"
53+
)
54+
55+
# make dirs
56+
output_dir.mkdir(parents=True, exist_ok=True)
57+
58+
dataset_config_path = output_dir / CONFIG_FILE_PATH
59+
dataset_config_path.write_text(json.dumps(ds_config_dict, indent=2))
60+
61+
batches_dir = output_dir / BATCHES_DIR_PATH
62+
# iterate over the requested number of batches and save them
63+
output_paths: list[Path] = []
64+
for batch_idx, batch in tqdm(
65+
enumerate(ds),
66+
total=config.n_batches,
67+
unit="batch",
68+
):
69+
if batch_idx >= config.n_batches:
70+
break
71+
72+
batch_path: Path = batches_dir / BATCH_FILE_FMT.format(batch_idx=batch_idx)
73+
74+
np.savez_compressed(
75+
batch_path,
76+
input_ids=batch.cpu().numpy(),
77+
)
78+
output_paths.append(batch_path)
79+
80+
return output_paths
81+
82+
83+
def _get_dataloader_lm(
2184
model_path: str,
22-
n_batches: int,
2385
batch_size: int,
24-
output_dir: Path,
25-
save_file_fmt: str,
26-
cfg_file_fmt: str,
27-
) -> list[Path]:
86+
) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]:
2887
"""split up a SS dataset into n_batches of batch_size, returned the saved paths
2988
3089
1. load the config for a SimpleStories SPD Run given by model_path
@@ -42,11 +101,11 @@ def split_dataset_lm(
42101
assert pretrained_model_name is not None
43102
except Exception as e:
44103
raise AttributeError(
45-
"Could not find 'pretrained_model_name' in the SPD Run config, but called `split_dataset_lm`"
104+
"Could not find 'pretrained_model_name' in the SPD Run config, but called `_get_dataloader_lm`"
46105
) from e
47106

48107
assert isinstance(cfg.task_config, LMTaskConfig), (
49-
f"Expected task_config to be of type LMTaskConfig since using `split_dataset_lm`, but got {type(cfg.task_config) = }"
108+
f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }"
50109
)
51110

52111
dataset_config: DatasetConfig = DatasetConfig(
@@ -71,66 +130,13 @@ def split_dataset_lm(
71130
ddp_world_size=1,
72131
)
73132

74-
# make dirs
75-
output_dir.mkdir(parents=True, exist_ok=True)
76-
(
77-
output_dir
78-
/ save_file_fmt.format(batch_size=batch_size, batch_idx="XX", n_batches=f"{n_batches:02d}")
79-
).parent.mkdir(parents=True, exist_ok=True)
80-
# iterate over the requested number of batches and save them
81-
output_paths: list[Path] = []
82-
for batch_idx, batch in tqdm(
83-
enumerate(iter(dataloader)),
84-
total=n_batches,
85-
unit="batch",
86-
):
87-
if batch_idx >= n_batches:
88-
break
89-
batch_path: Path = output_dir / save_file_fmt.format(
90-
batch_size=batch_size,
91-
batch_idx=f"{batch_idx:02d}",
92-
n_batches=f"{n_batches:02d}",
93-
)
94-
np.savez_compressed(
95-
batch_path,
96-
input_ids=batch["input_ids"].cpu().numpy(),
97-
)
98-
output_paths.append(batch_path)
99-
100-
# save a config file
101-
cfg_path: Path = output_dir / cfg_file_fmt.format(batch_size=batch_size)
102-
cfg_data: dict[str, Any] = dict(
103-
# args to this function
104-
model_path=model_path,
105-
batch_size=batch_size,
106-
n_batches=n_batches,
107-
# dataset and tokenizer config
108-
dataset_config=dataset_config.model_dump(mode="json"),
109-
tokenizer_path=str(getattr(_tokenizer, "name_or_path", None)),
110-
tokenizer_type=str(getattr(_tokenizer, "__class__", None)),
111-
# files we saved
112-
output_files=[str(p) for p in output_paths],
113-
output_dir=str(output_dir),
114-
output_file_fmt=save_file_fmt,
115-
cfg_file_fmt=cfg_file_fmt,
116-
cfg_file=str(cfg_path),
117-
)
118-
cfg_path.parent.mkdir(parents=True, exist_ok=True)
119-
cfg_path.write_text(json.dumps(cfg_data, indent="\t"))
120-
121-
print(f"Saved config to: {cfg_path}")
122-
123-
return output_paths
133+
return (batch["input_ids"] for batch in dataloader), dataset_config.model_dump(mode="json")
124134

125135

126-
def split_dataset_resid_mlp(
136+
def _get_dataloader_resid_mlp(
127137
model_path: str,
128-
n_batches: int,
129138
batch_size: int,
130-
output_dir: Path,
131-
save_file_fmt: str,
132-
cfg_file_fmt: str,
133-
) -> list[Path]:
139+
) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]:
134140
"""Split a ResidMLP dataset into n_batches of batch_size and save the batches."""
135141
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset
136142
from spd.utils.data_utils import DatasetGeneratedDataLoader
@@ -143,14 +149,14 @@ def split_dataset_resid_mlp(
143149

144150
with SpinnerContext(message="Creating ResidMLPDataset..."):
145151
assert isinstance(cfg.task_config, ResidMLPTaskConfig), (
146-
f"Expected task_config to be of type ResidMLPTaskConfig since using `split_dataset_resid_mlp`, but got {type(cfg.task_config) = }"
152+
f"Expected task_config to be of type ResidMLPTaskConfig since using `_get_dataloader_resid_mlp`, but got {type(cfg.task_config) = }"
147153
)
148154
assert isinstance(component_model.target_model, ResidMLP), (
149-
f"Expected patched_model to be of type ResidMLP since using `split_dataset_resid_mlp`, but got {type(component_model.patched_model) = }"
155+
f"Expected patched_model to be of type ResidMLP since using `_get_dataloader_resid_mlp`, but got {type(component_model.patched_model) = }"
150156
)
151157

152158
assert isinstance(component_model.target_model.config, ResidMLPModelConfig), (
153-
f"Expected patched_model.config to be of type ResidMLPModelConfig since using `split_dataset_resid_mlp`, but got {type(component_model.target_model.config) = }"
159+
f"Expected patched_model.config to be of type ResidMLPModelConfig since using `_get_dataloader_resid_mlp`, but got {type(component_model.target_model.config) = }"
154160
)
155161
resid_mlp_dataset_kwargs: dict[str, Any] = dict(
156162
n_features=component_model.target_model.config.n_features,
@@ -167,87 +173,4 @@ def split_dataset_resid_mlp(
167173

168174
dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False)
169175

170-
# make dirs
171-
output_dir.mkdir(parents=True, exist_ok=True)
172-
(
173-
output_dir
174-
/ save_file_fmt.format(batch_size=batch_size, batch_idx="XX", n_batches=f"{n_batches:02d}")
175-
).parent.mkdir(parents=True, exist_ok=True)
176-
177-
# iterate over the requested number of batches and save them
178-
output_paths: list[Path] = []
179-
batch: torch.Tensor
180-
# second term in the tuple is same as the first
181-
for batch_idx, (batch, _) in tqdm(
182-
enumerate(iter(dataloader)),
183-
total=n_batches,
184-
unit="batch",
185-
):
186-
if batch_idx >= n_batches:
187-
break
188-
189-
batch_path: Path = output_dir / save_file_fmt.format(
190-
batch_size=batch_size,
191-
batch_idx=f"{batch_idx:02d}",
192-
n_batches=f"{n_batches:02d}",
193-
)
194-
np.savez_compressed(
195-
batch_path,
196-
input_ids=batch.cpu().numpy(),
197-
)
198-
output_paths.append(batch_path)
199-
200-
# save the config file
201-
cfg_path: Path = output_dir / cfg_file_fmt.format(batch_size=batch_size)
202-
cfg_data: dict[str, Any] = dict(
203-
# args to this function
204-
model_path=model_path,
205-
batch_size=batch_size,
206-
n_batches=n_batches,
207-
# dataset and tokenizer config
208-
resid_mlp_dataset_kwargs=resid_mlp_dataset_kwargs,
209-
# files we saved
210-
output_files=[str(p) for p in output_paths],
211-
output_dir=str(output_dir),
212-
output_file_fmt=save_file_fmt,
213-
cfg_file_fmt=cfg_file_fmt,
214-
cfg_file=str(cfg_path),
215-
)
216-
217-
cfg_path.parent.mkdir(parents=True, exist_ok=True)
218-
cfg_path.write_text(json.dumps(cfg_data, indent="\t"))
219-
print(f"Saved config to: {cfg_path}")
220-
221-
return output_paths
222-
223-
224-
def split_and_save_dataset(
225-
config: MergeRunConfig,
226-
output_dir: Path,
227-
save_file_fmt: str,
228-
cfg_file_fmt: str,
229-
) -> list[Path]:
230-
"""Split a dataset into n_batches of batch_size and save the batches"""
231-
match config.task_name:
232-
case "lm":
233-
return split_dataset_lm(
234-
model_path=config.model_path,
235-
n_batches=config.n_batches,
236-
batch_size=config.batch_size,
237-
output_dir=output_dir,
238-
save_file_fmt=save_file_fmt,
239-
cfg_file_fmt=cfg_file_fmt,
240-
)
241-
case "resid_mlp":
242-
return split_dataset_resid_mlp(
243-
model_path=config.model_path,
244-
n_batches=config.n_batches,
245-
batch_size=config.batch_size,
246-
output_dir=output_dir,
247-
save_file_fmt=save_file_fmt,
248-
cfg_file_fmt=cfg_file_fmt,
249-
)
250-
case name:
251-
raise ValueError(
252-
f"Unsupported task name '{name}'. Supported tasks are 'lm' and 'resid_mlp'. {config.model_path=}, {name=}"
253-
)
176+
return (batch[0] for batch in dataloader), resid_mlp_dataset_kwargs

spd/clustering/s2_clustering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ class ClusteringResult:
3030
def process_batches_parallel(
3131
config: MergeRunConfig,
3232
data_files: list[Path],
33-
output_base_dir: Path,
33+
output_dir: Path,
3434
workers_per_device: int,
3535
devices: list[str],
3636
) -> list[ClusteringResult]:
3737
worker_args = [
38-
(config, data_path, output_base_dir, devices[i % len(devices)])
38+
(config, data_path, output_dir, devices[i % len(devices)])
3939
for i, data_path in enumerate(data_files)
4040
]
4141

@@ -89,6 +89,9 @@ def _run_clustering(
8989
step=0,
9090
single=True,
9191
)
92+
wandb_url = run.url
93+
else:
94+
wandb_url = None
9295

9396
# Use original activations for raw plots, but filtered data for concat/coact/histograms
9497
logger.info("plotting")
@@ -118,10 +121,7 @@ def _run_clustering(
118121
# run, history_save_path, batch_id, config.config_identifier, history
119122
# )
120123

121-
wandb_url = run.url
122124
run.finish()
123-
else:
124-
wandb_url = None
125125

126126
return ClusteringResult(history_save_path=history_save_path, wandb_url=wandb_url)
127127

0 commit comments

Comments
 (0)