Skip to content

Commit cc202ce

Browse files
committed
wip
1 parent b829f3e commit cc202ce

File tree

1 file changed

+54
-99
lines changed

1 file changed

+54
-99
lines changed

spd/clustering/storage.py

Lines changed: 54 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ class ClusteringStorage:
4040
"""
4141

4242
# Directory structure constants
43-
DATASET_DIR = "dataset"
44-
BATCHES_DIR = "batches"
45-
HISTORIES_DIR = "merge_histories"
46-
ENSEMBLE_DIR = "ensemble"
47-
DISTANCES_DIR = "distances"
43+
DATASET_DIR: str = "dataset"
44+
BATCHES_DIR: str = "batches"
45+
HISTORIES_DIR: str = "merge_histories"
46+
ENSEMBLE_DIR: str = "ensemble"
47+
DISTANCES_DIR: str = "distances"
4848

4949
# File naming constants
50-
CONFIG_FILE = "dataset_config.json"
51-
BATCH_FILE_FMT = "batch_{batch_idx:02d}.npz"
52-
HISTORY_FILE_FMT = "data_{batch_id}"
53-
MERGE_HISTORY_FILE = "merge_history.zip"
54-
ENSEMBLE_META_FILE = "ensemble_meta.json"
55-
ENSEMBLE_ARRAY_FILE = "ensemble_merge_array.npz"
56-
DISTANCES_FILE_FMT = "distances_{method}.npz"
57-
RUN_CONFIG_FILE = "run_config.json"
50+
RUN_CONFIG_FILE: str = "run_config.json"
51+
DATASET_CONFIG_FILE: str = "dataset_config.json"
52+
ENSEMBLE_META_FILE: str = "ensemble_meta.json"
53+
ENSEMBLE_ARRAY_FILE: str = "ensemble_merge_array.npz"
54+
BATCH_FILE_FMT: str = "batch_{batch_idx:02d}.npz"
55+
HISTORY_FILE_FMT: str = "data_{batch_id}"
56+
MERGE_HISTORY_FILE: str = "merge_history.zip"
57+
DISTANCES_FILE_FMT: str = "distances.{method}.npz"
5858

5959
def __init__(self, base_path: Path, run_identifier: str | None = None):
6060
"""Initialize storage with base path and optional run identifier.
@@ -63,7 +63,7 @@ def __init__(self, base_path: Path, run_identifier: str | None = None):
6363
base_path: Root directory for all storage operations
6464
run_identifier: Optional identifier to create a subdirectory for this run
6565
"""
66-
self.base_path = base_path
66+
self.base_path: Path = base_path
6767
if run_identifier:
6868
self.run_path = base_path / run_identifier
6969
else:
@@ -72,122 +72,94 @@ def __init__(self, base_path: Path, run_identifier: str | None = None):
7272
# Ensure base directory exists
7373
self.run_path.mkdir(parents=True, exist_ok=True)
7474

75+
# directories
7576
@property
7677
def dataset_dir(self) -> Path:
77-
"""Get dataset directory path."""
7878
return self.run_path / self.DATASET_DIR
7979

8080
@property
8181
def batches_dir(self) -> Path:
82-
"""Get batches directory path."""
8382
return self.dataset_dir / self.BATCHES_DIR
8483

8584
@property
8685
def histories_dir(self) -> Path:
87-
"""Get histories directory path."""
8886
return self.run_path / self.HISTORIES_DIR
8987

9088
@property
9189
def ensemble_dir(self) -> Path:
92-
"""Get ensemble directory path."""
9390
return self.run_path / self.ENSEMBLE_DIR
9491

9592
@property
9693
def distances_dir(self) -> Path:
97-
"""Get distances directory path."""
9894
return self.run_path / self.DISTANCES_DIR
9995

100-
# Batch storage methods
96+
# files
97+
@property
98+
def run_config_file(self) -> Path:
99+
return self.run_path / self.RUN_CONFIG_FILE
100+
101+
@property
102+
def dataset_config_file(self) -> Path:
103+
return self.dataset_dir / self.DATASET_CONFIG_FILE
104+
105+
@property
106+
def ensemble_meta_file(self) -> Path:
107+
return self.ensemble_dir / self.ENSEMBLE_META_FILE
101108

102-
def save_dataset_config(self, config: dict[str, Any]) -> Path:
103-
"""Save dataset configuration to JSON file.
109+
@property
110+
def ensemble_array_file(self) -> Path:
111+
return self.ensemble_dir / self.ENSEMBLE_ARRAY_FILE
104112

105-
Args:
106-
config: Dataset configuration dictionary
113+
# dynamic
114+
def batch_path(self, batch_idx: int) -> Path:
115+
return self.batches_dir / self.BATCH_FILE_FMT.format(batch_idx=batch_idx)
107116

108-
Returns:
109-
Path to saved configuration file
110-
"""
117+
def history_path(self, batch_id: str) -> Path:
118+
return self.histories_dir / self.HISTORY_FILE_FMT.format(batch_id=batch_id) / self.MERGE_HISTORY_FILE
119+
120+
# Batch storage methods
121+
122+
def save_dataset_config(self, config: dict[str, Any]) -> Path:
111123
self.dataset_dir.mkdir(parents=True, exist_ok=True)
112-
config_path = self.dataset_dir / self.CONFIG_FILE
113-
config_path.write_text(json.dumps(config, indent=2))
114-
return config_path
124+
self.dataset_config_file.write_text(json.dumps(config, indent=2))
125+
return self.dataset_config_file
115126

116127
def save_batch(self, batch: Tensor, batch_idx: int) -> Path:
117128
self.batches_dir.mkdir(parents=True, exist_ok=True)
118-
batch_path = self.batches_dir / self.BATCH_FILE_FMT.format(batch_idx=batch_idx)
129+
batch_path: Path = self.batch_path(batch_idx)
119130

120131
np.savez_compressed(batch_path, input_ids=batch.cpu().numpy())
121132
return batch_path
122133

123134
def save_batches(self, batches: Iterator[Tensor], config: dict[str, Any]) -> list[Path]:
124-
paths = []
135+
paths: list[Path] = []
125136

126137
self.save_dataset_config(config)
127138

128139
for idx, batch in enumerate(batches):
129-
path = self.save_batch(batch, idx)
140+
path: Path = self.save_batch(batch, idx)
130141
paths.append(path)
131142

132143
return paths
133144

134145
def load_batch(self, batch_path: Path) -> Int[Tensor, "batch_size n_ctx"]:
135-
"""Load a batch from disk.
136-
137-
Args:
138-
batch_path: Path to batch file
139-
140-
Returns:
141-
Loaded batch tensor
142-
"""
143-
data = np.load(batch_path)
146+
data: dict[str, np.ndarray] = np.load(batch_path)
144147
return torch.tensor(data["input_ids"])
145148

146-
def load_batches(self) -> list[Tensor]:
147-
"""Load all batches from the batches directory.
148-
149-
Returns:
150-
List of loaded batch tensors
151-
"""
152-
batch_files = sorted(self.batches_dir.glob("batch_*.npz"))
153-
return [self.load_batch(path) for path in batch_files]
154-
155149
def get_batch_paths(self) -> list[Path]:
156-
"""Get sorted list of all batch file paths.
157-
158-
Returns:
159-
List of paths to batch files
160-
"""
161150
return sorted(self.batches_dir.glob("batch_*.npz"))
162151

163152
# History storage methods
164153

165154
def save_history(self, history: MergeHistory, batch_id: str) -> Path:
166-
"""Save merge history for a batch.
167-
168-
Args:
169-
history: MergeHistory object to save
170-
batch_id: Identifier for the batch
171-
172-
Returns:
173-
Path to saved history file
174-
"""
175-
history_dir = self.histories_dir / self.HISTORY_FILE_FMT.format(batch_id=batch_id)
155+
history_dir: Path = self.histories_dir / self.HISTORY_FILE_FMT.format(batch_id=batch_id)
176156
history_dir.mkdir(parents=True, exist_ok=True)
177157

178-
history_path = history_dir / self.MERGE_HISTORY_FILE
158+
history_path: Path = history_dir / self.MERGE_HISTORY_FILE
179159
history.save(history_path)
180160
return history_path
181161

182162
def load_history(self, batch_id: str) -> MergeHistory:
183-
"""Load merge history for a batch.
184-
185-
Args:
186-
batch_id: Identifier for the batch
187-
188-
Returns:
189-
Loaded MergeHistory object
190-
"""
191163
history_path = (
192164
self.histories_dir
193165
/ self.HISTORY_FILE_FMT.format(batch_id=batch_id)
@@ -196,57 +168,40 @@ def load_history(self, batch_id: str) -> MergeHistory:
196168
return MergeHistory.read(history_path)
197169

198170
def get_history_paths(self) -> list[Path]:
199-
"""Get all history file paths.
200-
201-
Returns:
202-
List of paths to history files
203-
"""
204171
return sorted(self.histories_dir.glob(f"*/{self.MERGE_HISTORY_FILE}"))
205172

206173
def load_histories(self) -> list[MergeHistory]:
207-
"""Load all merge histories.
208-
209-
Returns:
210-
List of loaded MergeHistory objects
211-
"""
212174
return [MergeHistory.read(path) for path in self.get_history_paths()]
213175

214176
# Ensemble storage methods
215177

216178
def save_ensemble(self, ensemble: NormalizedEnsemble) -> tuple[Path, Path]:
217-
"""Save normalized ensemble data.
218-
219-
Args:
220-
ensemble: NormalizedEnsemble containing merge array and metadata
221-
222-
Returns:
223-
Tuple of (metadata_path, array_path)
224-
"""
179+
"""Save normalized ensemble data"""
225180
self.ensemble_dir.mkdir(parents=True, exist_ok=True)
226181

227182
# Save metadata
228-
metadata_path = self.ensemble_dir / self.ENSEMBLE_META_FILE
183+
metadata_path: Path = self.ensemble_meta_file
229184
metadata_path.write_text(json.dumps(ensemble.metadata, indent=2))
230185

231186
# Save merge array
232-
array_path = self.ensemble_dir / self.ENSEMBLE_ARRAY_FILE
187+
array_path: Path = self.ensemble_array_file
233188
np.savez_compressed(array_path, merges=ensemble.merge_array)
234189

235190
return metadata_path, array_path
236191

237192
def save_distances(self, distances: DistancesArray, method: DistancesMethod) -> Path:
238193
self.distances_dir.mkdir(parents=True, exist_ok=True)
239194

240-
distances_path = self.distances_dir / self.DISTANCES_FILE_FMT.format(method=method)
195+
distances_path: Path = self.distances_dir / self.DISTANCES_FILE_FMT.format(method=method)
241196
np.savez_compressed(distances_path, distances=distances)
242197
return distances_path
243198

244199
def load_distances(self, method: DistancesMethod) -> DistancesArray:
245-
distances_path = self.distances_dir / self.DISTANCES_FILE_FMT.format(method=method)
246-
data = np.load(distances_path)
200+
distances_path: Path = self.distances_dir / self.DISTANCES_FILE_FMT.format(method=method)
201+
data: dict[str, np.ndarray] = np.load(distances_path)
247202
return data["distances"]
248203

249204
def save_run_config(self, config: RunConfig) -> Path:
250-
config_path = self.run
205+
config_path = self.run_path / self.RUN_CONFIG_FILE
251206
config_path.write_text(config.model_dump_json(indent=2))
252207
return config_path

0 commit comments

Comments
 (0)