Skip to content

Commit 238fa31

Browse files
committed
fix pih tests
1 parent fce4007 commit 238fa31

File tree

1 file changed

+53
-25
lines changed

1 file changed

+53
-25
lines changed

tests/clustering/math/test_perm_invariant_hamming.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,35 @@ def test_identity() -> None:
2323
"""a == b should give distance 0."""
2424
a = np.array([0, 1, 2, 1, 0])
2525
b = a.copy()
26-
d, _ = perm_invariant_hamming_matrix(a, b)
27-
assert d == 0
26+
X = np.array([a, b])
27+
D = perm_invariant_hamming_matrix(X)
28+
# Distance between row 1 and row 0 should be 0
29+
assert D[1, 0] == 0
2830

2931

3032
def test_all_one_group() -> None:
3133
"""All rows belong to one group in both arrays (possibly different labels)."""
3234
a = np.zeros(10, dtype=int)
3335
b = np.ones(10, dtype=int) # different label but identical grouping
34-
d, _ = perm_invariant_hamming_matrix(a, b)
35-
assert d == 0
36+
X = np.array([a, b])
37+
D = perm_invariant_hamming_matrix(X)
38+
assert D[1, 0] == 0
3639

3740

3841
def test_permuted_labels() -> None:
3942
a = np.array([0, 2, 1, 1, 0])
4043
b = np.array([1, 0, 0, 2, 1])
41-
d, _ = perm_invariant_hamming_matrix(a, b)
42-
assert d == 1
44+
X = np.array([a, b])
45+
D = perm_invariant_hamming_matrix(X)
46+
assert D[1, 0] == 1
4347

4448

4549
def test_swap_two_labels() -> None:
4650
a = np.array([0, 0, 1, 1])
4751
b = np.array([1, 1, 0, 0])
48-
d, _ = perm_invariant_hamming_matrix(a, b)
49-
assert d == 0
52+
X = np.array([a, b])
53+
D = perm_invariant_hamming_matrix(X)
54+
assert D[1, 0] == 0
5055

5156

5257
def test_random_small_bruteforce() -> None:
@@ -56,40 +61,63 @@ def test_random_small_bruteforce() -> None:
5661
k = 3
5762
a = rng.integers(0, k, size=n)
5863
b = rng.integers(0, k, size=n)
59-
d_alg, _ = perm_invariant_hamming_matrix(a, b)
64+
X = np.array([a, b])
65+
D = perm_invariant_hamming_matrix(X)
66+
d_alg = D[1, 0]
6067
d_true = brute_force_min_hamming(a, b)
6168
assert d_alg == d_true
6269

6370

6471
def test_shape_mismatch() -> None:
6572
a = np.array([0, 1, 2])
6673
b = np.array([0, 1])
67-
with pytest.raises(AssertionError):
68-
perm_invariant_hamming_matrix(a, b)
74+
with pytest.raises((ValueError, IndexError)):
75+
# This should fail when trying to create the matrix due to shape mismatch
76+
X = np.array([a, b])
77+
perm_invariant_hamming_matrix(X)
6978

7079

71-
def test_return_mapping() -> None:
72-
"""Verify the returned mapping is correct."""
80+
def test_matrix_multiple_pairs() -> None:
81+
"""Test the matrix function with multiple label vectors."""
7382
a = np.array([0, 0, 1, 1])
74-
b = np.array([2, 2, 3, 3])
75-
d, mapping = perm_invariant_hamming_matrix(a, b, return_mapping=True)
76-
assert d == 0
77-
assert mapping[0] == 2
78-
assert mapping[1] == 3
83+
b = np.array([2, 2, 3, 3]) # Should be distance 0 (perfect mapping)
84+
c = np.array([0, 1, 0, 1]) # Should be distance 2 from both a and b
85+
X = np.array([a, b, c])
86+
D = perm_invariant_hamming_matrix(X)
7987

88+
assert D[1, 0] == 0 # a and b should have distance 0
89+
assert D[2, 0] == 2 # a and c should have distance 2
90+
assert D[2, 1] == 2 # b and c should have distance 2
8091

81-
def test_return_mapping_false() -> None:
82-
"""Test return_mapping=False."""
92+
93+
def test_matrix_upper_triangle_nan() -> None:
94+
"""Test that upper triangle and diagonal are NaN."""
8395
a = np.array([0, 1, 0])
8496
b = np.array([1, 0, 1])
85-
d, mapping = perm_invariant_hamming_matrix(a, b, return_mapping=False)
86-
assert d == 0
87-
assert mapping is None
97+
c = np.array([0, 0, 1])
98+
X = np.array([a, b, c])
99+
D = perm_invariant_hamming_matrix(X)
100+
101+
# Diagonal should be NaN
102+
assert np.isnan(D[0, 0])
103+
assert np.isnan(D[1, 1])
104+
assert np.isnan(D[2, 2])
105+
106+
# Upper triangle should be NaN
107+
assert np.isnan(D[0, 1])
108+
assert np.isnan(D[0, 2])
109+
assert np.isnan(D[1, 2])
110+
111+
# Lower triangle should have actual distances
112+
assert not np.isnan(D[1, 0])
113+
assert not np.isnan(D[2, 0])
114+
assert not np.isnan(D[2, 1])
88115

89116

90117
def test_unused_labels() -> None:
91118
"""Test when arrays don't use all labels 0..k-1."""
92119
a = np.array([0, 0, 3, 3]) # skips 1, 2
93120
b = np.array([1, 1, 2, 2])
94-
d, _ = perm_invariant_hamming_matrix(a, b)
95-
assert d == 0
121+
X = np.array([a, b])
122+
D = perm_invariant_hamming_matrix(X)
123+
assert D[1, 0] == 0

0 commit comments

Comments
 (0)