Skip to content

Commit 94f3f92

Browse files
committed
better path handling
1 parent 238fa31 commit 94f3f92

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

spd/clustering/clustering_pipeline.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,43 @@
99

1010
from spd.clustering.merge_run_config import RunConfig
1111

12+
PIPELINE_PATHS: dict[str, str] = {
13+
"run_record_path": "run_record.json",
14+
"histories_dir": "merge_histories",
15+
"dataset_dir": "dataset",
16+
"ensemble_dir": "ensemble",
17+
"distances_dir": "distances",
18+
}
19+
20+
21+
class PipelinePaths:
22+
def __init__(self, config: RunConfig) -> None:
23+
self.config: RunConfig = config
24+
25+
@property
26+
def run_path(self) -> Path:
27+
return self.config.base_path / self.config.config_identifier
28+
29+
@property
30+
def run_record_path(self) -> Path:
31+
return self.run_path / PIPELINE_PATHS["run_record_path"]
32+
33+
@property
34+
def histories_dir(self) -> Path:
35+
return self.run_path / PIPELINE_PATHS["histories_dir"]
36+
37+
@property
38+
def dataset_dir(self) -> Path:
39+
return self.run_path / PIPELINE_PATHS["dataset_dir"]
40+
41+
@property
42+
def ensemble_dir(self) -> Path:
43+
return self.run_path / PIPELINE_PATHS["ensemble_dir"]
44+
45+
@property
46+
def distances_dir(self) -> Path:
47+
return self.run_path / PIPELINE_PATHS["distances_dir"]
48+
1249

1350
def main(config: RunConfig) -> None:
1451
from spd.clustering.math.merge_distances import (
@@ -21,39 +58,33 @@ def main(config: RunConfig) -> None:
2158
from spd.clustering.s3_normalize_histories import normalize_and_save
2259
from spd.clustering.s4_compute_distances import create_clustering_report
2360

24-
# TODO: factor these out into dataclass or something
25-
run_path: Path = config.base_path / config.config_identifier
26-
run_record_path: Path = run_path / "run_record.json"
27-
histories_dir: Path = run_path / "merge_histories"
28-
dataset_dir: Path = run_path / "dataset"
29-
ensemble_dir: Path = run_path / "ensemble"
30-
distances_dir: Path = run_path / "distances"
61+
paths: PipelinePaths = PipelinePaths(config=config)
3162

32-
print(f"Run record saved to {run_record_path}")
33-
run_record_path.write_text(config.model_dump_json(indent=2))
63+
print(f"Run record saved to {paths.run_record_path}")
64+
paths.run_record_path.write_text(config.model_dump_json(indent=2))
3465

3566
print(f"Splitting dataset into {config.n_batches} batches...")
36-
data_files: list[Path] = split_and_save_dataset(config=config, output_dir=dataset_dir)
67+
data_files: list[Path] = split_and_save_dataset(config=config, output_dir=paths.dataset_dir)
3768

3869
print(
3970
f"Processing {len(data_files)} batches with {config.workers_per_device} workers per device..."
4071
)
4172
results: list[ClusteringResult] = process_batches_parallel(
4273
data_files=data_files,
4374
config=config,
44-
output_dir=histories_dir,
75+
output_dir=paths.histories_dir,
4576
workers_per_device=config.workers_per_device,
4677
devices=config.devices,
4778
)
4879

4980
normalized_merge_array: MergesArray = normalize_and_save(
5081
history_paths=[r.history_save_path for r in results],
51-
output_dir=ensemble_dir,
82+
output_dir=paths.ensemble_dir,
5283
)
5384

5485
distances: DistancesArray = compute_and_save_distances(
5586
normalized_merge_array=normalized_merge_array,
56-
output_dir=distances_dir,
87+
output_dir=paths.distances_dir,
5788
)
5889

5990
wandb_urls: list[str] = [r.wandb_url for r in results if r.wandb_url] # Gross - clean up

spd/clustering/pipeline/__init__.py

Whitespace-only changes.

tests/clustering/test_merge_integration.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def test_merge_with_range_sampler(self):
3939
# Check results
4040
assert history is not None
4141
assert len(history.merges.k_groups) > 0
42-
assert history.merges.k_groups[0].item() == n_components
42+
# First entry is after first merge, so should be n_components - 1
43+
assert history.merges.k_groups[0].item() == n_components - 1
4344
# After iterations, should have fewer groups (merges reduce count)
4445
# Exact count depends on early stopping conditions
4546
assert history.merges.k_groups[-1].item() < n_components
@@ -75,7 +76,8 @@ def test_merge_with_mcmc_sampler(self):
7576
# Check results
7677
assert history is not None
7778
assert len(history.merges.k_groups) > 0
78-
assert history.merges.k_groups[0].item() == n_components
79+
# First entry is after first merge, so should be n_components - 1
80+
assert history.merges.k_groups[0].item() == n_components - 1
7981
# Should have fewer groups after iterations
8082
assert history.merges.k_groups[-1].item() < n_components
8183
assert history.merges.k_groups[-1].item() >= 2

0 commit comments

Comments
 (0)