|
11 | 11 |
|
12 | 12 |
|
13 | 13 | def test_wandb_url_parsing_short_format(): |
14 | | - """Test that normalize_and_save can process merge histories.""" |
15 | | - # Create temporary merge history files |
| 14 | + """Test that normalize_and_save can process merge histories using storage.""" |
| 15 | + from spd.clustering.pipeline.storage import ClusteringStorage |
| 16 | + |
| 17 | + # Create temporary directory for storage |
16 | 18 | with tempfile.TemporaryDirectory() as tmp_dir: |
17 | 19 | tmp_path = Path(tmp_dir) |
18 | 20 |
|
19 | | - # Create mock merge histories |
20 | | - from spd.clustering.merge_config import MergeConfig |
| 21 | + # Create ClusteringStorage instance |
| 22 | + storage = ClusteringStorage(base_path=tmp_path, run_identifier="test_run") |
21 | 23 |
|
| 24 | + # Create mock merge histories |
22 | 25 | config = MergeConfig( |
23 | 26 | iters=5, |
24 | 27 | alpha=1.0, |
25 | 28 | activation_threshold=None, |
26 | 29 | pop_component_prob=0.0, |
27 | 30 | ) |
28 | 31 |
|
29 | | - history_paths = [] |
| 32 | + # Save histories using storage |
30 | 33 | for idx in range(2): |
31 | 34 | history = MergeHistory.from_config( |
32 | 35 | config=config, |
33 | 36 | labels=[f"comp{j}" for j in range(5)], |
34 | 37 | ) |
35 | | - history_path = tmp_path / f"history_{idx}.zip" |
36 | | - history.save(history_path) |
37 | | - history_paths.append(history_path) |
| 38 | + storage.save_history(history, batch_id=f"batch_{idx:02d}") |
38 | 39 |
|
39 | | - # Test normalize_and_save |
40 | | - output_dir = tmp_path / "output" |
41 | | - result = normalize_and_save(history_paths, output_dir) |
| 40 | + # Test normalize_and_save with storage |
| 41 | + result = normalize_and_save(storage=storage) |
42 | 42 |
|
43 | 43 | # Basic checks |
44 | 44 | assert result is not None |
45 | | - assert output_dir.exists() |
46 | | - assert (output_dir / "ensemble_meta.json").exists() |
47 | | - assert (output_dir / "ensemble_merge_array.npz").exists() |
| 45 | + assert storage.ensemble_meta_file.exists() |
| 46 | + assert storage.ensemble_array_file.exists() |
| 47 | + |
| 48 | + # Verify we can load the histories back |
| 49 | + loaded_histories = storage.load_histories() |
| 50 | + assert len(loaded_histories) == 2 |
48 | 51 |
|
49 | 52 |
|
50 | 53 | def test_merge_history_ensemble(): |
|
0 commit comments