Skip to content

Commit

Permalink
Add method: has_node (#1169)
Browse files Browse the repository at this point in the history
* Add method: has_node

clippy

Add tests

Update src/digraph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

Update src/graph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

Add type annotations

* Add release note

* Add test

---------

Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com>
  • Loading branch information
haoxins and IvanIsCoding authored Apr 24, 2024
1 parent 229f38f commit 7302c8b
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 6 deletions.
7 changes: 7 additions & 0 deletions releasenotes/notes/add-has_node-method-9e6b91bf79e60f50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Added a method :meth:`~rustworkx.PyGraph.has_node`
to the
:class:`~rustworkx.PyGraph` and :class:`~rustworkx.PyDiGraph`
classes to check if a node is in the graph.
2 changes: 2 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ class PyGraph(Generic[_S, _T]):
def get_edge_data_by_index(self, edge_index: int, /) -> _T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def get_node_data(self, node: int, /) -> _S: ...
def has_node(self, node: int, /) -> bool: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
Expand Down Expand Up @@ -1303,6 +1304,7 @@ class PyDiGraph(Generic[_S, _T]):
def get_node_data(self, node: int, /) -> _S: ...
def get_edge_data_by_index(self, edge_index: int, /) -> _T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def has_node(self, node: int, /) -> bool: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_degree(self, node: int, /) -> int: ...
Expand Down
18 changes: 15 additions & 3 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,18 @@ impl PyDiGraph {
self.node_indices()
}

/// Return True if there is a node in the graph.
///
/// :param int node: The node index to check
///
/// :returns: True if there is a node false if there is no node
/// :rtype: bool
#[pyo3(text_signature = "(self, node, /)")]
pub fn has_node(&self, node: usize) -> bool {
let index = NodeIndex::new(node);
self.graph.contains_node(index)
}

/// Return True if there is an edge from node_a to node_b.
///
/// :param int node_a: The source node index to check for an edge
Expand Down Expand Up @@ -3059,7 +3071,7 @@ impl PyDiGraph {
/// required to return a boolean value stating whether the node's data payload fits some criteria.
///
/// For example::
///
///
/// from rustworkx import PyDiGraph
///
/// graph = PyDiGraph()
Expand Down Expand Up @@ -3107,8 +3119,8 @@ impl PyDiGraph {
/// def my_filter_function(edge):
/// if edge:
/// return edge == 'B'
/// return False
///
/// return False
///
/// indices = graph.filter_edges(my_filter_function)
/// assert indices == [1]
///
Expand Down
18 changes: 15 additions & 3 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ impl PyGraph {
self.node_indices()
}

/// Return True if there is a node.
///
/// :param int node: The index for the node
///
/// :returns: True if there is a node false if there is no node
/// :rtype: bool
#[pyo3(text_signature = "(self, node, /)")]
pub fn has_node(&self, node: usize) -> bool {
let index = NodeIndex::new(node);
self.graph.contains_node(index)
}

/// Return True if there is an edge between ``node_a`` and ``node_b``.
///
/// :param int node_a: The index for the first node
Expand Down Expand Up @@ -2039,7 +2051,7 @@ impl PyGraph {
/// required to return a boolean value stating whether the node's data payload fits some criteria.
///
/// For example::
///
///
/// from rustworkx import PyGraph
///
/// graph = PyGraph()
Expand Down Expand Up @@ -2087,8 +2099,8 @@ impl PyGraph {
/// def my_filter_function(edge):
/// if edge:
/// return edge == 'B'
/// return False
///
/// return False
///
/// indices = graph.filter_edges(my_filter_function)
/// assert indices == [1]
///
Expand Down
6 changes: 6 additions & 0 deletions tests/digraph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def test_remove_nodes_from_with_invalid_index(self):
self.assertEqual(["a"], res)
self.assertEqual([0], dag.node_indexes())

def test_has_node(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
self.assertTrue(dag.has_node(node_a))
self.assertFalse(dag.has_node(node_a + 1))

def test_remove_nodes_retain_edges_single_edge(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
Expand Down
6 changes: 6 additions & 0 deletions tests/graph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,9 @@ def test_remove_node_delitem_invalid_index(self):
res = graph.nodes()
self.assertEqual(["a", "b", "c"], res)
self.assertEqual([0, 1, 2], graph.node_indexes())

def test_has_node(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
self.assertTrue(graph.has_node(node_a))
self.assertFalse(graph.has_node(node_a + 1))

0 comments on commit 7302c8b

Please sign in to comment.