99
1010from spd .clustering .merge_run_config import RunConfig
1111
12+ PIPELINE_PATHS : dict [str , str ] = {
13+ "run_record_path" : "run_record.json" ,
14+ "histories_dir" : "merge_histories" ,
15+ "dataset_dir" : "dataset" ,
16+ "ensemble_dir" : "ensemble" ,
17+ "distances_dir" : "distances" ,
18+ }
19+
20+
21+ class PipelinePaths :
22+ def __init__ (self , config : RunConfig ) -> None :
23+ self .config : RunConfig = config
24+
25+ @property
26+ def run_path (self ) -> Path :
27+ return self .config .base_path / self .config .config_identifier
28+
29+ @property
30+ def run_record_path (self ) -> Path :
31+ return self .run_path / PIPELINE_PATHS ["run_record_path" ]
32+
33+ @property
34+ def histories_dir (self ) -> Path :
35+ return self .run_path / PIPELINE_PATHS ["histories_dir" ]
36+
37+ @property
38+ def dataset_dir (self ) -> Path :
39+ return self .run_path / PIPELINE_PATHS ["dataset_dir" ]
40+
41+ @property
42+ def ensemble_dir (self ) -> Path :
43+ return self .run_path / PIPELINE_PATHS ["ensemble_dir" ]
44+
45+ @property
46+ def distances_dir (self ) -> Path :
47+ return self .run_path / PIPELINE_PATHS ["distances_dir" ]
48+
1249
1350def main (config : RunConfig ) -> None :
1451 from spd .clustering .math .merge_distances import (
@@ -21,39 +58,33 @@ def main(config: RunConfig) -> None:
2158 from spd .clustering .s3_normalize_histories import normalize_and_save
2259 from spd .clustering .s4_compute_distances import create_clustering_report
2360
24- # TODO: factor these out into dataclass or something
25- run_path : Path = config .base_path / config .config_identifier
26- run_record_path : Path = run_path / "run_record.json"
27- histories_dir : Path = run_path / "merge_histories"
28- dataset_dir : Path = run_path / "dataset"
29- ensemble_dir : Path = run_path / "ensemble"
30- distances_dir : Path = run_path / "distances"
61+ paths : PipelinePaths = PipelinePaths (config = config )
3162
32- print (f"Run record saved to { run_record_path } " )
33- run_record_path .write_text (config .model_dump_json (indent = 2 ))
63+ print (f"Run record saved to { paths . run_record_path } " )
64+ paths . run_record_path .write_text (config .model_dump_json (indent = 2 ))
3465
3566 print (f"Splitting dataset into { config .n_batches } batches..." )
36- data_files : list [Path ] = split_and_save_dataset (config = config , output_dir = dataset_dir )
67+ data_files : list [Path ] = split_and_save_dataset (config = config , output_dir = paths . dataset_dir )
3768
3869 print (
3970 f"Processing { len (data_files )} batches with { config .workers_per_device } workers per device..."
4071 )
4172 results : list [ClusteringResult ] = process_batches_parallel (
4273 data_files = data_files ,
4374 config = config ,
44- output_dir = histories_dir ,
75+ output_dir = paths . histories_dir ,
4576 workers_per_device = config .workers_per_device ,
4677 devices = config .devices ,
4778 )
4879
4980 normalized_merge_array : MergesArray = normalize_and_save (
5081 history_paths = [r .history_save_path for r in results ],
51- output_dir = ensemble_dir ,
82+ output_dir = paths . ensemble_dir ,
5283 )
5384
5485 distances : DistancesArray = compute_and_save_distances (
5586 normalized_merge_array = normalized_merge_array ,
56- output_dir = distances_dir ,
87+ output_dir = paths . distances_dir ,
5788 )
5889
5990 wandb_urls : list [str ] = [r .wandb_url for r in results if r .wandb_url ] # Gross - clean up
0 commit comments