@@ -40,21 +40,21 @@ class ClusteringStorage:
4040 """
4141
4242 # Directory structure constants
43- DATASET_DIR = "dataset"
44- BATCHES_DIR = "batches"
45- HISTORIES_DIR = "merge_histories"
46- ENSEMBLE_DIR = "ensemble"
47- DISTANCES_DIR = "distances"
43+ DATASET_DIR : str = "dataset"
44+ BATCHES_DIR : str = "batches"
45+ HISTORIES_DIR : str = "merge_histories"
46+ ENSEMBLE_DIR : str = "ensemble"
47+ DISTANCES_DIR : str = "distances"
4848
4949 # File naming constants
50- CONFIG_FILE = "dataset_config .json"
51- BATCH_FILE_FMT = "batch_{batch_idx:02d}.npz "
52- HISTORY_FILE_FMT = "data_{batch_id} "
53- MERGE_HISTORY_FILE = "merge_history.zip "
54- ENSEMBLE_META_FILE = "ensemble_meta.json "
55- ENSEMBLE_ARRAY_FILE = "ensemble_merge_array.npz "
56- DISTANCES_FILE_FMT = "distances_{method}.npz "
57- RUN_CONFIG_FILE = "run_config.json "
50+ RUN_CONFIG_FILE : str = "run_config .json"
51+ DATASET_CONFIG_FILE : str = "dataset_config.json "
52+ ENSEMBLE_META_FILE : str = "ensemble_meta.json "
53+ ENSEMBLE_ARRAY_FILE : str = "ensemble_merge_array.npz "
54+ BATCH_FILE_FMT : str = "batch_{batch_idx:02d}.npz "
55+ HISTORY_FILE_FMT : str = "data_{batch_id} "
56+ MERGE_HISTORY_FILE : str = "merge_history.zip "
57+ DISTANCES_FILE_FMT : str = "distances.{method}.npz "
5858
5959 def __init__ (self , base_path : Path , run_identifier : str | None = None ):
6060 """Initialize storage with base path and optional run identifier.
@@ -63,7 +63,7 @@ def __init__(self, base_path: Path, run_identifier: str | None = None):
6363 base_path: Root directory for all storage operations
6464 run_identifier: Optional identifier to create a subdirectory for this run
6565 """
66- self .base_path = base_path
66+ self .base_path : Path = base_path
6767 if run_identifier :
6868 self .run_path = base_path / run_identifier
6969 else :
@@ -72,122 +72,94 @@ def __init__(self, base_path: Path, run_identifier: str | None = None):
7272 # Ensure base directory exists
7373 self .run_path .mkdir (parents = True , exist_ok = True )
7474
75+ # directories
7576 @property
7677 def dataset_dir (self ) -> Path :
77- """Get dataset directory path."""
7878 return self .run_path / self .DATASET_DIR
7979
8080 @property
8181 def batches_dir (self ) -> Path :
82- """Get batches directory path."""
8382 return self .dataset_dir / self .BATCHES_DIR
8483
8584 @property
8685 def histories_dir (self ) -> Path :
87- """Get histories directory path."""
8886 return self .run_path / self .HISTORIES_DIR
8987
9088 @property
9189 def ensemble_dir (self ) -> Path :
92- """Get ensemble directory path."""
9390 return self .run_path / self .ENSEMBLE_DIR
9491
9592 @property
9693 def distances_dir (self ) -> Path :
97- """Get distances directory path."""
9894 return self .run_path / self .DISTANCES_DIR
9995
100- # Batch storage methods
96+ # files
97+ @property
98+ def run_config_file (self ) -> Path :
99+ return self .run_path / self .RUN_CONFIG_FILE
100+
101+ @property
102+ def dataset_config_file (self ) -> Path :
103+ return self .dataset_dir / self .DATASET_CONFIG_FILE
104+
105+ @property
106+ def ensemble_meta_file (self ) -> Path :
107+ return self .ensemble_dir / self .ENSEMBLE_META_FILE
101108
102- def save_dataset_config (self , config : dict [str , Any ]) -> Path :
103- """Save dataset configuration to JSON file.
109+ @property
110+ def ensemble_array_file (self ) -> Path :
111+ return self .ensemble_dir / self .ENSEMBLE_ARRAY_FILE
104112
105- Args:
106- config: Dataset configuration dictionary
113+ # dynamic
114+ def batch_path (self , batch_idx : int ) -> Path :
115+ return self .batches_dir / self .BATCH_FILE_FMT .format (batch_idx = batch_idx )
107116
108- Returns:
109- Path to saved configuration file
110- """
117+ def history_path (self , batch_id : str ) -> Path :
118+ return self .histories_dir / self .HISTORY_FILE_FMT .format (batch_id = batch_id ) / self .MERGE_HISTORY_FILE
119+
120+ # Batch storage methods
121+
122+ def save_dataset_config (self , config : dict [str , Any ]) -> Path :
111123 self .dataset_dir .mkdir (parents = True , exist_ok = True )
112- config_path = self .dataset_dir / self .CONFIG_FILE
113- config_path .write_text (json .dumps (config , indent = 2 ))
114- return config_path
124+ self .dataset_config_file .write_text (json .dumps (config , indent = 2 ))
125+ return self .dataset_config_file
115126
116127 def save_batch (self , batch : Tensor , batch_idx : int ) -> Path :
117128 self .batches_dir .mkdir (parents = True , exist_ok = True )
118- batch_path = self .batches_dir / self . BATCH_FILE_FMT . format ( batch_idx = batch_idx )
129+ batch_path : Path = self .batch_path ( batch_idx )
119130
120131 np .savez_compressed (batch_path , input_ids = batch .cpu ().numpy ())
121132 return batch_path
122133
123134 def save_batches (self , batches : Iterator [Tensor ], config : dict [str , Any ]) -> list [Path ]:
124- paths = []
135+ paths : list [ Path ] = []
125136
126137 self .save_dataset_config (config )
127138
128139 for idx , batch in enumerate (batches ):
129- path = self .save_batch (batch , idx )
140+ path : Path = self .save_batch (batch , idx )
130141 paths .append (path )
131142
132143 return paths
133144
134145 def load_batch (self , batch_path : Path ) -> Int [Tensor , "batch_size n_ctx" ]:
135- """Load a batch from disk.
136-
137- Args:
138- batch_path: Path to batch file
139-
140- Returns:
141- Loaded batch tensor
142- """
143- data = np .load (batch_path )
146+ data : dict [str , np .ndarray ] = np .load (batch_path )
144147 return torch .tensor (data ["input_ids" ])
145148
146- def load_batches (self ) -> list [Tensor ]:
147- """Load all batches from the batches directory.
148-
149- Returns:
150- List of loaded batch tensors
151- """
152- batch_files = sorted (self .batches_dir .glob ("batch_*.npz" ))
153- return [self .load_batch (path ) for path in batch_files ]
154-
155149 def get_batch_paths (self ) -> list [Path ]:
156- """Get sorted list of all batch file paths.
157-
158- Returns:
159- List of paths to batch files
160- """
161150 return sorted (self .batches_dir .glob ("batch_*.npz" ))
162151
163152 # History storage methods
164153
165154 def save_history (self , history : MergeHistory , batch_id : str ) -> Path :
166- """Save merge history for a batch.
167-
168- Args:
169- history: MergeHistory object to save
170- batch_id: Identifier for the batch
171-
172- Returns:
173- Path to saved history file
174- """
175- history_dir = self .histories_dir / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
155+ history_dir : Path = self .histories_dir / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
176156 history_dir .mkdir (parents = True , exist_ok = True )
177157
178- history_path = history_dir / self .MERGE_HISTORY_FILE
158+ history_path : Path = history_dir / self .MERGE_HISTORY_FILE
179159 history .save (history_path )
180160 return history_path
181161
182162 def load_history (self , batch_id : str ) -> MergeHistory :
183- """Load merge history for a batch.
184-
185- Args:
186- batch_id: Identifier for the batch
187-
188- Returns:
189- Loaded MergeHistory object
190- """
191163 history_path = (
192164 self .histories_dir
193165 / self .HISTORY_FILE_FMT .format (batch_id = batch_id )
@@ -196,57 +168,40 @@ def load_history(self, batch_id: str) -> MergeHistory:
196168 return MergeHistory .read (history_path )
197169
198170 def get_history_paths (self ) -> list [Path ]:
199- """Get all history file paths.
200-
201- Returns:
202- List of paths to history files
203- """
204171 return sorted (self .histories_dir .glob (f"*/{ self .MERGE_HISTORY_FILE } " ))
205172
206173 def load_histories (self ) -> list [MergeHistory ]:
207- """Load all merge histories.
208-
209- Returns:
210- List of loaded MergeHistory objects
211- """
212174 return [MergeHistory .read (path ) for path in self .get_history_paths ()]
213175
214176 # Ensemble storage methods
215177
216178 def save_ensemble (self , ensemble : NormalizedEnsemble ) -> tuple [Path , Path ]:
217- """Save normalized ensemble data.
218-
219- Args:
220- ensemble: NormalizedEnsemble containing merge array and metadata
221-
222- Returns:
223- Tuple of (metadata_path, array_path)
224- """
179+ """Save normalized ensemble data"""
225180 self .ensemble_dir .mkdir (parents = True , exist_ok = True )
226181
227182 # Save metadata
228- metadata_path = self .ensemble_dir / self . ENSEMBLE_META_FILE
183+ metadata_path : Path = self .ensemble_meta_file
229184 metadata_path .write_text (json .dumps (ensemble .metadata , indent = 2 ))
230185
231186 # Save merge array
232- array_path = self .ensemble_dir / self . ENSEMBLE_ARRAY_FILE
187+ array_path : Path = self .ensemble_array_file
233188 np .savez_compressed (array_path , merges = ensemble .merge_array )
234189
235190 return metadata_path , array_path
236191
237192 def save_distances (self , distances : DistancesArray , method : DistancesMethod ) -> Path :
238193 self .distances_dir .mkdir (parents = True , exist_ok = True )
239194
240- distances_path = self .distances_dir / self .DISTANCES_FILE_FMT .format (method = method )
195+ distances_path : Path = self .distances_dir / self .DISTANCES_FILE_FMT .format (method = method )
241196 np .savez_compressed (distances_path , distances = distances )
242197 return distances_path
243198
244199 def load_distances (self , method : DistancesMethod ) -> DistancesArray :
245- distances_path = self .distances_dir / self .DISTANCES_FILE_FMT .format (method = method )
246- data = np .load (distances_path )
200+ distances_path : Path = self .distances_dir / self .DISTANCES_FILE_FMT .format (method = method )
201+ data : dict [ str , np . ndarray ] = np .load (distances_path )
247202 return data ["distances" ]
248203
249204 def save_run_config (self , config : RunConfig ) -> Path :
250- config_path = self .run
205+ config_path = self .run_path / self . RUN_CONFIG_FILE
251206 config_path .write_text (config .model_dump_json (indent = 2 ))
252207 return config_path
0 commit comments