Skip to content

Commit

Permalink
Refactor: move rewrite inside hugr, Rewrite -> Replace implem…
Browse files Browse the repository at this point in the history
…enting new 'Rewrite' trait (#119)

* Remove Pattern and Pattern.rs
* Move rewrite module to hugr/rewrite
* Rename old `Rewrite` struct to `Replace` but leave as skeletal
* Add `Rewrite` trait, (parameterized) Hugr::apply_rewrite dispatches to that
   * Associated `Error` type and `unchanged_on_failure: bool`
* unchanged_on_failure as trait associated constant
* Drive-by: simple_replace.rs: change ".ok();"s to unwrap
  • Loading branch information
acl-cqc authored Jun 19, 2023
1 parent 344ef0c commit 96cac0e
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 226 deletions.
194 changes: 7 additions & 187 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@

mod hugrmut;

pub mod rewrite;
pub mod serialize;
pub mod typecheck;
pub mod validate;
pub mod view;

use std::collections::HashMap;

pub(crate) use self::hugrmut::HugrMut;
pub use self::validate::ValidationError;

use derive_more::From;
pub use rewrite::{Replace, ReplaceError, Rewrite, SimpleReplacement, SimpleReplacementError};

use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
use portgraph::multiportgraph::MultiPortGraph;
use portgraph::{Hierarchy, LinkView, NodeIndex, PortView, UnmanagedDenseMap};
use portgraph::{Hierarchy, LinkView, PortView, UnmanagedDenseMap};
use thiserror::Error;

pub use self::view::HugrView;
use crate::ops::tag::OpTag;
use crate::ops::{OpName, OpTrait, OpType};
use crate::replacement::{SimpleReplacement, SimpleReplacementError};
use crate::rewrite::{Rewrite, RewriteError};
use crate::ops::{OpName, OpType};
use crate::types::EdgeKind;

/// The Hugr data structure.
Expand Down Expand Up @@ -81,187 +79,9 @@ pub struct Wire(Node, usize);

/// Public API for HUGRs.
impl Hugr {
/// Apply a simple replacement operation to the HUGR.
pub fn apply_simple_replacement(
&mut self,
r: SimpleReplacement,
) -> Result<(), SimpleReplacementError> {
// 1. Check the parent node exists and is a DFG node.
if self.get_optype(r.parent).tag() != OpTag::Dfg {
return Err(SimpleReplacementError::InvalidParentNode());
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in &r.removal {
if self.hierarchy.parent(node.index) != Some(r.parent.index)
|| self.hierarchy.has_children(node.index)
{
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}
// 3. Do the replacement.
// 3.1. Add copies of all replacement nodes and edges to self. Exclude Input/Output nodes.
// Create map from old NodeIndex (in r.replacement) to new NodeIndex (in self).
let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let replacement_nodes = r
.replacement
.children(r.replacement.root())
.collect::<Vec<Node>>();
// slice of nodes omitting Input and Output:
let replacement_inner_nodes = &replacement_nodes[2..];
for &node in replacement_inner_nodes {
// Check there are no const inputs.
if !r
.replacement
.get_optype(node)
.signature()
.const_input
.is_empty()
{
return Err(SimpleReplacementError::InvalidReplacementNode());
}
}
let self_output_node_index = self.children(r.parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
let op: &OpType = r.replacement.get_optype(node);
let new_node_index = self
.add_op_after(self_output_node_index, op.clone())
.unwrap();
index_map.insert(node.index, new_node_index.index);
}
// Add edges between all newly added nodes matching those in replacement.
// TODO This will probably change when implicit copies are implemented.
for &node in replacement_inner_nodes {
let new_node_index = index_map.get(&node.index).unwrap();
for node_successor in r.replacement.output_neighbours(node) {
if r.replacement.get_optype(node_successor).tag() != OpTag::Output {
let new_node_successor_index = index_map.get(&node_successor.index).unwrap();
for connection in r
.replacement
.graph
.get_connections(node.index, node_successor.index)
{
let src_offset = r
.replacement
.graph
.port_offset(connection.0)
.unwrap()
.index();
let tgt_offset = r
.replacement
.graph
.port_offset(connection.1)
.unwrap()
.index();
self.graph
.link_nodes(
*new_node_index,
src_offset,
*new_node_successor_index,
tgt_offset,
)
.ok();
}
}
}
}
// 3.2. For each p = r.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &r.nu_inp {
if r.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap();
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let rem_inp_port_index = self
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_port_index =
self.graph.port_link(rem_inp_port_index).unwrap().port();
let new_inp_port_index = self
.graph
.port_index(*new_inp_node_index, rep_inp_port.offset)
.unwrap();
self.graph.unlink_port(rem_inp_predecessor_port_index);
self.graph
.link_ports(rem_inp_predecessor_port_index, new_inp_port_index)
.ok();
}
}
// 3.3. For each q = r.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &r.nu_out {
let rem_out_port_index = self
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
let rep_out_port_index = r
.replacement
.graph
.port_index(replacement_output_node.index, rep_out_port.offset)
.unwrap();
let rep_out_predecessor_port_index =
r.replacement.graph.port_link(rep_out_port_index).unwrap();
let rep_out_predecessor_node_index = r
.replacement
.graph
.port_node(rep_out_predecessor_port_index)
.unwrap();
if r.replacement
.get_optype(rep_out_predecessor_node_index.into())
.tag()
!= OpTag::Input
{
let rep_out_predecessor_port_offset = r
.replacement
.graph
.port_offset(rep_out_predecessor_port_index)
.unwrap();
let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap();
let new_out_port_index = self
.graph
.port_index(*new_out_node_index, rep_out_predecessor_port_offset)
.unwrap();
self.graph.unlink_port(rem_out_port_index);
self.graph
.link_ports(new_out_port_index, rem_out_port_index)
.ok();
}
}
// 3.4. For each q = r.nu_out[p1], p0 = r.nu_inp[q], add an edge from the predecessor of p0
// to p1.
for ((rem_out_node, rem_out_port), &rep_out_port) in &r.nu_out {
let rem_inp_nodeport = r.nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let rem_inp_port_index = self
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_port_index =
self.graph.port_link(rem_inp_port_index).unwrap().port();
let rem_out_port_index = self
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
self.graph.unlink_port(rem_inp_port_index);
self.graph.unlink_port(rem_out_port_index);
self.graph
.link_ports(rem_inp_predecessor_port_index, rem_out_port_index)
.ok();
}
}
// 3.5. Remove all nodes in r.removal and edges between them.
for node in &r.removal {
self.graph.remove_node(node.index);
self.hierarchy.remove(node.index);
}
Ok(())
}

/// Applies a rewrite to the graph.
pub fn apply_rewrite(self, _rewrite: Rewrite) -> Result<(), RewriteError> {
unimplemented!()
pub fn apply_rewrite<E>(&mut self, rw: impl Rewrite<Error = E>) -> Result<(), E> {
rw.apply(self)
}

/// Return dot string showing underlying graph and hierarchy side by side.
Expand Down
62 changes: 62 additions & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod replace;
pub mod simple_replace;
use std::mem;

use crate::Hugr;
pub use replace::{OpenHugr, Replace, ReplaceError};
pub use simple_replace::{SimpleReplacement, SimpleReplacementError};

/// An operation that can be applied to mutate a Hugr
pub trait Rewrite {
/// The type of Error with which this Rewrite may fail
type Error: std::error::Error;

/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
const UNCHANGED_ON_FAILURE: bool;

/// Checks whether the rewrite would succeed on the specified Hugr.
/// If this call succeeds, [self.apply] should also succeed on the same `h`
/// If this calls fails, [self.apply] would fail with the same error.
fn verify(&self, h: &Hugr) -> Result<(), Self::Error>;

/// Mutate the specified Hugr, or fail with an error.
/// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned.
/// See also [self.verify]
/// # Panics
/// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is,
/// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())`
/// being preferred.
fn apply(self, h: &mut Hugr) -> Result<(), Self::Error>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
pub struct Transactional<R> {
underlying: R,
}

// Note we might like to constrain R to Rewrite<unchanged_on_failure=false> but this
// is not yet supported, https://github.com/rust-lang/rust/issues/92827
impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &Hugr) -> Result<(), Self::Error> {
self.underlying.verify(h)
}

fn apply(self, h: &mut Hugr) -> Result<(), Self::Error> {
if R::UNCHANGED_ON_FAILURE {
return self.underlying.apply(h);
}
let backup = h.clone();
let r = self.underlying.apply(h);
if r.is_err() {
// drop the old h, it was undefined
let _ = mem::replace(h, backup);
}
r
}
}
42 changes: 28 additions & 14 deletions src/rewrite/rewrite.rs → src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#![allow(missing_docs)]
//! Rewrite operations on Hugr graphs.
//! Replace operations on Hugr graphs. This is a nonfunctional
//! dummy implementation just to demonstrate design principles.

use std::collections::HashMap;

use portgraph::substitute::OpenGraph;
use portgraph::{NodeIndex, PortIndex};
use thiserror::Error;

use super::Rewrite;
use crate::Hugr;

/// A subset of the nodes in a graph, and the ports that it is connected to.
Expand Down Expand Up @@ -77,7 +79,7 @@ pub type ParentsMap = HashMap<NodeIndex, NodeIndex>;
/// Includes the new weights for the nodes in the replacement graph.
#[derive(Debug, Clone)]
#[allow(unused)]
pub struct Rewrite {
pub struct Replace {
/// The subgraph to be replaced.
subgraph: BoundedSubgraph,
/// The replacement graph.
Expand All @@ -86,7 +88,7 @@ pub struct Rewrite {
parents: ParentsMap,
}

impl Rewrite {
impl Replace {
/// Creates a new rewrite operation.
pub fn new(
subgraph: BoundedSubgraph,
Expand Down Expand Up @@ -114,30 +116,42 @@ impl Rewrite {
)
}

pub fn verify_convexity(&self) -> Result<(), ReplaceError> {
unimplemented!()
}

pub fn verify_boundaries(&self) -> Result<(), ReplaceError> {
unimplemented!()
}
}

impl Rewrite for Replace {
type Error = ReplaceError;
const UNCHANGED_ON_FAILURE: bool = false;

/// Checks that the rewrite is valid.
///
/// This includes having a convex subgraph (TODO: include definition), and
/// having matching numbers of ports on the boundaries.
pub fn verify(&self) -> Result<(), RewriteError> {
/// TODO not clear this implementation really provides much guarantee about [self.apply]
/// but this class is not really working anyway.
fn verify(&self, _h: &Hugr) -> Result<(), ReplaceError> {
self.verify_convexity()?;
self.verify_boundaries()?;
Ok(())
}

pub fn verify_convexity(&self) -> Result<(), RewriteError> {
todo!()
}

pub fn verify_boundaries(&self) -> Result<(), RewriteError> {
todo!()
/// Performs a Replace operation on the graph.
fn apply(self, _h: &mut Hugr) -> Result<(), ReplaceError> {
unimplemented!()
}
}

/// Error generated when a rewrite fails.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RewriteError {
/// The rewrite failed because the boundary defined by the
/// [`Rewrite`] could not be matched to the dangling ports of the
pub enum ReplaceError {
/// The replacement failed because the boundary defined by the
/// [`Replace`] could not be matched to the dangling ports of the
/// [`OpenHugr`].
#[error("The boundary defined by the rewrite could not be matched to the dangling ports of the OpenHugr")]
BoundarySize(#[source] portgraph::substitute::RewriteError),
Expand All @@ -152,7 +166,7 @@ pub enum RewriteError {
NotConvex(),
}

impl From<portgraph::substitute::RewriteError> for RewriteError {
impl From<portgraph::substitute::RewriteError> for ReplaceError {
fn from(e: portgraph::substitute::RewriteError) -> Self {
match e {
portgraph::substitute::RewriteError::BoundarySize => Self::BoundarySize(e),
Expand Down
Loading

0 comments on commit 96cac0e

Please sign in to comment.