diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 0e1a0b9e0..b29389a73 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -9,13 +9,15 @@ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. + +use std::cmp::Eq; +use std::hash::Hash; + use hashbrown::HashMap; use petgraph::algo; -use petgraph::graph::NodeIndex; use petgraph::visit::{ - EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers, - Visitable, + EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNodeIdentifiers, Visitable, }; use petgraph::Directed; @@ -51,7 +53,6 @@ type LongestPathResult = Result>, T)>, E>; /// # Example /// ``` /// use petgraph::graph::DiGraph; -/// use petgraph::graph::NodeIndex; /// use petgraph::Directed; /// use rustworkx_core::dag_algo::longest_path; /// @@ -69,14 +70,10 @@ type LongestPathResult = Result>, T)>, E>; /// ``` pub fn longest_path(graph: G, mut weight_fn: F) -> LongestPathResult where - G: GraphProp - + IntoNodeIdentifiers - + IntoNeighborsDirected - + IntoEdgesDirected - + Visitable - + GraphBase, + G: GraphProp + IntoNodeIdentifiers + IntoEdgesDirected + Visitable, F: FnMut(G::EdgeRef) -> Result, T: Num + Zero + PartialOrd + Copy, + ::NodeId: Hash + Eq + PartialOrd, { let mut path: Vec> = Vec::new(); let nodes = match algo::toposort(graph, None) { @@ -88,12 +85,12 @@ where return Ok(Some((path, T::zero()))); } - let mut dist: HashMap = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node + let mut dist: HashMap = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node // Iterate over nodes in topological order for node in nodes { let parents = graph.edges_directed(node, petgraph::Direction::Incoming); - let mut incoming_path: Vec<(T, NodeIndex)> = Vec::new(); // Stores the distance and the previous node for each parent + let mut incoming_path: Vec<(T, G::NodeId)> = Vec::new(); // Stores the distance and the previous node for each parent for p_edge in parents { let p_node = p_edge.source(); let weight: T = weight_fn(p_edge)?; @@ -101,7 +98,7 @@ where incoming_path.push((length, p_node)); } // Determine the maximum distance and corresponding parent node - let max_path: (T, NodeIndex) = incoming_path + let max_path: (T, G::NodeId) = incoming_path .into_iter() .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()) .unwrap_or((T::zero(), node)); // If there are no incoming edges, the distance is zero @@ -114,7 +111,7 @@ where .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) .unwrap(); let mut v = *first; - let mut u: Option = None; + let mut u: Option = None; // Backtrack from this node to find the path while u.map_or(true, |u| u != v) { path.push(v); diff --git a/src/dag_algo/longest_path.rs b/src/dag_algo/longest_path.rs deleted file mode 100644 index bdb1cf91a..000000000 --- a/src/dag_algo/longest_path.rs +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -use crate::{digraph, DAGHasCycle}; -use rustworkx_core::dag_algo::longest_path as core_longest_path; - -use petgraph::stable_graph::{EdgeReference, NodeIndex}; -use petgraph::visit::EdgeRef; - -use pyo3::prelude::*; - -use num_traits::{Num, Zero}; - -/// Calculate the longest path in a directed acyclic graph (DAG). -/// -/// This function interfaces with the Python `PyDiGraph` object to compute the longest path -/// using the provided weight function. -/// -/// # Arguments -/// * `graph`: Reference to a `PyDiGraph` object. -/// * `weight_fn`: A callable that takes the source node index, target node index, and the weight -/// object and returns the weight of the edge as a `PyResult`. -/// -/// # Type Parameters -/// * `F`: Type of the weight function. -/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`. -/// -/// # Returns -/// * `PyResult<(Vec, T)>` representing the longest path as a sequence of node indices and its total weight. -pub fn longest_path(graph: &digraph::PyDiGraph, mut weight_fn: F) -> PyResult<(Vec, T)> -where - F: FnMut(usize, usize, &PyObject) -> PyResult, - T: Num + Zero + PartialOrd + Copy, -{ - let dag = &graph.graph; - - // Create a new weight function that matches the required signature - let edge_cost = |edge_ref: EdgeReference<'_, PyObject>| -> Result { - let source = edge_ref.source().index(); - let target = edge_ref.target().index(); - let weight = edge_ref.weight(); - weight_fn(source, target, weight) - }; - - let (path, path_weight) = match core_longest_path(dag, edge_cost) { - Ok(Some((path, path_weight))) => ( - path.into_iter().map(NodeIndex::index).collect(), - path_weight, - ), - Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")), - Err(e) => return Err(e), - }; - - Ok((path, path_weight)) -} diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 854c91b5c..206fa9b45 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -10,8 +10,6 @@ // License for the specific language governing permissions and limitations // under the License. -mod longest_path; - use super::DictMap; use hashbrown::{HashMap, HashSet}; use indexmap::IndexSet; @@ -22,6 +20,7 @@ use std::collections::BinaryHeap; use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; +use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::traversal::dfs_edges; use pyo3::exceptions::PyValueError; @@ -32,8 +31,54 @@ use pyo3::Python; use petgraph::algo; use petgraph::graph::NodeIndex; use petgraph::prelude::*; +use petgraph::stable_graph::EdgeReference; use petgraph::visit::NodeCount; +use num_traits::{Num, Zero}; + +/// Calculate the longest path in a directed acyclic graph (DAG). +/// +/// This function interfaces with the Python `PyDiGraph` object to compute the longest path +/// using the provided weight function. +/// +/// # Arguments +/// * `graph`: Reference to a `PyDiGraph` object. +/// * `weight_fn`: A callable that takes the source node index, target node index, and the weight +/// object and returns the weight of the edge as a `PyResult`. +/// +/// # Type Parameters +/// * `F`: Type of the weight function. +/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`. +/// +/// # Returns +/// * `PyResult<(Vec, T)>` representing the longest path as a sequence of node indices and its total weight. +fn longest_path(graph: &digraph::PyDiGraph, mut weight_fn: F) -> PyResult<(Vec, T)> +where + F: FnMut(usize, usize, &PyObject) -> PyResult, + T: Num + Zero + PartialOrd + Copy, +{ + let dag = &graph.graph; + + // Create a new weight function that matches the required signature + let edge_cost = |edge_ref: EdgeReference<'_, PyObject>| -> Result { + let source = edge_ref.source().index(); + let target = edge_ref.target().index(); + let weight = edge_ref.weight(); + weight_fn(source, target, weight) + }; + + let (path, path_weight) = match core_longest_path(dag, edge_cost) { + Ok(Some((path, path_weight))) => ( + path.into_iter().map(NodeIndex::index).collect(), + path_weight, + ), + Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")), + Err(e) => return Err(e), + }; + + Ok((path, path_weight)) +} + /// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards" /// direction of graph traversal, based on whether the graph is being traved forwards (following /// the edges) or backward (reversing along edges). The order of returns is (forwards, backwards). @@ -82,7 +127,7 @@ pub fn dag_longest_path( } }; Ok(NodeIndices { - nodes: longest_path::longest_path(graph, edge_weight_callable)?.0, + nodes: longest_path(graph, edge_weight_callable)?.0, }) } @@ -121,7 +166,7 @@ pub fn dag_longest_path_length( None => Ok(1), } }; - let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?; + let (_, path_weight) = longest_path(graph, edge_weight_callable)?; Ok(path_weight) } @@ -163,7 +208,7 @@ pub fn dag_weighted_longest_path( Ok(float_res) }; Ok(NodeIndices { - nodes: longest_path::longest_path(graph, edge_weight_callable)?.0, + nodes: longest_path(graph, edge_weight_callable)?.0, }) } @@ -204,7 +249,7 @@ pub fn dag_weighted_longest_path_length( } Ok(float_res) }; - let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?; + let (_, path_weight) = longest_path(graph, edge_weight_callable)?; Ok(path_weight) }