44from pathlib import Path
55from unittest .mock import Mock , patch
66
7- from spd .clustering .merge_history import MergeHistory
8- from spd .clustering .s2_clustering import _save_merge_history_to_wandb
9- from spd .clustering .s3_normalize_histories import normalize_and_save
7+ from spd .clustering .merge_config import MergeConfig
8+ from spd .clustering .merge_history import MergeHistory , MergeHistoryEnsemble
9+ from spd .clustering .pipeline .s2_clustering import _save_merge_history_to_wandb
10+ from spd .clustering .pipeline .s3_normalize_histories import normalize_and_save
1011
1112
1213def test_wandb_url_parsing_short_format ():
@@ -48,8 +49,6 @@ def test_wandb_url_parsing_short_format():
4849
4950def test_merge_history_ensemble ():
5051 """Test that MergeHistoryEnsemble can handle multiple histories."""
51- from spd .clustering .merge_config import MergeConfig
52- from spd .clustering .merge_history import MergeHistoryEnsemble
5352
5453 # Create test merge histories
5554 config = MergeConfig (
@@ -79,7 +78,6 @@ def test_merge_history_ensemble():
7978
8079def test_save_merge_history_to_wandb ():
8180 """Test that _save_merge_history_to_wandb creates the expected artifact."""
82- from spd .clustering .merge_config import MergeConfig
8381
8482 # Create a real MergeHistory
8583 config = MergeConfig (
@@ -102,7 +100,7 @@ def test_save_merge_history_to_wandb():
102100 history_path = Path (tmp_dir ) / "test_history.zip"
103101 history .save (history_path )
104102
105- with patch ("spd.clustering.s2_clustering.wandb.Artifact" ) as mock_artifact_class :
103+ with patch ("spd.clustering.pipeline. s2_clustering.wandb.Artifact" ) as mock_artifact_class :
106104 mock_artifact_class .return_value = mock_artifact
107105
108106 # Call the function
@@ -127,7 +125,6 @@ def test_save_merge_history_to_wandb():
127125
128126def test_wandb_url_field_in_merge_history ():
129127 """Test that MergeHistory can store and serialize wandb_url."""
130- from spd .clustering .merge_config import MergeConfig
131128
132129 # Create a simple config
133130 config = MergeConfig (
@@ -148,5 +145,5 @@ def test_wandb_url_field_in_merge_history():
148145 history .save (save_path )
149146 loaded_history = MergeHistory .read (save_path )
150147
151- assert loaded_history
152- assert loaded_history .merges .group_idxs .shape == (0 , 5 )
148+ assert loaded_history is not None
149+ assert loaded_history .merges .group_idxs .shape == (10 , 5 ) # (iters, n_components )
0 commit comments