Skip to content

Commit c74a058

Browse files
anirudtperimosocordiae
authored andcommitted
adds code for safe return in case of no impostors for lmnn (#36)
adds code for safe return in case of no impostors for lmnn, fixes #17
1 parent 9f602e6 commit c74a058

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

metric_learn/lmnn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def fit(self, X, labels):
6363

6464
target_neighbors = self._select_targets()
6565
impostors = self._find_impostors(target_neighbors[:,-1])
66+
if len(impostors) == 0:
67+
# L has already been initialized to an identity matrix of requisite shape
68+
return
6669

6770
# sum outer products
6871
dfG = _sum_outer_products(self.X, target_neighbors.flatten(),
@@ -203,6 +206,9 @@ def _find_impostors(self, furthest_neighbors):
203206
tmp = np.ravel_multi_index((i,j), shape)
204207
i,j = np.unravel_index(np.unique(tmp), shape)
205208
impostors.append(np.vstack((in_inds[j], out_inds[i])))
209+
if len(impostors) == 0:
210+
# No impostors detected
211+
return impostors
206212
return np.hstack(impostors)
207213

208214

0 commit comments

Comments
 (0)