diff --git a/kmapper/kmapper.py b/kmapper/kmapper.py index 7c87ee66..7c5dcfb3 100644 --- a/kmapper/kmapper.py +++ b/kmapper/kmapper.py @@ -827,7 +827,7 @@ def data_from_cluster_id(self, cluster_id, graph, data): else: return np.array([]) - def clusters_from_cover(self, cube_ids, graph): + def find_nodes(self, cube_ids, graph): """Returns the clusters and their members from the subset of the cover spanned by the given cube_ids Parameters @@ -839,7 +839,7 @@ def clusters_from_cover(self, cube_ids, graph): Returns ------- - clusters : dict + nodes : dict cluster membership indexed by cluster ID (subset of `graph["nodes"]`). """ @@ -850,6 +850,52 @@ def clusters_from_cover(self, cube_ids, graph): clusters[cluster_id] = cluster_members return clusters + def nearest_nodes(self, newdata, graph, cover, data, nn): + """Returns the clusters nearest to the `newdata` using the given NearestNeighbors algorithm + + Parameters + ---------- + newdata : Numpy array + New dataset. Accepts both 1-D and 2-D array. + graph : dict + The resulting dictionary after applying map(). + cover : kmapper.Cover + The cover used to build graph. + data : Numpy array + Original dataset. + nn : NearestNeighbors + Scikit-learn NearestNeighbors instance to use. + + Returns + ------- + cluster_ids : numpy array + Cluster IDs. + + """ + if len(newdata.shape) == 1: + newdata = newdata[np.newaxis] + + cube_ids = np.concatenate([cover.find(row) for row in newdata]) + if len(cube_ids) == 0: + return np.empty((0,)) + print("new data {} found in {}".format(newdata, cube_ids)) + + nodes = self.find_nodes(cube_ids, graph) + if len(nodes) == 0: + return np.empty((0,)) + + nn_data = [] + nn_cluster_ids = [] + for cluster_id, cluster_members in nodes.items(): + cluster_data = data[cluster_members] + nn_data.append(cluster_data) + nn_cluster_ids.append([cluster_id]*len(cluster_data)) + nn_data = np.vstack(nn_data) + nn_cluster_ids = np.concatenate(nn_cluster_ids) + nn.fit(nn_data) + nn_ids = nn.kneighbors(newdata, return_distance=False) + return np.unique(nn_cluster_ids[nn_ids]) + def _process_projection_tuple(self, projection): # Detect if projection is a tuple (for prediction functions) # TODO: multi-label models diff --git a/test/test_mapper.py b/test/test_mapper.py index 4d8c03b8..a23f7204 100644 --- a/test/test_mapper.py +++ b/test/test_mapper.py @@ -75,25 +75,58 @@ def test_wrong_id(self): mems = mapper.data_from_cluster_id("new node", graph, data) np.testing.assert_array_equal(mems, np.array([])) - def test_clusters_from_cover(self): + def test_find_nodes(self): mapper = KeplerMapper(verbose=1) data = np.random.rand(100, 2) graph = mapper.map(data) cube_ids = mapper.cover.find(data[0]) - mems = mapper.clusters_from_cover(cube_ids, graph) + mems = mapper.find_nodes(cube_ids, graph) assert len(mems) > 0 for cluster_id, cluster_members in mems.items(): np.testing.assert_array_equal(cluster_members, graph["nodes"][cluster_id]) - def test_no_clusters_from_cover(self): + def test_node_not_found(self): mapper = KeplerMapper(verbose=1) data = np.random.rand(100, 2) graph = mapper.map(data) - mems = mapper.clusters_from_cover([999], graph) + mems = mapper.find_nodes([999], graph) assert len(mems) == 0 + def test_nearest_nodes_1(self): + mapper = KeplerMapper(verbose=1) + data = np.random.rand(100, 2) + + graph = mapper.map(data) + nn = neighbors.NearestNeighbors(n_neighbors=1) + expected_id, members = next(iter(graph["nodes"].items())) + newdata = data[members[-1]] + ids = mapper.nearest_nodes(newdata, graph, mapper.cover, data, nn) + assert all(ids == [expected_id]), ids + + def test_nearest_nodes_2(self): + mapper = KeplerMapper(verbose=1) + data = np.random.rand(100, 2) + + graph = mapper.map(data) + nn = neighbors.NearestNeighbors(n_neighbors=1) + expected_clusters = [(cluster_id, members) for cluster_id, members in graph['nodes'].items()][:2] + + cluster_id1 = expected_clusters[0][0] + newdata1 = data[expected_clusters[0][1][-1]] + ids = mapper.nearest_nodes(newdata1, graph, mapper.cover, data, nn) + assert all(ids == [cluster_id1]) + + cluster_id2 = expected_clusters[1][0] + newdata2 = data[expected_clusters[1][1][-1]] + ids = mapper.nearest_nodes(newdata2, graph, mapper.cover, data, nn) + assert all(ids == [cluster_id2]) + + newdata = np.vstack([newdata1, newdata2]) + ids = mapper.nearest_nodes(newdata, graph, mapper.cover, data, nn) + assert all(ids == [cluster_id1, cluster_id2]) + class TestMap: def test_simplices(self): mapper = KeplerMapper()