Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loosen trait constraints and simplify structure for longest_path #1195

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Eq;
use std::hash::Hash;

use hashbrown::HashMap;

use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::visit::{
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
Visitable,
EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNodeIdentifiers, Visitable,
};
use petgraph::Directed;

Expand Down Expand Up @@ -51,7 +53,6 @@ type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
/// # Example
/// ```
/// use petgraph::graph::DiGraph;
/// use petgraph::graph::NodeIndex;
/// use petgraph::Directed;
/// use rustworkx_core::dag_algo::longest_path;
///
Expand All @@ -69,14 +70,10 @@ type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
/// ```
pub fn longest_path<G, F, T, E>(graph: G, mut weight_fn: F) -> LongestPathResult<G, T, E>
where
G: GraphProp<EdgeType = Directed>
+ IntoNodeIdentifiers
+ IntoNeighborsDirected
+ IntoEdgesDirected
+ Visitable
+ GraphBase<NodeId = NodeIndex>,
G: GraphProp<EdgeType = Directed> + IntoNodeIdentifiers + IntoEdgesDirected + Visitable,
F: FnMut(G::EdgeRef) -> Result<T, E>,
T: Num + Zero + PartialOrd + Copy,
<G as GraphBase>::NodeId: Hash + Eq + PartialOrd,
{
let mut path: Vec<NodeId<G>> = Vec::new();
let nodes = match algo::toposort(graph, None) {
Expand All @@ -88,20 +85,20 @@ where
return Ok(Some((path, T::zero())));
}

let mut dist: HashMap<NodeIndex, (T, NodeIndex)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node
let mut dist: HashMap<G::NodeId, (T, G::NodeId)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node

// Iterate over nodes in topological order
for node in nodes {
let parents = graph.edges_directed(node, petgraph::Direction::Incoming);
let mut incoming_path: Vec<(T, NodeIndex)> = Vec::new(); // Stores the distance and the previous node for each parent
let mut incoming_path: Vec<(T, G::NodeId)> = Vec::new(); // Stores the distance and the previous node for each parent
for p_edge in parents {
let p_node = p_edge.source();
let weight: T = weight_fn(p_edge)?;
let length = dist[&p_node].0 + weight;
incoming_path.push((length, p_node));
}
// Determine the maximum distance and corresponding parent node
let max_path: (T, NodeIndex) = incoming_path
let max_path: (T, G::NodeId) = incoming_path
.into_iter()
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.unwrap_or((T::zero(), node)); // If there are no incoming edges, the distance is zero
Expand All @@ -114,7 +111,7 @@ where
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
let mut v = *first;
let mut u: Option<NodeIndex> = None;
let mut u: Option<G::NodeId> = None;
// Backtrack from this node to find the path
while u.map_or(true, |u| u != v) {
path.push(v);
Expand Down
64 changes: 0 additions & 64 deletions src/dag_algo/longest_path.rs

This file was deleted.

57 changes: 51 additions & 6 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
// License for the specific language governing permissions and limitations
// under the License.

mod longest_path;

use super::DictMap;
use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
Expand All @@ -22,6 +20,7 @@ use std::collections::BinaryHeap;
use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};

use rustworkx_core::dag_algo::longest_path as core_longest_path;
use rustworkx_core::traversal::dfs_edges;

use pyo3::exceptions::PyValueError;
Expand All @@ -32,8 +31,54 @@ use pyo3::Python;
use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::stable_graph::EdgeReference;
use petgraph::visit::NodeCount;

use num_traits::{Num, Zero};

/// Calculate the longest path in a directed acyclic graph (DAG).
///
/// This function interfaces with the Python `PyDiGraph` object to compute the longest path
/// using the provided weight function.
///
/// # Arguments
/// * `graph`: Reference to a `PyDiGraph` object.
/// * `weight_fn`: A callable that takes the source node index, target node index, and the weight
/// object and returns the weight of the edge as a `PyResult<T>`.
///
/// # Type Parameters
/// * `F`: Type of the weight function.
/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`.
///
/// # Returns
/// * `PyResult<(Vec<G::NodeId>, T)>` representing the longest path as a sequence of node indices and its total weight.
fn longest_path<F, T>(graph: &digraph::PyDiGraph, mut weight_fn: F) -> PyResult<(Vec<usize>, T)>
where
F: FnMut(usize, usize, &PyObject) -> PyResult<T>,
T: Num + Zero + PartialOrd + Copy,
{
let dag = &graph.graph;

// Create a new weight function that matches the required signature
let edge_cost = |edge_ref: EdgeReference<'_, PyObject>| -> Result<T, PyErr> {
let source = edge_ref.source().index();
let target = edge_ref.target().index();
let weight = edge_ref.weight();
weight_fn(source, target, weight)
};

let (path, path_weight) = match core_longest_path(dag, edge_cost) {
Ok(Some((path, path_weight))) => (
path.into_iter().map(NodeIndex::index).collect(),
path_weight,
),
Ok(None) => return Err(DAGHasCycle::new_err("The graph contains a cycle")),
Err(e) => return Err(e),
};

Ok((path, path_weight))
}

/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards"
/// direction of graph traversal, based on whether the graph is being traved forwards (following
/// the edges) or backward (reversing along edges). The order of returns is (forwards, backwards).
Expand Down Expand Up @@ -82,7 +127,7 @@ pub fn dag_longest_path(
}
};
Ok(NodeIndices {
nodes: longest_path::longest_path(graph, edge_weight_callable)?.0,
nodes: longest_path(graph, edge_weight_callable)?.0,
})
}

Expand Down Expand Up @@ -121,7 +166,7 @@ pub fn dag_longest_path_length(
None => Ok(1),
}
};
let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?;
let (_, path_weight) = longest_path(graph, edge_weight_callable)?;
Ok(path_weight)
}

Expand Down Expand Up @@ -163,7 +208,7 @@ pub fn dag_weighted_longest_path(
Ok(float_res)
};
Ok(NodeIndices {
nodes: longest_path::longest_path(graph, edge_weight_callable)?.0,
nodes: longest_path(graph, edge_weight_callable)?.0,
})
}

Expand Down Expand Up @@ -204,7 +249,7 @@ pub fn dag_weighted_longest_path_length(
}
Ok(float_res)
};
let (_, path_weight) = longest_path::longest_path(graph, edge_weight_callable)?;
let (_, path_weight) = longest_path(graph, edge_weight_callable)?;
Ok(path_weight)
}

Expand Down