diff --git a/pynndescent/pynndescent_.py b/pynndescent/pynndescent_.py index ec87742..e19834f 100644 --- a/pynndescent/pynndescent_.py +++ b/pynndescent/pynndescent_.py @@ -109,27 +109,33 @@ def make_nn_descent(dist, dist_args): specialised for the given distance metric and metric arguments. Numba doesn't support higher order functions directly, but we can instead JIT compile the version of NN-descent for any given metric. - Parameters ---------- dist: function A numba JITd distance function which, given two arrays computes a dissimilarity between them. - dist_args: tuple Any extra arguments that need to be passed to the distance function beyond the two arrays to be compared. - Returns ------- A numba JITd function for nearest neighbor descent computation that is specialised to the given metric. """ - @numba.njit(fastmath=True) - def nn_descent(data, n_neighbors, rng_state, max_candidates=50, - n_iters=10, delta=0.001, rho=0.5, - rp_tree_init=True, leaf_array=None, verbose=False): + @numba.njit(parallel=True) + def nn_descent( + data, + n_neighbors, + rng_state, + max_candidates=50, + n_iters=10, + delta=0.001, + rho=0.5, + rp_tree_init=True, + leaf_array=None, + verbose=False, + ): n_vertices = data.shape[0] current_graph = make_heap(data.shape[0], n_neighbors) @@ -142,59 +148,49 @@ def nn_descent(data, n_neighbors, rng_state, max_candidates=50, if rp_tree_init: for n in range(leaf_array.shape[0]): - tried = set([(-1, -1)]) for i in range(leaf_array.shape[1]): if leaf_array[n, i] < 0: break for j in range(i + 1, leaf_array.shape[1]): if leaf_array[n, j] < 0: break - if (leaf_array[n, i], leaf_array[n, j]) in tried: - continue - d = dist(data[leaf_array[n, i]], data[leaf_array[n, j]], - *dist_args) - heap_push(current_graph, leaf_array[n, i], d, - leaf_array[n, j], - 1) - heap_push(current_graph, leaf_array[n, j], d, - leaf_array[n, i], - 1) - tried.add((leaf_array[n, i], leaf_array[n, j])) + d = dist( + data[leaf_array[n, i]], data[leaf_array[n, j]], *dist_args + ) + heap_push( + current_graph, leaf_array[n, i], d, leaf_array[n, j], 1 + ) + heap_push( + current_graph, leaf_array[n, j], d, leaf_array[n, i], 1 + ) for n in range(n_iters): + if verbose: + print("\t", n, " / ", n_iters) - (new_candidate_neighbors, - old_candidate_neighbors) = build_candidates(current_graph, - n_vertices, - n_neighbors, - max_candidates, - rng_state, rho) + candidate_neighbors = build_candidates( + current_graph, n_vertices, n_neighbors, max_candidates, rng_state + ) c = 0 for i in range(n_vertices): for j in range(max_candidates): - p = int(new_candidate_neighbors[0, i, j]) - if p < 0: + p = int(candidate_neighbors[0, i, j]) + if p < 0 or tau_rand(rng_state) < rho: continue - for k in range(j, max_candidates): - q = int(new_candidate_neighbors[0, i, k]) - if q < 0: - continue - - d = dist(data[p], data[q], *dist_args) - c += heap_push(current_graph, p, d, q, 1) - c += heap_push(current_graph, q, d, p, 1) - for k in range(max_candidates): - q = int(old_candidate_neighbors[0, i, k]) - if q < 0: + q = int(candidate_neighbors[0, i, k]) + if ( + q < 0 + or not candidate_neighbors[2, i, j] + and not candidate_neighbors[2, i, k] + ): continue d = dist(data[p], data[q], *dist_args) c += heap_push(current_graph, p, d, q, 1) c += heap_push(current_graph, q, d, p, 1) - if c <= delta * n_neighbors * data.shape[0]: break diff --git a/pynndescent/utils.py b/pynndescent/utils.py index a57645e..dc7d660 100644 --- a/pynndescent/utils.py +++ b/pynndescent/utils.py @@ -373,29 +373,61 @@ def smallest_flagged(heap, row): @numba.njit(parallel=True) -def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates, - rng_state, rho=0.5): +def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates, rng_state): """Build a heap of candidate neighbors for nearest neighbor descent. For each vertex the candidate neighbors are any current neighbors, and any vertices that have the vertex as one of their nearest neighbors. - Parameters ---------- current_graph: heap The current state of the graph for nearest neighbor descent. - n_vertices: int The total number of vertices in the graph. - n_neighbors: int The number of neighbor edges per node in the current graph. - max_candidates: int The maximum number of new candidate neighbors. - rng_state: array of int64, shape (3,) The internal state of the rng + Returns + ------- + candidate_neighbors: A heap with an array of (randomly sorted) candidate + neighbors for each vertex in the graph. + """ + candidate_neighbors = make_heap(n_vertices, max_candidates) + for i in range(n_vertices): + for j in range(n_neighbors): + if current_graph[0, i, j] < 0: + continue + idx = current_graph[0, i, j] + isn = current_graph[2, i, j] + d = tau_rand(rng_state) + heap_push(candidate_neighbors, i, d, idx, isn) + heap_push(candidate_neighbors, idx, d, i, isn) + current_graph[2, i, j] = 0 + + return candidate_neighbors + +@numba.njit(parallel=True) +def new_build_candidates( + current_graph, n_vertices, n_neighbors, max_candidates, rng_state, rho=0.5 +): # pragma: no cover + """Build a heap of candidate neighbors for nearest neighbor descent. For + each vertex the candidate neighbors are any current neighbors, and any + vertices that have the vertex as one of their nearest neighbors. + Parameters + ---------- + current_graph: heap + The current state of the graph for nearest neighbor descent. + n_vertices: int + The total number of vertices in the graph. + n_neighbors: int + The number of neighbor edges per node in the current graph. + max_candidates: int + The maximum number of new candidate neighbors. + rng_state: array of int64, shape (3,) + The internal state of the rng Returns ------- candidate_neighbors: A heap with an array of (randomly sorted) candidate @@ -420,7 +452,7 @@ def build_candidates(current_graph, n_vertices, n_neighbors, max_candidates, heap_push(old_candidate_neighbors, i, d, idx, isn) heap_push(old_candidate_neighbors, idx, d, i, isn) - if c > 0 : + if c > 0: current_graph[2, i, j] = 0 return new_candidate_neighbors, old_candidate_neighbors