Skip to content

Commit

Permalink
fix flann distances
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeff committed Feb 19, 2019
1 parent 2b25337 commit 8cc3539
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions pygsp/graphs/nngraphs/nngraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def _knn_flann(features, k, metric, order, params):
# seems to work best).
neighbors, distances = index.nn_index(features, k+1)
index.free_index()
if metric == 'euclidean': # flann returns squared distances
if metric == 'euclidean':
np.sqrt(distances, out=distances)
elif metric == 'minkowski':
np.power(distances, 1/order, out=distances)
return neighbors, distances


Expand Down Expand Up @@ -185,17 +187,22 @@ def _radius_flann(features, radius, metric, order, params):
cfl.set_distance_type(metric, order=order)
index = cfl.FLANNIndex()
index.build_index(features, **params)
D = []
NN = []
for k in range(n_vertices):
nn, d = index.nn_radius(features[k, :], radius**2)
D.append(d)
NN.append(nn)
distances = []
neighbors = []
if metric == 'euclidean':
radius = radius**2
elif metric == 'minkowski':
radius = radius**order
for vertex in range(n_vertices):
neighbor, distance = index.nn_radius(features[vertex, :], radius)
distances.append(distance)
neighbors.append(neighbor)
index.free_index()
if metric == 'euclidean':
# Flann returns squared distances.
D = list(map(np.sqrt, D))
return NN, D
distances = list(map(np.sqrt, distances))
elif metric == 'minkowski':
distances = list(map(lambda d: np.power(d, 1/order), distances))
return neighbors, distances


_nn_functions = {
Expand Down

0 comments on commit 8cc3539

Please sign in to comment.