@@ -23,30 +23,35 @@ def test_identity() -> None:
2323 """a == b should give distance 0."""
2424 a = np .array ([0 , 1 , 2 , 1 , 0 ])
2525 b = a .copy ()
26- d , _ = perm_invariant_hamming_matrix (a , b )
27- assert d == 0
26+ X = np .array ([a , b ])
27+ D = perm_invariant_hamming_matrix (X )
28+ # Distance between row 1 and row 0 should be 0
29+ assert D [1 , 0 ] == 0
2830
2931
3032def test_all_one_group () -> None :
3133 """All rows belong to one group in both arrays (possibly different labels)."""
3234 a = np .zeros (10 , dtype = int )
3335 b = np .ones (10 , dtype = int ) # different label but identical grouping
34- d , _ = perm_invariant_hamming_matrix (a , b )
35- assert d == 0
36+ X = np .array ([a , b ])
37+ D = perm_invariant_hamming_matrix (X )
38+ assert D [1 , 0 ] == 0
3639
3740
3841def test_permuted_labels () -> None :
3942 a = np .array ([0 , 2 , 1 , 1 , 0 ])
4043 b = np .array ([1 , 0 , 0 , 2 , 1 ])
41- d , _ = perm_invariant_hamming_matrix (a , b )
42- assert d == 1
44+ X = np .array ([a , b ])
45+ D = perm_invariant_hamming_matrix (X )
46+ assert D [1 , 0 ] == 1
4347
4448
4549def test_swap_two_labels () -> None :
4650 a = np .array ([0 , 0 , 1 , 1 ])
4751 b = np .array ([1 , 1 , 0 , 0 ])
48- d , _ = perm_invariant_hamming_matrix (a , b )
49- assert d == 0
52+ X = np .array ([a , b ])
53+ D = perm_invariant_hamming_matrix (X )
54+ assert D [1 , 0 ] == 0
5055
5156
5257def test_random_small_bruteforce () -> None :
@@ -56,40 +61,63 @@ def test_random_small_bruteforce() -> None:
5661 k = 3
5762 a = rng .integers (0 , k , size = n )
5863 b = rng .integers (0 , k , size = n )
59- d_alg , _ = perm_invariant_hamming_matrix (a , b )
64+ X = np .array ([a , b ])
65+ D = perm_invariant_hamming_matrix (X )
66+ d_alg = D [1 , 0 ]
6067 d_true = brute_force_min_hamming (a , b )
6168 assert d_alg == d_true
6269
6370
6471def test_shape_mismatch () -> None :
6572 a = np .array ([0 , 1 , 2 ])
6673 b = np .array ([0 , 1 ])
67- with pytest .raises (AssertionError ):
68- perm_invariant_hamming_matrix (a , b )
74+ with pytest .raises ((ValueError , IndexError )):
75+ # This should fail when trying to create the matrix due to shape mismatch
76+ X = np .array ([a , b ])
77+ perm_invariant_hamming_matrix (X )
6978
7079
71- def test_return_mapping () -> None :
72- """Verify the returned mapping is correct ."""
80+ def test_matrix_multiple_pairs () -> None :
81+ """Test the matrix function with multiple label vectors ."""
7382 a = np .array ([0 , 0 , 1 , 1 ])
74- b = np .array ([2 , 2 , 3 , 3 ])
75- d , mapping = perm_invariant_hamming_matrix (a , b , return_mapping = True )
76- assert d == 0
77- assert mapping [0 ] == 2
78- assert mapping [1 ] == 3
83+ b = np .array ([2 , 2 , 3 , 3 ]) # Should be distance 0 (perfect mapping)
84+ c = np .array ([0 , 1 , 0 , 1 ]) # Should be distance 2 from both a and b
85+ X = np .array ([a , b , c ])
86+ D = perm_invariant_hamming_matrix (X )
7987
88+ assert D [1 , 0 ] == 0 # a and b should have distance 0
89+ assert D [2 , 0 ] == 2 # a and c should have distance 2
90+ assert D [2 , 1 ] == 2 # b and c should have distance 2
8091
81- def test_return_mapping_false () -> None :
82- """Test return_mapping=False."""
92+
93+ def test_matrix_upper_triangle_nan () -> None :
94+ """Test that upper triangle and diagonal are NaN."""
8395 a = np .array ([0 , 1 , 0 ])
8496 b = np .array ([1 , 0 , 1 ])
85- d , mapping = perm_invariant_hamming_matrix (a , b , return_mapping = False )
86- assert d == 0
87- assert mapping is None
97+ c = np .array ([0 , 0 , 1 ])
98+ X = np .array ([a , b , c ])
99+ D = perm_invariant_hamming_matrix (X )
100+
101+ # Diagonal should be NaN
102+ assert np .isnan (D [0 , 0 ])
103+ assert np .isnan (D [1 , 1 ])
104+ assert np .isnan (D [2 , 2 ])
105+
106+ # Upper triangle should be NaN
107+ assert np .isnan (D [0 , 1 ])
108+ assert np .isnan (D [0 , 2 ])
109+ assert np .isnan (D [1 , 2 ])
110+
111+ # Lower triangle should have actual distances
112+ assert not np .isnan (D [1 , 0 ])
113+ assert not np .isnan (D [2 , 0 ])
114+ assert not np .isnan (D [2 , 1 ])
88115
89116
90117def test_unused_labels () -> None :
91118 """Test when arrays don't use all labels 0..k-1."""
92119 a = np .array ([0 , 0 , 3 , 3 ]) # skips 1, 2
93120 b = np .array ([1 , 1 , 2 , 2 ])
94- d , _ = perm_invariant_hamming_matrix (a , b )
95- assert d == 0
121+ X = np .array ([a , b ])
122+ D = perm_invariant_hamming_matrix (X )
123+ assert D [1 , 0 ] == 0
0 commit comments