diff --git a/src/hugr.rs b/src/hugr.rs index 8a6302428..2cb9ecf44 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -2,27 +2,25 @@ mod hugrmut; +pub mod rewrite; pub mod serialize; pub mod typecheck; pub mod validate; pub mod view; -use std::collections::HashMap; - pub(crate) use self::hugrmut::HugrMut; pub use self::validate::ValidationError; use derive_more::From; +pub use rewrite::{Replace, ReplaceError, Rewrite, SimpleReplacement, SimpleReplacementError}; + use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle}; use portgraph::multiportgraph::MultiPortGraph; -use portgraph::{Hierarchy, LinkView, NodeIndex, PortView, UnmanagedDenseMap}; +use portgraph::{Hierarchy, LinkView, PortView, UnmanagedDenseMap}; use thiserror::Error; pub use self::view::HugrView; -use crate::ops::tag::OpTag; -use crate::ops::{OpName, OpTrait, OpType}; -use crate::replacement::{SimpleReplacement, SimpleReplacementError}; -use crate::rewrite::{Rewrite, RewriteError}; +use crate::ops::{OpName, OpType}; use crate::types::EdgeKind; /// The Hugr data structure. @@ -81,187 +79,9 @@ pub struct Wire(Node, usize); /// Public API for HUGRs. impl Hugr { - /// Apply a simple replacement operation to the HUGR. - pub fn apply_simple_replacement( - &mut self, - r: SimpleReplacement, - ) -> Result<(), SimpleReplacementError> { - // 1. Check the parent node exists and is a DFG node. - if self.get_optype(r.parent).tag() != OpTag::Dfg { - return Err(SimpleReplacementError::InvalidParentNode()); - } - // 2. Check that all the to-be-removed nodes are children of it and are leaves. - for node in &r.removal { - if self.hierarchy.parent(node.index) != Some(r.parent.index) - || self.hierarchy.has_children(node.index) - { - return Err(SimpleReplacementError::InvalidRemovedNode()); - } - } - // 3. Do the replacement. - // 3.1. Add copies of all replacement nodes and edges to self. Exclude Input/Output nodes. - // Create map from old NodeIndex (in r.replacement) to new NodeIndex (in self). - let mut index_map: HashMap = HashMap::new(); - let replacement_nodes = r - .replacement - .children(r.replacement.root()) - .collect::>(); - // slice of nodes omitting Input and Output: - let replacement_inner_nodes = &replacement_nodes[2..]; - for &node in replacement_inner_nodes { - // Check there are no const inputs. - if !r - .replacement - .get_optype(node) - .signature() - .const_input - .is_empty() - { - return Err(SimpleReplacementError::InvalidReplacementNode()); - } - } - let self_output_node_index = self.children(r.parent).nth(1).unwrap(); - let replacement_output_node = *replacement_nodes.get(1).unwrap(); - for &node in replacement_inner_nodes { - // Add the nodes. - let op: &OpType = r.replacement.get_optype(node); - let new_node_index = self - .add_op_after(self_output_node_index, op.clone()) - .unwrap(); - index_map.insert(node.index, new_node_index.index); - } - // Add edges between all newly added nodes matching those in replacement. - // TODO This will probably change when implicit copies are implemented. - for &node in replacement_inner_nodes { - let new_node_index = index_map.get(&node.index).unwrap(); - for node_successor in r.replacement.output_neighbours(node) { - if r.replacement.get_optype(node_successor).tag() != OpTag::Output { - let new_node_successor_index = index_map.get(&node_successor.index).unwrap(); - for connection in r - .replacement - .graph - .get_connections(node.index, node_successor.index) - { - let src_offset = r - .replacement - .graph - .port_offset(connection.0) - .unwrap() - .index(); - let tgt_offset = r - .replacement - .graph - .port_offset(connection.1) - .unwrap() - .index(); - self.graph - .link_nodes( - *new_node_index, - src_offset, - *new_node_successor_index, - tgt_offset, - ) - .ok(); - } - } - } - } - // 3.2. For each p = r.nu_inp[q] such that q is not an Output port, add an edge from the - // predecessor of p to (the new copy of) q. - for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &r.nu_inp { - if r.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { - let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap(); - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) - let rem_inp_port_index = self - .graph - .port_index(rem_inp_node.index, rem_inp_port.offset) - .unwrap(); - let rem_inp_predecessor_port_index = - self.graph.port_link(rem_inp_port_index).unwrap().port(); - let new_inp_port_index = self - .graph - .port_index(*new_inp_node_index, rep_inp_port.offset) - .unwrap(); - self.graph.unlink_port(rem_inp_predecessor_port_index); - self.graph - .link_ports(rem_inp_predecessor_port_index, new_inp_port_index) - .ok(); - } - } - // 3.3. For each q = r.nu_out[p] such that the predecessor of q is not an Input port, add an - // edge from (the new copy of) the predecessor of q to p. - for ((rem_out_node, rem_out_port), rep_out_port) in &r.nu_out { - let rem_out_port_index = self - .graph - .port_index(rem_out_node.index, rem_out_port.offset) - .unwrap(); - let rep_out_port_index = r - .replacement - .graph - .port_index(replacement_output_node.index, rep_out_port.offset) - .unwrap(); - let rep_out_predecessor_port_index = - r.replacement.graph.port_link(rep_out_port_index).unwrap(); - let rep_out_predecessor_node_index = r - .replacement - .graph - .port_node(rep_out_predecessor_port_index) - .unwrap(); - if r.replacement - .get_optype(rep_out_predecessor_node_index.into()) - .tag() - != OpTag::Input - { - let rep_out_predecessor_port_offset = r - .replacement - .graph - .port_offset(rep_out_predecessor_port_index) - .unwrap(); - let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap(); - let new_out_port_index = self - .graph - .port_index(*new_out_node_index, rep_out_predecessor_port_offset) - .unwrap(); - self.graph.unlink_port(rem_out_port_index); - self.graph - .link_ports(new_out_port_index, rem_out_port_index) - .ok(); - } - } - // 3.4. For each q = r.nu_out[p1], p0 = r.nu_inp[q], add an edge from the predecessor of p0 - // to p1. - for ((rem_out_node, rem_out_port), &rep_out_port) in &r.nu_out { - let rem_inp_nodeport = r.nu_inp.get(&(replacement_output_node, rep_out_port)); - if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { - // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): - let rem_inp_port_index = self - .graph - .port_index(rem_inp_node.index, rem_inp_port.offset) - .unwrap(); - let rem_inp_predecessor_port_index = - self.graph.port_link(rem_inp_port_index).unwrap().port(); - let rem_out_port_index = self - .graph - .port_index(rem_out_node.index, rem_out_port.offset) - .unwrap(); - self.graph.unlink_port(rem_inp_port_index); - self.graph.unlink_port(rem_out_port_index); - self.graph - .link_ports(rem_inp_predecessor_port_index, rem_out_port_index) - .ok(); - } - } - // 3.5. Remove all nodes in r.removal and edges between them. - for node in &r.removal { - self.graph.remove_node(node.index); - self.hierarchy.remove(node.index); - } - Ok(()) - } - /// Applies a rewrite to the graph. - pub fn apply_rewrite(self, _rewrite: Rewrite) -> Result<(), RewriteError> { - unimplemented!() + pub fn apply_rewrite(&mut self, rw: impl Rewrite) -> Result<(), E> { + rw.apply(self) } /// Return dot string showing underlying graph and hierarchy side by side. diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs new file mode 100644 index 000000000..dda3cde8d --- /dev/null +++ b/src/hugr/rewrite.rs @@ -0,0 +1,62 @@ +//! Rewrite operations on the HUGR - replacement, outlining, etc. + +pub mod replace; +pub mod simple_replace; +use std::mem; + +use crate::Hugr; +pub use replace::{OpenHugr, Replace, ReplaceError}; +pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; + +/// An operation that can be applied to mutate a Hugr +pub trait Rewrite { + /// The type of Error with which this Rewrite may fail + type Error: std::error::Error; + + /// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err. + /// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Checks whether the rewrite would succeed on the specified Hugr. + /// If this call succeeds, [self.apply] should also succeed on the same `h` + /// If this calls fails, [self.apply] would fail with the same error. + fn verify(&self, h: &Hugr) -> Result<(), Self::Error>; + + /// Mutate the specified Hugr, or fail with an error. + /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned. + /// See also [self.verify] + /// # Panics + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, + /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` + /// being preferred. + fn apply(self, h: &mut Hugr) -> Result<(), Self::Error>; +} + +/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) +pub struct Transactional { + underlying: R, +} + +// Note we might like to constrain R to Rewrite but this +// is not yet supported, https://github.com/rust-lang/rust/issues/92827 +impl Rewrite for Transactional { + type Error = R::Error; + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &Hugr) -> Result<(), Self::Error> { + self.underlying.verify(h) + } + + fn apply(self, h: &mut Hugr) -> Result<(), Self::Error> { + if R::UNCHANGED_ON_FAILURE { + return self.underlying.apply(h); + } + let backup = h.clone(); + let r = self.underlying.apply(h); + if r.is_err() { + // drop the old h, it was undefined + let _ = mem::replace(h, backup); + } + r + } +} diff --git a/src/rewrite/rewrite.rs b/src/hugr/rewrite/replace.rs similarity index 81% rename from src/rewrite/rewrite.rs rename to src/hugr/rewrite/replace.rs index 4a285ae53..fe25f6f52 100644 --- a/src/rewrite/rewrite.rs +++ b/src/hugr/rewrite/replace.rs @@ -1,5 +1,6 @@ #![allow(missing_docs)] -//! Rewrite operations on Hugr graphs. +//! Replace operations on Hugr graphs. This is a nonfunctional +//! dummy implementation just to demonstrate design principles. use std::collections::HashMap; @@ -7,6 +8,7 @@ use portgraph::substitute::OpenGraph; use portgraph::{NodeIndex, PortIndex}; use thiserror::Error; +use super::Rewrite; use crate::Hugr; /// A subset of the nodes in a graph, and the ports that it is connected to. @@ -77,7 +79,7 @@ pub type ParentsMap = HashMap; /// Includes the new weights for the nodes in the replacement graph. #[derive(Debug, Clone)] #[allow(unused)] -pub struct Rewrite { +pub struct Replace { /// The subgraph to be replaced. subgraph: BoundedSubgraph, /// The replacement graph. @@ -86,7 +88,7 @@ pub struct Rewrite { parents: ParentsMap, } -impl Rewrite { +impl Replace { /// Creates a new rewrite operation. pub fn new( subgraph: BoundedSubgraph, @@ -114,30 +116,42 @@ impl Rewrite { ) } + pub fn verify_convexity(&self) -> Result<(), ReplaceError> { + unimplemented!() + } + + pub fn verify_boundaries(&self) -> Result<(), ReplaceError> { + unimplemented!() + } +} + +impl Rewrite for Replace { + type Error = ReplaceError; + const UNCHANGED_ON_FAILURE: bool = false; + /// Checks that the rewrite is valid. /// /// This includes having a convex subgraph (TODO: include definition), and /// having matching numbers of ports on the boundaries. - pub fn verify(&self) -> Result<(), RewriteError> { + /// TODO not clear this implementation really provides much guarantee about [self.apply] + /// but this class is not really working anyway. + fn verify(&self, _h: &Hugr) -> Result<(), ReplaceError> { self.verify_convexity()?; self.verify_boundaries()?; Ok(()) } - pub fn verify_convexity(&self) -> Result<(), RewriteError> { - todo!() - } - - pub fn verify_boundaries(&self) -> Result<(), RewriteError> { - todo!() + /// Performs a Replace operation on the graph. + fn apply(self, _h: &mut Hugr) -> Result<(), ReplaceError> { + unimplemented!() } } /// Error generated when a rewrite fails. #[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum RewriteError { - /// The rewrite failed because the boundary defined by the - /// [`Rewrite`] could not be matched to the dangling ports of the +pub enum ReplaceError { + /// The replacement failed because the boundary defined by the + /// [`Replace`] could not be matched to the dangling ports of the /// [`OpenHugr`]. #[error("The boundary defined by the rewrite could not be matched to the dangling ports of the OpenHugr")] BoundarySize(#[source] portgraph::substitute::RewriteError), @@ -152,7 +166,7 @@ pub enum RewriteError { NotConvex(), } -impl From for RewriteError { +impl From for ReplaceError { fn from(e: portgraph::substitute::RewriteError) -> Self { match e { portgraph::substitute::RewriteError::BoundarySize => Self::BoundarySize(e), diff --git a/src/replacement/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs similarity index 59% rename from src/replacement/simple_replace.rs rename to src/hugr/rewrite/simple_replace.rs index 62cad3c87..8b01c70b5 100644 --- a/src/replacement/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -2,7 +2,14 @@ use std::collections::{HashMap, HashSet}; -use crate::{hugr::Node, Hugr, Port}; +use portgraph::{LinkView, NodeIndex, PortView}; + +use crate::hugr::{HugrMut, HugrView}; +use crate::{ + hugr::{Node, Rewrite}, + ops::{tag::OpTag, OpTrait, OpType}, + Hugr, Port, +}; use thiserror::Error; /// Specification of a simple replacement operation. @@ -41,6 +48,191 @@ impl SimpleReplacement { } } +impl Rewrite for SimpleReplacement { + type Error = SimpleReplacementError; + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, _h: &Hugr) -> Result<(), SimpleReplacementError> { + unimplemented!() + } + + fn apply(self, h: &mut Hugr) -> Result<(), SimpleReplacementError> { + // 1. Check the parent node exists and is a DFG node. + if h.get_optype(self.parent).tag() != OpTag::Dfg { + return Err(SimpleReplacementError::InvalidParentNode()); + } + // 2. Check that all the to-be-removed nodes are children of it and are leaves. + for node in &self.removal { + if h.hierarchy.parent(node.index) != Some(self.parent.index) + || h.hierarchy.has_children(node.index) + { + return Err(SimpleReplacementError::InvalidRemovedNode()); + } + } + // 3. Do the replacement. + // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. + // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). + let mut index_map: HashMap = HashMap::new(); + let replacement_nodes = self + .replacement + .children(self.replacement.root()) + .collect::>(); + // slice of nodes omitting Input and Output: + let replacement_inner_nodes = &replacement_nodes[2..]; + for &node in replacement_inner_nodes { + // Check there are no const inputs. + if !self + .replacement + .get_optype(node) + .signature() + .const_input + .is_empty() + { + return Err(SimpleReplacementError::InvalidReplacementNode()); + } + } + let self_output_node_index = h.children(self.parent).nth(1).unwrap(); + let replacement_output_node = *replacement_nodes.get(1).unwrap(); + for &node in replacement_inner_nodes { + // Add the nodes. + let op: &OpType = self.replacement.get_optype(node); + let new_node_index = h.add_op_after(self_output_node_index, op.clone()).unwrap(); + index_map.insert(node.index, new_node_index.index); + } + // Add edges between all newly added nodes matching those in replacement. + // TODO This will probably change when implicit copies are implemented. + for &node in replacement_inner_nodes { + let new_node_index = index_map.get(&node.index).unwrap(); + for node_successor in self.replacement.output_neighbours(node) { + if self.replacement.get_optype(node_successor).tag() != OpTag::Output { + let new_node_successor_index = index_map.get(&node_successor.index).unwrap(); + for connection in self + .replacement + .graph + .get_connections(node.index, node_successor.index) + { + let src_offset = self + .replacement + .graph + .port_offset(connection.0) + .unwrap() + .index(); + let tgt_offset = self + .replacement + .graph + .port_offset(connection.1) + .unwrap() + .index(); + h.graph + .link_nodes( + *new_node_index, + src_offset, + *new_node_successor_index, + tgt_offset, + ) + .unwrap(); + } + } + } + } + // 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the + // predecessor of p to (the new copy of) q. + for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp { + if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { + let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap(); + // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) + let rem_inp_port_index = h + .graph + .port_index(rem_inp_node.index, rem_inp_port.offset) + .unwrap(); + let rem_inp_predecessor_port_index = + h.graph.port_link(rem_inp_port_index).unwrap().port(); + let new_inp_port_index = h + .graph + .port_index(*new_inp_node_index, rep_inp_port.offset) + .unwrap(); + h.graph.unlink_port(rem_inp_predecessor_port_index); + h.graph + .link_ports(rem_inp_predecessor_port_index, new_inp_port_index) + .unwrap(); + } + } + // 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an + // edge from (the new copy of) the predecessor of q to p. + for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { + let rem_out_port_index = h + .graph + .port_index(rem_out_node.index, rem_out_port.offset) + .unwrap(); + let rep_out_port_index = self + .replacement + .graph + .port_index(replacement_output_node.index, rep_out_port.offset) + .unwrap(); + let rep_out_predecessor_port_index = self + .replacement + .graph + .port_link(rep_out_port_index) + .unwrap(); + let rep_out_predecessor_node_index = self + .replacement + .graph + .port_node(rep_out_predecessor_port_index) + .unwrap(); + if self + .replacement + .get_optype(rep_out_predecessor_node_index.into()) + .tag() + != OpTag::Input + { + let rep_out_predecessor_port_offset = self + .replacement + .graph + .port_offset(rep_out_predecessor_port_index) + .unwrap(); + let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap(); + let new_out_port_index = h + .graph + .port_index(*new_out_node_index, rep_out_predecessor_port_offset) + .unwrap(); + h.graph.unlink_port(rem_out_port_index); + h.graph + .link_ports(new_out_port_index, rem_out_port_index) + .unwrap(); + } + } + // 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 + // to p1. + for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out { + let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); + if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { + // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): + let rem_inp_port_index = h + .graph + .port_index(rem_inp_node.index, rem_inp_port.offset) + .unwrap(); + let rem_inp_predecessor_port_index = + h.graph.port_link(rem_inp_port_index).unwrap().port(); + let rem_out_port_index = h + .graph + .port_index(rem_out_node.index, rem_out_port.offset) + .unwrap(); + h.graph.unlink_port(rem_inp_port_index); + h.graph.unlink_port(rem_out_port_index); + h.graph + .link_ports(rem_inp_predecessor_port_index, rem_out_port_index) + .unwrap(); + } + } + // 3.5. Remove all nodes in self.removal and edges between them. + for node in &self.removal { + h.graph.remove_node(node.index); + h.hierarchy.remove(node.index); + } + Ok(()) + } +} + /// Error from a [`SimpleReplacement`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum SimpleReplacementError { @@ -226,7 +418,7 @@ mod test { nu_inp, nu_out, }; - h.apply_simple_replacement(r).ok(); + h.apply_rewrite(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ // ┤ H ├┤ H ├──■── @@ -303,7 +495,7 @@ mod test { nu_inp, nu_out, }; - h.apply_simple_replacement(r).ok(); + h.apply_rewrite(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ // ┤ H ├┤ H ├ diff --git a/src/lib.rs b/src/lib.rs index 143ea29ff..73e441c9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,13 +15,11 @@ pub mod extensions; pub mod hugr; pub mod macros; pub mod ops; -pub mod replacement; pub mod resource; -pub mod rewrite; pub mod types; mod utils; -pub use crate::hugr::{Direction, Hugr, Node, Port, Wire}; -pub use crate::replacement::SimpleReplacement; +pub use crate::hugr::{ + Direction, Hugr, Node, Port, Replace, ReplaceError, SimpleReplacement, Wire, +}; pub use crate::resource::Resource; -pub use crate::rewrite::{Rewrite, RewriteError}; diff --git a/src/rewrite.rs b/src/rewrite.rs deleted file mode 100644 index dfc2182bf..000000000 --- a/src/rewrite.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! Pattern matching and rewrite operations on the HUGR. - -pub mod pattern; -#[allow(clippy::module_inception)] // TODO: Rename? -pub mod rewrite; - -pub use rewrite::{OpenHugr, Rewrite, RewriteError}; diff --git a/src/rewrite/pattern.rs b/src/rewrite/pattern.rs deleted file mode 100644 index 85a3b8a4b..000000000 --- a/src/rewrite/pattern.rs +++ /dev/null @@ -1,10 +0,0 @@ -#![allow(missing_docs)] -//! Pattern matching operations on a HUGR. - -#[cfg(feature = "pyo3")] -use pyo3::prelude::*; - -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "pyo3", pyclass)] -#[non_exhaustive] -pub struct Pattern {}