@@ -21,7 +21,7 @@ def split_dataset_lm(
2121 model_path : str ,
2222 n_batches : int ,
2323 batch_size : int ,
24- output_path : Path ,
24+ output_dir : Path ,
2525 save_file_fmt : str ,
2626 cfg_file_fmt : str ,
2727) -> list [Path ]:
@@ -72,9 +72,9 @@ def split_dataset_lm(
7272 )
7373
7474 # make dirs
75- output_path .mkdir (parents = True , exist_ok = True )
75+ output_dir .mkdir (parents = True , exist_ok = True )
7676 (
77- output_path
77+ output_dir
7878 / save_file_fmt .format (batch_size = batch_size , batch_idx = "XX" , n_batches = f"{ n_batches :02d} " )
7979 ).parent .mkdir (parents = True , exist_ok = True )
8080 # iterate over the requested number of batches and save them
@@ -86,7 +86,7 @@ def split_dataset_lm(
8686 ):
8787 if batch_idx >= n_batches :
8888 break
89- batch_path : Path = output_path / save_file_fmt .format (
89+ batch_path : Path = output_dir / save_file_fmt .format (
9090 batch_size = batch_size ,
9191 batch_idx = f"{ batch_idx :02d} " ,
9292 n_batches = f"{ n_batches :02d} " ,
@@ -98,7 +98,7 @@ def split_dataset_lm(
9898 output_paths .append (batch_path )
9999
100100 # save a config file
101- cfg_path : Path = output_path / cfg_file_fmt .format (batch_size = batch_size )
101+ cfg_path : Path = output_dir / cfg_file_fmt .format (batch_size = batch_size )
102102 cfg_data : dict [str , Any ] = dict (
103103 # args to this function
104104 model_path = model_path ,
@@ -110,7 +110,7 @@ def split_dataset_lm(
110110 tokenizer_type = str (getattr (_tokenizer , "__class__" , None )),
111111 # files we saved
112112 output_files = [str (p ) for p in output_paths ],
113- output_dir = str (output_path ),
113+ output_dir = str (output_dir ),
114114 output_file_fmt = save_file_fmt ,
115115 cfg_file_fmt = cfg_file_fmt ,
116116 cfg_file = str (cfg_path ),
@@ -127,7 +127,7 @@ def split_dataset_resid_mlp(
127127 model_path : str ,
128128 n_batches : int ,
129129 batch_size : int ,
130- output_path : Path ,
130+ output_dir : Path ,
131131 save_file_fmt : str ,
132132 cfg_file_fmt : str ,
133133) -> list [Path ]:
@@ -168,9 +168,9 @@ def split_dataset_resid_mlp(
168168 dataloader = DatasetGeneratedDataLoader (dataset , batch_size = batch_size , shuffle = False )
169169
170170 # make dirs
171- output_path .mkdir (parents = True , exist_ok = True )
171+ output_dir .mkdir (parents = True , exist_ok = True )
172172 (
173- output_path
173+ output_dir
174174 / save_file_fmt .format (batch_size = batch_size , batch_idx = "XX" , n_batches = f"{ n_batches :02d} " )
175175 ).parent .mkdir (parents = True , exist_ok = True )
176176
@@ -186,7 +186,7 @@ def split_dataset_resid_mlp(
186186 if batch_idx >= n_batches :
187187 break
188188
189- batch_path : Path = output_path / save_file_fmt .format (
189+ batch_path : Path = output_dir / save_file_fmt .format (
190190 batch_size = batch_size ,
191191 batch_idx = f"{ batch_idx :02d} " ,
192192 n_batches = f"{ n_batches :02d} " ,
@@ -198,7 +198,7 @@ def split_dataset_resid_mlp(
198198 output_paths .append (batch_path )
199199
200200 # save the config file
201- cfg_path : Path = output_path / cfg_file_fmt .format (batch_size = batch_size )
201+ cfg_path : Path = output_dir / cfg_file_fmt .format (batch_size = batch_size )
202202 cfg_data : dict [str , Any ] = dict (
203203 # args to this function
204204 model_path = model_path ,
@@ -208,7 +208,7 @@ def split_dataset_resid_mlp(
208208 resid_mlp_dataset_kwargs = resid_mlp_dataset_kwargs ,
209209 # files we saved
210210 output_files = [str (p ) for p in output_paths ],
211- output_dir = str (output_path ),
211+ output_dir = str (output_dir ),
212212 output_file_fmt = save_file_fmt ,
213213 cfg_file_fmt = cfg_file_fmt ,
214214 cfg_file = str (cfg_path ),
@@ -223,7 +223,7 @@ def split_dataset_resid_mlp(
223223
224224def split_and_save_dataset (
225225 config : MergeRunConfig ,
226- output_path : Path ,
226+ output_dir : Path ,
227227 save_file_fmt : str ,
228228 cfg_file_fmt : str ,
229229) -> list [Path ]:
@@ -234,7 +234,7 @@ def split_and_save_dataset(
234234 model_path = config .model_path ,
235235 n_batches = config .n_batches ,
236236 batch_size = config .batch_size ,
237- output_path = output_path ,
237+ output_dir = output_dir ,
238238 save_file_fmt = save_file_fmt ,
239239 cfg_file_fmt = cfg_file_fmt ,
240240 )
@@ -243,7 +243,7 @@ def split_and_save_dataset(
243243 model_path = config .model_path ,
244244 n_batches = config .n_batches ,
245245 batch_size = config .batch_size ,
246- output_path = output_path ,
246+ output_dir = output_dir ,
247247 save_file_fmt = save_file_fmt ,
248248 cfg_file_fmt = cfg_file_fmt ,
249249 )
0 commit comments