Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-deterministic search #11

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions pynndescent/pynndescent_.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,33 @@ def make_nn_descent(dist, dist_args):
specialised for the given distance metric and metric arguments. Numba
doesn't support higher order functions directly, but we can instead JIT
compile the version of NN-descent for any given metric.

Parameters
----------
dist: function
A numba JITd distance function which, given two arrays computes a
dissimilarity between them.

dist_args: tuple
Any extra arguments that need to be passed to the distance function
beyond the two arrays to be compared.

Returns
-------
A numba JITd function for nearest neighbor descent computation that is
specialised to the given metric.
"""

@numba.njit(fastmath=True)
def nn_descent(data, n_neighbors, rng_state, max_candidates=50,
n_iters=10, delta=0.001, rho=0.5,
rp_tree_init=True, leaf_array=None, verbose=False):
@numba.njit(parallel=True)
def nn_descent(
data,
n_neighbors,
rng_state,
max_candidates=50,
n_iters=10,
delta=0.001,
rho=0.5,
rp_tree_init=True,
leaf_array=None,
verbose=False,
):
n_vertices = data.shape[0]

current_graph = make_heap(data.shape[0], n_neighbors)
Expand All @@ -142,59 +148,49 @@ def nn_descent(data, n_neighbors, rng_state, max_candidates=50,

if rp_tree_init:
for n in range(leaf_array.shape[0]):
tried = set([(-1, -1)])
for i in range(leaf_array.shape[1]):
if leaf_array[n, i] < 0:
break
for j in range(i + 1, leaf_array.shape[1]):
if leaf_array[n, j] < 0:
break
if (leaf_array[n, i], leaf_array[n, j]) in tried:
continue
d = dist(data[leaf_array[n, i]], data[leaf_array[n, j]],
*dist_args)
heap_push(current_graph, leaf_array[n, i], d,
leaf_array[n, j],
1)
heap_push(current_graph, leaf_array[n, j], d,
leaf_array[n, i],
1)
tried.add((leaf_array[n, i], leaf_array[n, j]))
d = dist(
data[leaf_array[n, i]], data[leaf_array[n, j]], *dist_args
)
heap_push(
current_graph, leaf_array[n, i], d, leaf_array[n, j], 1
)
heap_push(
current_graph, leaf_array[n, j], d, leaf_array[n, i], 1
)

for n in range(n_iters):
if verbose:
print("\t", n, " / ", n_iters)

(new_candidate_neighbors,
old_candidate_neighbors) = build_candidates(current_graph,
n_vertices,
n_neighbors,
max_candidates,
rng_state, rho)
candidate_neighbors = build_candidates(
current_graph, n_vertices, n_neighbors, max_candidates, rng_state
)

c = 0
for i in range(n_vertices):
for j in range(max_candidates):
p = int(new_candidate_neighbors[0, i, j])
if p < 0:
p = int(candidate_neighbors[0, i, j])
if p < 0 or tau_rand(rng_state) < rho:
continue
for k in range(j, max_candidates):
q = int(new_candidate_neighbors[0, i, k])
if q < 0:
continue

d = dist(data[p], data[q], *dist_args)
c += heap_push(current_graph, p, d, q, 1)
c += heap_push(current_graph, q, d, p, 1)

for k in range(max_candidates):
q = int(old_candidate_neighbors[0, i, k])
if q < 0:
q = int(candidate_neighbors[0, i, k])
if (
q < 0
or not candidate_neighbors[2, i, j]
and not candidate_neighbors[2, i, k]
):
continue

d = dist(data[p], data[q], *dist_args)
c += heap_push(current_graph, p, d, q, 1)
c += heap_push(current_graph, q, d, p, 1)


if c <= delta * n_neighbors * data.shape[0]:
break

Expand Down
48 changes: 40 additions & 8 deletions pynndescent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,29 +373,61 @@ def smallest_flagged(heap, row):


@numba.njit(parallel=True)
def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates,
rng_state, rho=0.5):
def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates, rng_state):
"""Build a heap of candidate neighbors for nearest neighbor descent. For
each vertex the candidate neighbors are any current neighbors, and any
vertices that have the vertex as one of their nearest neighbors.

Parameters
----------
current_graph: heap
The current state of the graph for nearest neighbor descent.

n_vertices: int
The total number of vertices in the graph.

n_neighbors: int
The number of neighbor edges per node in the current graph.

max_candidates: int
The maximum number of new candidate neighbors.

rng_state: array of int64, shape (3,)
The internal state of the rng
Returns
-------
candidate_neighbors: A heap with an array of (randomly sorted) candidate
neighbors for each vertex in the graph.
"""
candidate_neighbors = make_heap(n_vertices, max_candidates)
for i in range(n_vertices):
for j in range(n_neighbors):
if current_graph[0, i, j] < 0:
continue
idx = current_graph[0, i, j]
isn = current_graph[2, i, j]
d = tau_rand(rng_state)
heap_push(candidate_neighbors, i, d, idx, isn)
heap_push(candidate_neighbors, idx, d, i, isn)
current_graph[2, i, j] = 0

return candidate_neighbors


@numba.njit(parallel=True)
def new_build_candidates(
current_graph, n_vertices, n_neighbors, max_candidates, rng_state, rho=0.5
): # pragma: no cover
"""Build a heap of candidate neighbors for nearest neighbor descent. For
each vertex the candidate neighbors are any current neighbors, and any
vertices that have the vertex as one of their nearest neighbors.
Parameters
----------
current_graph: heap
The current state of the graph for nearest neighbor descent.
n_vertices: int
The total number of vertices in the graph.
n_neighbors: int
The number of neighbor edges per node in the current graph.
max_candidates: int
The maximum number of new candidate neighbors.
rng_state: array of int64, shape (3,)
The internal state of the rng
Returns
-------
candidate_neighbors: A heap with an array of (randomly sorted) candidate
Expand All @@ -420,7 +452,7 @@ def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates,
heap_push(old_candidate_neighbors, i, d, idx, isn)
heap_push(old_candidate_neighbors, idx, d, i, isn)

if c > 0 :
if c > 0:
current_graph[2, i, j] = 0

return new_candidate_neighbors, old_candidate_neighbors
Expand Down