From ba2d28e3e922a0fee8356305742553db558e422a Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Thu, 17 Aug 2023 13:36:33 -0400 Subject: [PATCH 1/2] Added `s` parameter to `neighbors` --- xgi/core/views.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/xgi/core/views.py b/xgi/core/views.py index 8bea38474..ca9c9c812 100644 --- a/xgi/core/views.py +++ b/xgi/core/views.py @@ -330,7 +330,7 @@ def filterby_attr(self, attr, val, mode="eq", missing=None): ) return type(self).from_view(self, bunch) - def neighbors(self, id): + def neighbors(self, id, s=1): """Find the neighbors of an ID. The neighbors of an ID are those IDs that share at least one bipartite ID. @@ -339,6 +339,9 @@ def neighbors(self, id): ---------- id : hashable ID to find neighbors of. + s : int, optional + The intersection size s to be considered neighbors. + By default, 1. Returns ------- set @@ -359,9 +362,17 @@ def neighbors(self, id): {1, 3, 4} """ - return {i for n in self._id_dict[id] for i in self._bi_id_dict[n]}.difference( - {id} - ) + if s == 1: + return { + i for n in self._id_dict[id] for i in self._bi_id_dict[n] + }.difference({id}) + else: + return { + i + for n in self._id_dict[id] + for i in self._bi_id_dict[n] + if len(self._id_dict[id].intersection(self._id_dict[i])) >= s + }.difference({id}) def duplicates(self): """Find IDs that have a duplicate. From ac0dccfccf866bfbf729d06e325fb2ef4eaa4983 Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Thu, 21 Sep 2023 15:07:54 -0400 Subject: [PATCH 2/2] Added unit tests --- tests/core/test_views.py | 13 ++++++++----- xgi/core/views.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/core/test_views.py b/tests/core/test_views.py index ba341d794..71883d72a 100644 --- a/tests/core/test_views.py +++ b/tests/core/test_views.py @@ -5,11 +5,10 @@ from xgi.exception import IDNotFound, XGIError -def test_neighbors(edgelist1, edgelist2): - el1 = edgelist1 - el2 = edgelist2 - H1 = xgi.Hypergraph(el1) - H2 = xgi.Hypergraph(el2) +def test_neighbors(edgelist1, edgelist2, edgelist7): + H1 = xgi.Hypergraph(edgelist1) + H2 = xgi.Hypergraph(edgelist2) + H3 = xgi.Hypergraph(edgelist7) assert H1.nodes.neighbors(1) == {2, 3} assert H1.nodes.neighbors(4) == set() assert H1.nodes.neighbors(6) == {5, 7, 8} @@ -20,6 +19,10 @@ def test_neighbors(edgelist1, edgelist2): assert H1.edges.neighbors(1) == set() assert H1.edges.neighbors(2) == {3} assert H1.edges.neighbors(3) == {2} + assert H3.edges.neighbors(0, s=2) == {1} + assert H3.edges.neighbors(1, s=2) == {0, 2} + assert H3.edges.neighbors(2, s=2) == {1} + assert H3.edges.neighbors(3, s=2) == set() def test_edge_order(edgelist3): diff --git a/xgi/core/views.py b/xgi/core/views.py index ca9c9c812..1d4acad8a 100644 --- a/xgi/core/views.py +++ b/xgi/core/views.py @@ -340,8 +340,9 @@ def neighbors(self, id, s=1): id : hashable ID to find neighbors of. s : int, optional - The intersection size s to be considered neighbors. + The intersection size s for two edges or nodes to be considered neighbors. By default, 1. + Returns ------- set