diff --git a/docs/source/api.rst b/docs/source/api.rst index 2512e60f7..db30369de 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -149,6 +149,7 @@ Connectivity and Cycles rustworkx.weakly_connected_components rustworkx.is_weakly_connected rustworkx.cycle_basis + rustworkx.simple_cycles rustworkx.digraph_find_cycle rustworkx.articulation_points rustworkx.biconnected_components diff --git a/releasenotes/notes/simple_cycles-b2d64c2d02a1974a.yaml b/releasenotes/notes/simple_cycles-b2d64c2d02a1974a.yaml new file mode 100644 index 000000000..eff46ca71 --- /dev/null +++ b/releasenotes/notes/simple_cycles-b2d64c2d02a1974a.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function, :func:`~.simple_cycles`, which is an implementation of + `Johnson's algorithm `__ for finding all + elementary cycles in a directed graph. diff --git a/src/connectivity/johnson_simple_cycles.rs b/src/connectivity/johnson_simple_cycles.rs new file mode 100644 index 000000000..bef2a01b4 --- /dev/null +++ b/src/connectivity/johnson_simple_cycles.rs @@ -0,0 +1,309 @@ +// Licensed under the apache license, version 2.0 (the "license"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use hashbrown::{HashMap, HashSet}; +use indexmap::IndexSet; + +use crate::digraph::PyDiGraph; +use crate::StablePyGraph; +use petgraph::algo::kosaraju_scc; +use petgraph::graph::NodeIndex; +use petgraph::stable_graph::StableDiGraph; +use petgraph::visit::EdgeRef; +use petgraph::visit::IntoEdgeReferences; +use petgraph::visit::IntoNodeReferences; +use petgraph::visit::NodeFiltered; +use petgraph::Directed; + +use pyo3::iter::IterNextOutput; +use pyo3::prelude::*; + +use crate::iterators::NodeIndices; + +fn build_subgraph( + graph: &StablePyGraph, + nodes: &[NodeIndex], +) -> (StableDiGraph<(), ()>, HashMap) { + let node_set: HashSet = nodes.iter().copied().collect(); + let mut node_map: HashMap = HashMap::with_capacity(nodes.len()); + let node_filter = |node: NodeIndex| -> bool { node_set.contains(&node) }; + // Overallocates edges, but not a big deal as this is temporary for the lifetime of the + // subgraph + let mut out_graph = StableDiGraph::<(), ()>::with_capacity(nodes.len(), graph.edge_count()); + let filtered = NodeFiltered(&graph, node_filter); + for node in filtered.node_references() { + let new_node = out_graph.add_node(()); + node_map.insert(node.0, new_node); + } + for edge in filtered.edge_references() { + let new_source = *node_map.get(&edge.source()).unwrap(); + let new_target = *node_map.get(&edge.target()).unwrap(); + out_graph.add_edge(new_source, new_target, ()); + } + (out_graph, node_map) +} + +#[pyclass(module = "rustworkx")] +pub struct SimpleCycleIter { + graph_clone: StablePyGraph, + scc: Vec>, + self_cycles: Option>, + path: Vec, + blocked: HashSet, + closed: HashSet, + block: HashMap>, + stack: Vec<(NodeIndex, IndexSet)>, + start_node: NodeIndex, + node_map: HashMap, + reverse_node_map: HashMap, + subgraph: StableDiGraph<(), ()>, +} + +impl SimpleCycleIter { + pub fn new(graph: &PyDiGraph) -> Self { + // Copy graph to remove self edges before running johnson's algorithm + let mut graph_clone = graph.graph.clone(); + + // For compatibility with networkx manually insert self cycles and filter + // from Johnson's algorithm + let self_cycles_vec: Vec = graph_clone + .node_indices() + .filter(|n| graph_clone.neighbors(*n).any(|x| x == *n)) + .collect(); + for node in &self_cycles_vec { + while let Some(edge_index) = graph_clone.find_edge(*node, *node) { + graph_clone.remove_edge(edge_index); + } + } + let self_cycles = if self_cycles_vec.is_empty() { + None + } else { + Some(self_cycles_vec) + }; + let strongly_connected_components: Vec> = kosaraju_scc(&graph_clone) + .into_iter() + .filter(|component| component.len() > 1) + .collect(); + SimpleCycleIter { + graph_clone, + scc: strongly_connected_components, + self_cycles, + path: Vec::new(), + blocked: HashSet::new(), + closed: HashSet::new(), + block: HashMap::new(), + stack: Vec::new(), + start_node: NodeIndex::new(std::u32::MAX as usize), + node_map: HashMap::new(), + reverse_node_map: HashMap::new(), + subgraph: StableDiGraph::new(), + } + } +} + +fn unblock( + node: NodeIndex, + blocked: &mut HashSet, + block: &mut HashMap>, +) { + let mut stack: IndexSet = IndexSet::new(); + stack.insert(node); + while let Some(stack_node) = stack.pop() { + if blocked.remove(&stack_node) { + match block.get_mut(&stack_node) { + // stack.update(block[stack_node]): + Some(block_set) => { + block_set.drain().for_each(|n| { + stack.insert(n); + }); + } + // If block doesn't have stack_node treat it as an empty set + // (so no updates to stack) and populate it with an empty + // set. + None => { + block.insert(stack_node, HashSet::new()); + } + } + blocked.remove(&stack_node); + } + } +} + +#[allow(clippy::too_many_arguments)] +fn process_stack( + start_node: NodeIndex, + stack: &mut Vec<(NodeIndex, IndexSet)>, + path: &mut Vec, + closed: &mut HashSet, + blocked: &mut HashSet, + block: &mut HashMap>, + subgraph: &StableDiGraph<(), ()>, + reverse_node_map: &HashMap, +) -> Option> { + while let Some((this_node, neighbors)) = stack.last_mut() { + if let Some(next_node) = neighbors.pop() { + if next_node == start_node { + // Out path in input graph basis + let mut out_path: Vec = Vec::with_capacity(path.len()); + for n in path { + out_path.push(reverse_node_map[n].index()); + closed.insert(*n); + } + return Some(IterNextOutput::Yield(NodeIndices { nodes: out_path })); + } else if blocked.insert(next_node) { + path.push(next_node); + stack.push(( + next_node, + subgraph + .neighbors(next_node) + .collect::>(), + )); + closed.remove(&next_node); + blocked.insert(next_node); + continue; + } + } + if neighbors.is_empty() { + if closed.contains(this_node) { + unblock(*this_node, blocked, block); + } else { + for neighbor in subgraph.neighbors(*this_node) { + let block_neighbor = block.entry(neighbor).or_insert_with(HashSet::new); + block_neighbor.insert(*this_node); + } + } + stack.pop(); + path.pop(); + } + } + None +} + +#[pymethods] +impl SimpleCycleIter { + fn __iter__(slf: PyRef) -> Py { + slf.into() + } + + fn __next__(mut slf: PyRefMut) -> PyResult> { + if slf.self_cycles.is_some() { + let self_cycles = slf.self_cycles.as_mut().unwrap(); + let cycle_node = self_cycles.pop().unwrap(); + if self_cycles.is_empty() { + slf.self_cycles = None; + } + return Ok(IterNextOutput::Yield(NodeIndices { + nodes: vec![cycle_node.index()], + })); + } + // Restore previous state if it exists + let mut stack: Vec<(NodeIndex, IndexSet)> = std::mem::take(&mut slf.stack); + let mut path: Vec = std::mem::take(&mut slf.path); + let mut closed: HashSet = std::mem::take(&mut slf.closed); + let mut blocked: HashSet = std::mem::take(&mut slf.blocked); + let mut block: HashMap> = std::mem::take(&mut slf.block); + let mut subgraph: StableDiGraph<(), ()> = std::mem::take(&mut slf.subgraph); + let mut reverse_node_map: HashMap = + std::mem::take(&mut slf.reverse_node_map); + let mut node_map: HashMap = std::mem::take(&mut slf.node_map); + + if let Some(res) = process_stack( + slf.start_node, + &mut stack, + &mut path, + &mut closed, + &mut blocked, + &mut block, + &subgraph, + &reverse_node_map, + ) { + // Store internal state on yield + slf.stack = stack; + slf.path = path; + slf.closed = closed; + slf.blocked = blocked; + slf.block = block; + slf.subgraph = subgraph; + slf.reverse_node_map = reverse_node_map; + slf.node_map = node_map; + return Ok(res); + } else { + subgraph.remove_node(slf.start_node); + slf.scc + .extend(kosaraju_scc(&subgraph).into_iter().filter_map(|scc| { + if scc.len() > 1 { + let res = scc + .iter() + .map(|n| reverse_node_map[n]) + .collect::>(); + Some(res) + } else { + None + } + })); + } + while let Some(mut scc) = slf.scc.pop() { + let temp = build_subgraph(&slf.graph_clone, &scc); + subgraph = temp.0; + node_map = temp.1; + reverse_node_map = node_map.iter().map(|(k, v)| (*v, *k)).collect(); + // start_node, path, blocked, closed, block and stack all in subgraph basis + slf.start_node = node_map[&scc.pop().unwrap()]; + path = vec![slf.start_node]; + blocked = path.iter().copied().collect(); + // Nodes in cycle all + closed = HashSet::new(); + block = HashMap::new(); + stack = vec![( + slf.start_node, + subgraph + .neighbors(slf.start_node) + .collect::>(), + )]; + if let Some(res) = process_stack( + slf.start_node, + &mut stack, + &mut path, + &mut closed, + &mut blocked, + &mut block, + &subgraph, + &reverse_node_map, + ) { + // Store internal state on yield + slf.stack = stack; + slf.path = path; + slf.closed = closed; + slf.blocked = blocked; + slf.block = block; + slf.subgraph = subgraph; + slf.reverse_node_map = reverse_node_map; + slf.node_map = node_map; + return Ok(res); + } + subgraph.remove_node(slf.start_node); + slf.scc + .extend(kosaraju_scc(&subgraph).into_iter().filter_map(|scc| { + if scc.len() > 1 { + let res = scc + .iter() + .map(|n| reverse_node_map[n]) + .collect::>(); + Some(res) + } else { + None + } + })); + } + Ok(IterNextOutput::Return("Ended")) + } +} diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 4def48ae3..8b2bb5e61 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -14,6 +14,7 @@ mod all_pairs_all_simple_paths; mod core_number; +mod johnson_simple_cycles; use super::{digraph, get_edge_iter_with_weights, graph, weight_callable, InvalidNode, NullGraph}; @@ -126,6 +127,23 @@ pub fn cycle_basis(graph: &graph::PyGraph, root: Option) -> Vec johnson_simple_cycles::SimpleCycleIter { + johnson_simple_cycles::SimpleCycleIter::new(graph) +} + /// Compute the strongly connected components for a directed graph /// /// This function is implemented using Kosaraju's algorithm diff --git a/src/lib.rs b/src/lib.rs index 39ee874d7..a4586c7aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -424,6 +424,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(undirected_gnm_random_graph))?; m.add_wrapped(wrap_pyfunction!(random_geometric_graph))?; m.add_wrapped(wrap_pyfunction!(cycle_basis))?; + m.add_wrapped(wrap_pyfunction!(simple_cycles))?; m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?; m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?; m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?; diff --git a/tests/rustworkx_tests/digraph/test_simple_cycles.py b/tests/rustworkx_tests/digraph/test_simple_cycles.py new file mode 100644 index 000000000..2da8f0a76 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_simple_cycles.py @@ -0,0 +1,69 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx + + +class TestSimpleCycles(unittest.TestCase): + def test_simple_cycles(self): + edges = [(0, 0), (0, 1), (0, 2), (1, 2), (2, 0), (2, 1), (2, 2)] + graph = rustworkx.PyDiGraph() + graph.extend_from_edge_list(edges) + expected = [[0], [0, 1, 2], [0, 2], [1, 2], [2]] + res = list(rustworkx.simple_cycles(graph)) + self.assertEqual(len(res), len(expected)) + for cycle in res: + self.assertIn(sorted(cycle), expected) + + def test_mesh_graph(self): + # Test taken from Table 2 in the Johnson Algorithm paper + # which shows the number of cycles in a complete graph of + # 2 to 9 nodes and the time to calculate it on a s370/168 + # The table in question is a benchmark comparing the runtime + # to tarjan's algorithm, but it gives us a good test with + # a known value (networkX does this too) + num_circuits = [1, 5, 20, 84, 409, 2365, 16064] + for n, c in zip(range(2, 9), num_circuits): + with self.subTest(n=n): + graph = rustworkx.generators.directed_mesh_graph(n) + res = list(rustworkx.simple_cycles(graph)) + self.assertEqual(len(res), c) + + def test_empty_graph(self): + self.assertEqual( + list(rustworkx.simple_cycles(rustworkx.PyDiGraph())), + [], + ) + + def test_figure_1(self): + # This graph tests figured 1 from the Johnson's algorithm paper + for k in range(3, 10): + with self.subTest(k=k): + graph = rustworkx.PyDiGraph() + edge_list = [] + for n in range(2, k + 2): + edge_list.append((1, n)) + edge_list.append((n, k + 2)) + edge_list.append((2 * k + 1, 1)) + for n in range(k + 2, 2 * k + 2): + edge_list.append((n, 2 * k + 2)) + edge_list.append((n, n + 1)) + edge_list.append((2 * k + 3, k + 2)) + for n in range(2 * k + 3, 3 * k + 3): + edge_list.append((2 * k + 2, n)) + edge_list.append((n, 3 * k + 3)) + edge_list.append((3 * k + 3, 2 * k + 2)) + graph.extend_from_edge_list(edge_list) + cycles = list(rustworkx.simple_cycles(graph)) + self.assertEqual(len(cycles), 3 * k)