Skip to content

Commit f2ecda1

Browse files
committed
wip storage
1 parent cc202ce commit f2ecda1

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

spd/clustering/math/merge_distances.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from pathlib import Path
32
from typing import Literal
43

54
import numpy as np

spd/clustering/merge_history.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def save(self, path: Path, wandb_url: str | None = None) -> None:
150150
wandb_url=wandb_url,
151151
c_components=self.c_components,
152152
n_iters_current=self.n_iters_current,
153+
labels=self.labels,
153154
)
154155
),
155156
)

spd/clustering/storage.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class NormalizedEnsemble:
3232
metadata: dict[str, Any]
3333

3434

35+
def _write_text_to_path_and_return(path: Path, data: str) -> Path:
36+
path.parent.mkdir(parents=True, exist_ok=True)
37+
path.write_text(data)
38+
return path
39+
40+
3541
class ClusteringStorage:
3642
"""Handles all file I/O operations for the clustering pipeline.
3743
@@ -97,11 +103,11 @@ def distances_dir(self) -> Path:
97103
@property
98104
def run_config_file(self) -> Path:
99105
return self.run_path / self.RUN_CONFIG_FILE
100-
106+
101107
@property
102108
def dataset_config_file(self) -> Path:
103109
return self.dataset_dir / self.DATASET_CONFIG_FILE
104-
110+
105111
@property
106112
def ensemble_meta_file(self) -> Path:
107113
return self.ensemble_dir / self.ENSEMBLE_META_FILE
@@ -111,22 +117,27 @@ def ensemble_array_file(self) -> Path:
111117
return self.ensemble_dir / self.ENSEMBLE_ARRAY_FILE
112118

113119
# dynamic
120+
114121
def batch_path(self, batch_idx: int) -> Path:
115122
return self.batches_dir / self.BATCH_FILE_FMT.format(batch_idx=batch_idx)
116123

117124
def history_path(self, batch_id: str) -> Path:
118-
return self.histories_dir / self.HISTORY_FILE_FMT.format(batch_id=batch_id) / self.MERGE_HISTORY_FILE
125+
return (
126+
self.histories_dir
127+
/ self.HISTORY_FILE_FMT.format(batch_id=batch_id)
128+
/ self.MERGE_HISTORY_FILE
129+
)
119130

120131
# Batch storage methods
121132

122133
def save_dataset_config(self, config: dict[str, Any]) -> Path:
123-
self.dataset_dir.mkdir(parents=True, exist_ok=True)
124-
self.dataset_config_file.write_text(json.dumps(config, indent=2))
125-
return self.dataset_config_file
134+
return _write_text_to_path_and_return(
135+
self.dataset_config_file, json.dumps(config, indent=2)
136+
)
126137

127138
def save_batch(self, batch: Tensor, batch_idx: int) -> Path:
128-
self.batches_dir.mkdir(parents=True, exist_ok=True)
129139
batch_path: Path = self.batch_path(batch_idx)
140+
batch_path.parent.mkdir(parents=True, exist_ok=True)
130141

131142
np.savez_compressed(batch_path, input_ids=batch.cpu().numpy())
132143
return batch_path
@@ -152,20 +163,13 @@ def get_batch_paths(self) -> list[Path]:
152163
# History storage methods
153164

154165
def save_history(self, history: MergeHistory, batch_id: str) -> Path:
155-
history_dir: Path = self.histories_dir / self.HISTORY_FILE_FMT.format(batch_id=batch_id)
156-
history_dir.mkdir(parents=True, exist_ok=True)
157-
158-
history_path: Path = history_dir / self.MERGE_HISTORY_FILE
166+
history_path: Path = self.history_path(batch_id)
167+
history_path.parent.mkdir(parents=True, exist_ok=True)
159168
history.save(history_path)
160169
return history_path
161170

162171
def load_history(self, batch_id: str) -> MergeHistory:
163-
history_path = (
164-
self.histories_dir
165-
/ self.HISTORY_FILE_FMT.format(batch_id=batch_id)
166-
/ self.MERGE_HISTORY_FILE
167-
)
168-
return MergeHistory.read(history_path)
172+
return MergeHistory.read(self.history_path(batch_id))
169173

170174
def get_history_paths(self) -> list[Path]:
171175
return sorted(self.histories_dir.glob(f"*/{self.MERGE_HISTORY_FILE}"))
@@ -174,7 +178,6 @@ def load_histories(self) -> list[MergeHistory]:
174178
return [MergeHistory.read(path) for path in self.get_history_paths()]
175179

176180
# Ensemble storage methods
177-
178181
def save_ensemble(self, ensemble: NormalizedEnsemble) -> tuple[Path, Path]:
179182
"""Save normalized ensemble data"""
180183
self.ensemble_dir.mkdir(parents=True, exist_ok=True)
@@ -202,6 +205,5 @@ def load_distances(self, method: DistancesMethod) -> DistancesArray:
202205
return data["distances"]
203206

204207
def save_run_config(self, config: RunConfig) -> Path:
205-
config_path = self.run_path / self.RUN_CONFIG_FILE
206-
config_path.write_text(config.model_dump_json(indent=2))
207-
return config_path
208+
self.run_config_file.write_text(config.model_dump_json(indent=2))
209+
return self.run_config_file

0 commit comments

Comments
 (0)