Skip to content

Commit 901cebd

Browse files
committed
fixing tests
1 parent 2b4feee commit 901cebd

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

spd/clustering/merge_history.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,19 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]:
328328
# For now, keep using dbg_tensor for overlap_stats analysis
329329
dbg_tensor(overlap_stats)
330330

331+
# Convert any Path objects to strings for JSON serialization
332+
history_metadatas = []
333+
for history in self.data:
334+
if history.meta is not None:
335+
meta_copy = history.meta.copy()
336+
# Convert Path objects to strings
337+
for key, value in meta_copy.items():
338+
if isinstance(value, Path):
339+
meta_copy[key] = str(value)
340+
history_metadatas.append(meta_copy)
341+
else:
342+
history_metadatas.append(None)
343+
331344
return (
332345
merges_array,
333346
dict(
@@ -336,7 +349,7 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]:
336349
n_iters=self.n_iters,
337350
c_components=c_components,
338351
config=self.config.model_dump(mode="json"),
339-
history_metadatas=[history.meta for history in self.data],
352+
history_metadatas=history_metadatas,
340353
),
341354
)
342355

tests/clustering/test_merge_integration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def test_merge_with_popping(self):
111111

112112
# Check results
113113
assert history is not None
114-
assert history.merges.k_groups[0].item() == n_components
114+
# First entry is after first merge, so should be n_components - 1
115+
assert history.merges.k_groups[0].item() == n_components - 1
115116
# Final group count depends on pops, but should be less than initial
116117
assert history.merges.k_groups[-1].item() < n_components
117118

@@ -192,8 +193,8 @@ def test_merge_with_small_components(self):
192193
component_labels=component_labels,
193194
)
194195

195-
# Should start with 3 components
196-
assert history.merges.k_groups[0].item() == 3
196+
# First entry is after first merge, so should be 3 - 1 = 2
197+
assert history.merges.k_groups[0].item() == 2
197198
# Early stopping may occur at 2 groups, so final count could be 2 or 3
198199
assert history.merges.k_groups[-1].item() >= 2
199200
assert history.merges.k_groups[-1].item() <= 3

tests/clustering/test_wandb_integration.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from pathlib import Path
55
from unittest.mock import Mock, patch
66

7-
from spd.clustering.merge_history import MergeHistory
8-
from spd.clustering.s2_clustering import _save_merge_history_to_wandb
9-
from spd.clustering.s3_normalize_histories import normalize_and_save
7+
from spd.clustering.merge_config import MergeConfig
8+
from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble
9+
from spd.clustering.pipeline.s2_clustering import _save_merge_history_to_wandb
10+
from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save
1011

1112

1213
def test_wandb_url_parsing_short_format():
@@ -48,8 +49,6 @@ def test_wandb_url_parsing_short_format():
4849

4950
def test_merge_history_ensemble():
5051
"""Test that MergeHistoryEnsemble can handle multiple histories."""
51-
from spd.clustering.merge_config import MergeConfig
52-
from spd.clustering.merge_history import MergeHistoryEnsemble
5352

5453
# Create test merge histories
5554
config = MergeConfig(
@@ -79,7 +78,6 @@ def test_merge_history_ensemble():
7978

8079
def test_save_merge_history_to_wandb():
8180
"""Test that _save_merge_history_to_wandb creates the expected artifact."""
82-
from spd.clustering.merge_config import MergeConfig
8381

8482
# Create a real MergeHistory
8583
config = MergeConfig(
@@ -102,7 +100,7 @@ def test_save_merge_history_to_wandb():
102100
history_path = Path(tmp_dir) / "test_history.zip"
103101
history.save(history_path)
104102

105-
with patch("spd.clustering.s2_clustering.wandb.Artifact") as mock_artifact_class:
103+
with patch("spd.clustering.pipeline.s2_clustering.wandb.Artifact") as mock_artifact_class:
106104
mock_artifact_class.return_value = mock_artifact
107105

108106
# Call the function
@@ -127,7 +125,6 @@ def test_save_merge_history_to_wandb():
127125

128126
def test_wandb_url_field_in_merge_history():
129127
"""Test that MergeHistory can store and serialize wandb_url."""
130-
from spd.clustering.merge_config import MergeConfig
131128

132129
# Create a simple config
133130
config = MergeConfig(
@@ -148,5 +145,5 @@ def test_wandb_url_field_in_merge_history():
148145
history.save(save_path)
149146
loaded_history = MergeHistory.read(save_path)
150147

151-
assert loaded_history
152-
assert loaded_history.merges.group_idxs.shape == (0, 5)
148+
assert loaded_history is not None
149+
assert loaded_history.merges.group_idxs.shape == (10, 5) # (iters, n_components)

0 commit comments

Comments
 (0)