@@ -32,6 +32,12 @@ class NormalizedEnsemble:
3232 metadata : dict [str , Any ]
3333
3434
35+ def _write_text_to_path_and_return (path : Path , data : str ) -> Path :
36+ path .parent .mkdir (parents = True , exist_ok = True )
37+ path .write_text (data )
38+ return path
39+
40+
3541class ClusteringStorage :
3642 """Handles all file I/O operations for the clustering pipeline.
3743
@@ -97,11 +103,11 @@ def distances_dir(self) -> Path:
97103 @property
98104 def run_config_file (self ) -> Path :
99105 return self .run_path / self .RUN_CONFIG_FILE
100-
106+
101107 @property
102108 def dataset_config_file (self ) -> Path :
103109 return self .dataset_dir / self .DATASET_CONFIG_FILE
104-
110+
105111 @property
106112 def ensemble_meta_file (self ) -> Path :
107113 return self .ensemble_dir / self .ENSEMBLE_META_FILE
@@ -111,22 +117,27 @@ def ensemble_array_file(self) -> Path:
111117 return self .ensemble_dir / self .ENSEMBLE_ARRAY_FILE
112118
113119 # dynamic
120+
114121 def batch_path (self , batch_idx : int ) -> Path :
115122 return self .batches_dir / self .BATCH_FILE_FMT .format (batch_idx = batch_idx )
116123
117124 def history_path (self , batch_id : str ) -> Path :
118- return self .histories_dir / self .HISTORY_FILE_FMT .format (batch_id = batch_id ) / self .MERGE_HISTORY_FILE
125+ return (
126+ self .histories_dir
127+ / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
128+ / self .MERGE_HISTORY_FILE
129+ )
119130
120131 # Batch storage methods
121132
122133 def save_dataset_config (self , config : dict [str , Any ]) -> Path :
123- self . dataset_dir . mkdir ( parents = True , exist_ok = True )
124- self .dataset_config_file . write_text ( json .dumps (config , indent = 2 ) )
125- return self . dataset_config_file
134+ return _write_text_to_path_and_return (
135+ self .dataset_config_file , json .dumps (config , indent = 2 )
136+ )
126137
127138 def save_batch (self , batch : Tensor , batch_idx : int ) -> Path :
128- self .batches_dir .mkdir (parents = True , exist_ok = True )
129139 batch_path : Path = self .batch_path (batch_idx )
140+ batch_path .parent .mkdir (parents = True , exist_ok = True )
130141
131142 np .savez_compressed (batch_path , input_ids = batch .cpu ().numpy ())
132143 return batch_path
@@ -152,20 +163,13 @@ def get_batch_paths(self) -> list[Path]:
152163 # History storage methods
153164
154165 def save_history (self , history : MergeHistory , batch_id : str ) -> Path :
155- history_dir : Path = self .histories_dir / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
156- history_dir .mkdir (parents = True , exist_ok = True )
157-
158- history_path : Path = history_dir / self .MERGE_HISTORY_FILE
166+ history_path : Path = self .history_path (batch_id )
167+ history_path .parent .mkdir (parents = True , exist_ok = True )
159168 history .save (history_path )
160169 return history_path
161170
162171 def load_history (self , batch_id : str ) -> MergeHistory :
163- history_path = (
164- self .histories_dir
165- / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
166- / self .MERGE_HISTORY_FILE
167- )
168- return MergeHistory .read (history_path )
172+ return MergeHistory .read (self .history_path (batch_id ))
169173
170174 def get_history_paths (self ) -> list [Path ]:
171175 return sorted (self .histories_dir .glob (f"*/{ self .MERGE_HISTORY_FILE } " ))
@@ -174,7 +178,6 @@ def load_histories(self) -> list[MergeHistory]:
174178 return [MergeHistory .read (path ) for path in self .get_history_paths ()]
175179
176180 # Ensemble storage methods
177-
178181 def save_ensemble (self , ensemble : NormalizedEnsemble ) -> tuple [Path , Path ]:
179182 """Save normalized ensemble data"""
180183 self .ensemble_dir .mkdir (parents = True , exist_ok = True )
@@ -202,6 +205,5 @@ def load_distances(self, method: DistancesMethod) -> DistancesArray:
202205 return data ["distances" ]
203206
204207 def save_run_config (self , config : RunConfig ) -> Path :
205- config_path = self .run_path / self .RUN_CONFIG_FILE
206- config_path .write_text (config .model_dump_json (indent = 2 ))
207- return config_path
208+ self .run_config_file .write_text (config .model_dump_json (indent = 2 ))
209+ return self .run_config_file
0 commit comments