Skip to content

Commit 62dfcdf

Browse files
committed
fix tests
1 parent 7e66a3a commit 62dfcdf

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

tests/clustering/test_wandb_integration.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,43 @@
1111

1212

1313
def test_wandb_url_parsing_short_format():
14-
"""Test that normalize_and_save can process merge histories."""
15-
# Create temporary merge history files
14+
"""Test that normalize_and_save can process merge histories using storage."""
15+
from spd.clustering.pipeline.storage import ClusteringStorage
16+
17+
# Create temporary directory for storage
1618
with tempfile.TemporaryDirectory() as tmp_dir:
1719
tmp_path = Path(tmp_dir)
1820

19-
# Create mock merge histories
20-
from spd.clustering.merge_config import MergeConfig
21+
# Create ClusteringStorage instance
22+
storage = ClusteringStorage(base_path=tmp_path, run_identifier="test_run")
2123

24+
# Create mock merge histories
2225
config = MergeConfig(
2326
iters=5,
2427
alpha=1.0,
2528
activation_threshold=None,
2629
pop_component_prob=0.0,
2730
)
2831

29-
history_paths = []
32+
# Save histories using storage
3033
for idx in range(2):
3134
history = MergeHistory.from_config(
3235
config=config,
3336
labels=[f"comp{j}" for j in range(5)],
3437
)
35-
history_path = tmp_path / f"history_{idx}.zip"
36-
history.save(history_path)
37-
history_paths.append(history_path)
38+
storage.save_history(history, batch_id=f"batch_{idx:02d}")
3839

39-
# Test normalize_and_save
40-
output_dir = tmp_path / "output"
41-
result = normalize_and_save(history_paths, output_dir)
40+
# Test normalize_and_save with storage
41+
result = normalize_and_save(storage=storage)
4242

4343
# Basic checks
4444
assert result is not None
45-
assert output_dir.exists()
46-
assert (output_dir / "ensemble_meta.json").exists()
47-
assert (output_dir / "ensemble_merge_array.npz").exists()
45+
assert storage.ensemble_meta_file.exists()
46+
assert storage.ensemble_array_file.exists()
47+
48+
# Verify we can load the histories back
49+
loaded_histories = storage.load_histories()
50+
assert len(loaded_histories) == 2
4851

4952

5053
def test_merge_history_ensemble():

0 commit comments

Comments
 (0)