Skip to content

Commit

Permalink
add find_predecessor_node_by_edge. (#756)
Browse files Browse the repository at this point in the history
* add find_adjacent_predecessor_node_by_edge.

* linting

* predecessor test

* linting

* Update src/digraph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update src/digraph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* add negative test

* Update src/digraph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update tests/rustworkx_tests/digraph/test_edges.py

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update tests/rustworkx_tests/digraph/test_edges.py

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update tests/rustworkx_tests/digraph/test_edges.py

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update tests/rustworkx_tests/digraph/test_edges.py

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* Update src/digraph.rs

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* add reno note

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
  • Loading branch information
ewinston and mtreinish authored Dec 13, 2022
1 parent 1875b34 commit b8360df
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Add a method
:meth:`~rustworkx.DiGraph.find_predecessor_node_by_edge` to get
the immediate predecessor of a node which is connected by the
specified edge.
36 changes: 36 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,42 @@ impl PyDiGraph {
Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
}

/// Find a source node with a specific edge
///
/// This method is used to find a predecessor of
/// a given node given an edge condition.
///
/// :param int node: The node to use as the source of the search
/// :param callable predicate: A python callable that will take a single
/// parameter, the edge object, and will return a boolean if the
/// edge matches or not
///
/// :returns: The node object that has an edge from it to the provided
/// node index which matches the provided condition
#[pyo3(text_signature = "(self, node, predicate, /)")]
pub fn find_predecessor_node_by_edge(
&self,
py: Python,
node: usize,
predicate: PyObject,
) -> PyResult<&PyObject> {
let predicate_callable = |a: &PyObject| -> PyResult<PyObject> {
let res = predicate.call1(py, (a,))?;
Ok(res.to_object(py))
};
let index = NodeIndex::new(node);
let dir = petgraph::Direction::Incoming;
let edges = self.graph.edges_directed(index, dir);
for edge in edges {
let edge_predicate_raw = predicate_callable(edge.weight())?;
let edge_predicate: bool = edge_predicate_raw.extract(py)?;
if edge_predicate {
return Ok(self.graph.node_weight(edge.source()).unwrap());
}
}
Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
}

/// Generate a dot file from the graph
///
/// :param node_attr: A callable that will take in a node data object
Expand Down
26 changes: 26 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,32 @@ def compare_edges(edge):
with self.assertRaises(rustworkx.NoSuitableNeighbors):
dag.find_adjacent_node_by_edge(node_a, compare_edges)

def test_find_predecessor_node_by_edge(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", "a to b")
node_c = dag.add_child(node_b, "c", "b to c")
dag.add_child(node_c, "d", "c to d")

def compare_edges(edge):
return "a to b" == edge

res = dag.find_predecessor_node_by_edge(node_b, compare_edges)
self.assertEqual("a", res)

def test_find_predecessor_node_by_edge_no_match(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", "a to b")
node_c = dag.add_child(node_b, "c", "b to c")
dag.add_child(node_c, "d", "c to d")

def compare_edges(edge):
return "b to c" == edge

with self.assertRaises(rustworkx.NoSuitableNeighbors):
dag.find_predecessor_node_by_edge(node_b, compare_edges)

def test_add_edge_from(self):
dag = rustworkx.PyDAG()
nodes = list(range(4))
Expand Down

0 comments on commit b8360df

Please sign in to comment.