1+ """
2+ writes a dataset into a directory of batches
3+
4+ directory structure:
5+ dataset/
6+ config.json
7+ batches/
8+ batch_00.npz
9+ batch_01.npz
10+ ...
11+ """
12+
113import json
14+ from collections .abc import Generator
215from pathlib import Path
316from typing import Any
417
1629from spd .experiments .resid_mlp .models import ResidMLP
1730from spd .models .component_model import ComponentModel , SPDRunInfo
1831
32+ CONFIG_FILE_PATH = "dataset_config.json"
33+ BATCHES_DIR_PATH = "batches"
34+ BATCH_FILE_FMT = "batch_{batch_idx:02d}.npz"
35+
1936
20- def split_dataset_lm (
37+ def split_and_save_dataset (config : MergeRunConfig , output_dir : Path ) -> list [Path ]:
38+ """Split a dataset into n_batches of batch_size and save the batches"""
39+ match config .task_name :
40+ case "lm" :
41+ ds , ds_config_dict = _get_dataloader_lm (
42+ model_path = config .model_path ,
43+ batch_size = config .batch_size ,
44+ )
45+ case "resid_mlp" :
46+ ds , ds_config_dict = _get_dataloader_resid_mlp (
47+ model_path = config .model_path ,
48+ batch_size = config .batch_size ,
49+ )
50+ case name :
51+ raise ValueError (
52+ f"Unsupported task name '{ name } '. Supported tasks are 'lm' and 'resid_mlp'. { config .model_path = } , { name = } "
53+ )
54+
55+ # make dirs
56+ output_dir .mkdir (parents = True , exist_ok = True )
57+
58+ dataset_config_path = output_dir / CONFIG_FILE_PATH
59+ dataset_config_path .write_text (json .dumps (ds_config_dict , indent = 2 ))
60+
61+ batches_dir = output_dir / BATCHES_DIR_PATH
62+ # iterate over the requested number of batches and save them
63+ output_paths : list [Path ] = []
64+ for batch_idx , batch in tqdm (
65+ enumerate (ds ),
66+ total = config .n_batches ,
67+ unit = "batch" ,
68+ ):
69+ if batch_idx >= config .n_batches :
70+ break
71+
72+ batch_path : Path = batches_dir / BATCH_FILE_FMT .format (batch_idx = batch_idx )
73+
74+ np .savez_compressed (
75+ batch_path ,
76+ input_ids = batch .cpu ().numpy (),
77+ )
78+ output_paths .append (batch_path )
79+
80+ return output_paths
81+
82+
83+ def _get_dataloader_lm (
2184 model_path : str ,
22- n_batches : int ,
2385 batch_size : int ,
24- output_dir : Path ,
25- save_file_fmt : str ,
26- cfg_file_fmt : str ,
27- ) -> list [Path ]:
86+ ) -> tuple [Generator [torch .Tensor , None , None ], dict [str , Any ]]:
2887 """split up a SS dataset into n_batches of batch_size, returned the saved paths
2988
3089 1. load the config for a SimpleStories SPD Run given by model_path
@@ -42,11 +101,11 @@ def split_dataset_lm(
42101 assert pretrained_model_name is not None
43102 except Exception as e :
44103 raise AttributeError (
45- "Could not find 'pretrained_model_name' in the SPD Run config, but called `split_dataset_lm `"
104+ "Could not find 'pretrained_model_name' in the SPD Run config, but called `_get_dataloader_lm `"
46105 ) from e
47106
48107 assert isinstance (cfg .task_config , LMTaskConfig ), (
49- f"Expected task_config to be of type LMTaskConfig since using `split_dataset_lm `, but got { type (cfg .task_config ) = } "
108+ f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm `, but got { type (cfg .task_config ) = } "
50109 )
51110
52111 dataset_config : DatasetConfig = DatasetConfig (
@@ -71,66 +130,13 @@ def split_dataset_lm(
71130 ddp_world_size = 1 ,
72131 )
73132
74- # make dirs
75- output_dir .mkdir (parents = True , exist_ok = True )
76- (
77- output_dir
78- / save_file_fmt .format (batch_size = batch_size , batch_idx = "XX" , n_batches = f"{ n_batches :02d} " )
79- ).parent .mkdir (parents = True , exist_ok = True )
80- # iterate over the requested number of batches and save them
81- output_paths : list [Path ] = []
82- for batch_idx , batch in tqdm (
83- enumerate (iter (dataloader )),
84- total = n_batches ,
85- unit = "batch" ,
86- ):
87- if batch_idx >= n_batches :
88- break
89- batch_path : Path = output_dir / save_file_fmt .format (
90- batch_size = batch_size ,
91- batch_idx = f"{ batch_idx :02d} " ,
92- n_batches = f"{ n_batches :02d} " ,
93- )
94- np .savez_compressed (
95- batch_path ,
96- input_ids = batch ["input_ids" ].cpu ().numpy (),
97- )
98- output_paths .append (batch_path )
99-
100- # save a config file
101- cfg_path : Path = output_dir / cfg_file_fmt .format (batch_size = batch_size )
102- cfg_data : dict [str , Any ] = dict (
103- # args to this function
104- model_path = model_path ,
105- batch_size = batch_size ,
106- n_batches = n_batches ,
107- # dataset and tokenizer config
108- dataset_config = dataset_config .model_dump (mode = "json" ),
109- tokenizer_path = str (getattr (_tokenizer , "name_or_path" , None )),
110- tokenizer_type = str (getattr (_tokenizer , "__class__" , None )),
111- # files we saved
112- output_files = [str (p ) for p in output_paths ],
113- output_dir = str (output_dir ),
114- output_file_fmt = save_file_fmt ,
115- cfg_file_fmt = cfg_file_fmt ,
116- cfg_file = str (cfg_path ),
117- )
118- cfg_path .parent .mkdir (parents = True , exist_ok = True )
119- cfg_path .write_text (json .dumps (cfg_data , indent = "\t " ))
120-
121- print (f"Saved config to: { cfg_path } " )
122-
123- return output_paths
133+ return (batch ["input_ids" ] for batch in dataloader ), dataset_config .model_dump (mode = "json" )
124134
125135
126- def split_dataset_resid_mlp (
136+ def _get_dataloader_resid_mlp (
127137 model_path : str ,
128- n_batches : int ,
129138 batch_size : int ,
130- output_dir : Path ,
131- save_file_fmt : str ,
132- cfg_file_fmt : str ,
133- ) -> list [Path ]:
139+ ) -> tuple [Generator [torch .Tensor , None , None ], dict [str , Any ]]:
134140 """Split a ResidMLP dataset into n_batches of batch_size and save the batches."""
135141 from spd .experiments .resid_mlp .resid_mlp_dataset import ResidMLPDataset
136142 from spd .utils .data_utils import DatasetGeneratedDataLoader
@@ -143,14 +149,14 @@ def split_dataset_resid_mlp(
143149
144150 with SpinnerContext (message = "Creating ResidMLPDataset..." ):
145151 assert isinstance (cfg .task_config , ResidMLPTaskConfig ), (
146- f"Expected task_config to be of type ResidMLPTaskConfig since using `split_dataset_resid_mlp `, but got { type (cfg .task_config ) = } "
152+ f"Expected task_config to be of type ResidMLPTaskConfig since using `_get_dataloader_resid_mlp `, but got { type (cfg .task_config ) = } "
147153 )
148154 assert isinstance (component_model .target_model , ResidMLP ), (
149- f"Expected patched_model to be of type ResidMLP since using `split_dataset_resid_mlp `, but got { type (component_model .patched_model ) = } "
155+ f"Expected patched_model to be of type ResidMLP since using `_get_dataloader_resid_mlp `, but got { type (component_model .patched_model ) = } "
150156 )
151157
152158 assert isinstance (component_model .target_model .config , ResidMLPModelConfig ), (
153- f"Expected patched_model.config to be of type ResidMLPModelConfig since using `split_dataset_resid_mlp `, but got { type (component_model .target_model .config ) = } "
159+ f"Expected patched_model.config to be of type ResidMLPModelConfig since using `_get_dataloader_resid_mlp `, but got { type (component_model .target_model .config ) = } "
154160 )
155161 resid_mlp_dataset_kwargs : dict [str , Any ] = dict (
156162 n_features = component_model .target_model .config .n_features ,
@@ -167,87 +173,4 @@ def split_dataset_resid_mlp(
167173
168174 dataloader = DatasetGeneratedDataLoader (dataset , batch_size = batch_size , shuffle = False )
169175
170- # make dirs
171- output_dir .mkdir (parents = True , exist_ok = True )
172- (
173- output_dir
174- / save_file_fmt .format (batch_size = batch_size , batch_idx = "XX" , n_batches = f"{ n_batches :02d} " )
175- ).parent .mkdir (parents = True , exist_ok = True )
176-
177- # iterate over the requested number of batches and save them
178- output_paths : list [Path ] = []
179- batch : torch .Tensor
180- # second term in the tuple is same as the first
181- for batch_idx , (batch , _ ) in tqdm (
182- enumerate (iter (dataloader )),
183- total = n_batches ,
184- unit = "batch" ,
185- ):
186- if batch_idx >= n_batches :
187- break
188-
189- batch_path : Path = output_dir / save_file_fmt .format (
190- batch_size = batch_size ,
191- batch_idx = f"{ batch_idx :02d} " ,
192- n_batches = f"{ n_batches :02d} " ,
193- )
194- np .savez_compressed (
195- batch_path ,
196- input_ids = batch .cpu ().numpy (),
197- )
198- output_paths .append (batch_path )
199-
200- # save the config file
201- cfg_path : Path = output_dir / cfg_file_fmt .format (batch_size = batch_size )
202- cfg_data : dict [str , Any ] = dict (
203- # args to this function
204- model_path = model_path ,
205- batch_size = batch_size ,
206- n_batches = n_batches ,
207- # dataset and tokenizer config
208- resid_mlp_dataset_kwargs = resid_mlp_dataset_kwargs ,
209- # files we saved
210- output_files = [str (p ) for p in output_paths ],
211- output_dir = str (output_dir ),
212- output_file_fmt = save_file_fmt ,
213- cfg_file_fmt = cfg_file_fmt ,
214- cfg_file = str (cfg_path ),
215- )
216-
217- cfg_path .parent .mkdir (parents = True , exist_ok = True )
218- cfg_path .write_text (json .dumps (cfg_data , indent = "\t " ))
219- print (f"Saved config to: { cfg_path } " )
220-
221- return output_paths
222-
223-
224- def split_and_save_dataset (
225- config : MergeRunConfig ,
226- output_dir : Path ,
227- save_file_fmt : str ,
228- cfg_file_fmt : str ,
229- ) -> list [Path ]:
230- """Split a dataset into n_batches of batch_size and save the batches"""
231- match config .task_name :
232- case "lm" :
233- return split_dataset_lm (
234- model_path = config .model_path ,
235- n_batches = config .n_batches ,
236- batch_size = config .batch_size ,
237- output_dir = output_dir ,
238- save_file_fmt = save_file_fmt ,
239- cfg_file_fmt = cfg_file_fmt ,
240- )
241- case "resid_mlp" :
242- return split_dataset_resid_mlp (
243- model_path = config .model_path ,
244- n_batches = config .n_batches ,
245- batch_size = config .batch_size ,
246- output_dir = output_dir ,
247- save_file_fmt = save_file_fmt ,
248- cfg_file_fmt = cfg_file_fmt ,
249- )
250- case name :
251- raise ValueError (
252- f"Unsupported task name '{ name } '. Supported tasks are 'lm' and 'resid_mlp'. { config .model_path = } , { name = } "
253- )
176+ return (batch [0 ] for batch in dataloader ), resid_mlp_dataset_kwargs
0 commit comments