From 529fc330479a1eb752213436b1d7d9c54c4214dd Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 27 Sep 2023 16:55:28 +0100 Subject: [PATCH 1/2] Add an optional direction check when querying a port index --- src/builder/build_traits.rs | 6 +++--- src/hugr.rs | 23 ++++++++++++++++++++++- src/hugr/hugrmut.rs | 9 ++++++--- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 7ef288e2c..2b267aa66 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -28,7 +28,7 @@ use super::{ tail_loop::TailLoopBuilder, BuildError, Wire, }; -use crate::Hugr; +use crate::{Direction, Hugr}; use crate::hugr::HugrMut; @@ -662,8 +662,8 @@ fn wire_up( dst: Node, dst_port: impl PortIndex, ) -> Result { - let src_port = src_port.index(); - let dst_port = dst_port.index(); + let src_port = src_port.try_index(Direction::Outgoing)?; + let dst_port = dst_port.try_index(Direction::Incoming)?; let base = data_builder.hugr_mut(); let src_offset = Port::new_outgoing(src_port); diff --git a/src/hugr.rs b/src/hugr.rs index 6f78fe43f..2622a109d 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -214,6 +214,14 @@ pub struct Port { pub trait PortIndex { /// Returns the offset of the port. fn index(self) -> usize; + /// Returns the offset of the port, doing a sanity check on the expected direction. + #[inline(always)] + fn try_index(self, _dir: Direction) -> Result + where + Self: Sized, + { + Ok(self.index()) + } } /// The direction of a port. @@ -398,6 +406,16 @@ impl PortIndex for Port { fn index(self) -> usize { self.offset.index() } + #[inline(always)] + fn try_index(self, dir: Direction) -> Result + where + Self: Sized, + { + match dir == self.direction() { + true => Ok(self.index()), + false => Err(HugrError::InvalidPortDirection(dir)), + } + } } impl PortIndex for usize { @@ -416,7 +434,7 @@ impl Wire { /// Create a new wire from a node and a port. #[inline] pub fn new(node: Node, port: Port) -> Self { - Self(node, port.index()) + Self(node, port.try_index(Direction::Outgoing).unwrap()) } /// The node that this wire is connected to. @@ -484,6 +502,9 @@ pub enum HugrError { /// The node doesn't exist. #[error("Invalid node {0:?}.")] InvalidNode(Node), + /// An invalid port was specified. + #[error("Invalid port direction {0:?}.")] + InvalidPortDirection(Direction), } #[cfg(feature = "pyo3")] diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index a02d822c9..86bd095fc 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -266,9 +266,12 @@ where dst: Node, dst_port: impl PortIndex, ) -> Result<(), HugrError> { - self.as_mut() - .graph - .link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?; + self.as_mut().graph.link_nodes( + src.index, + src_port.try_index(Direction::Outgoing)?, + dst.index, + dst_port.try_index(Direction::Incoming)?, + )?; Ok(()) } From ca9e3d55c8b197255a0860e96a3f4106c801ba8a Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 28 Sep 2023 14:21:21 +0100 Subject: [PATCH 2/2] feat: port incoming / outgoing traits --- src/builder/build_traits.rs | 27 ++++---- src/hugr.rs | 129 +++++++++++++++++++++++++++--------- src/hugr/hugrmut.rs | 14 ++-- 3 files changed, 118 insertions(+), 52 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 2b267aa66..a1411dcd2 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -1,7 +1,7 @@ use crate::hugr::hugrmut::InsertionResult; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; -use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError}; +use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError}; use crate::ops::{self, LeafOp, OpTrait, OpType}; use std::iter; @@ -28,7 +28,7 @@ use super::{ tail_loop::TailLoopBuilder, BuildError, Wire, }; -use crate::{Direction, Hugr}; +use crate::Hugr; use crate::hugr::HugrMut; @@ -658,27 +658,26 @@ fn wire_up_inputs( fn wire_up( data_builder: &mut T, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result { - let src_port = src_port.try_index(Direction::Outgoing)?; - let dst_port = dst_port.try_index(Direction::Incoming)?; + let src_port = Port::try_new_outgoing(src_port)?; + let dst_port = Port::try_new_incoming(dst_port)?; let base = data_builder.hugr_mut(); - let src_offset = Port::new_outgoing(src_port); let src_parent = base.get_parent(src); let dst_parent = base.get_parent(dst); let local_source = src_parent == dst_parent; - if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_offset).unwrap() { + if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() { if !local_source { // Non-local value sources require a state edge to an ancestor of dst if !typ.copyable() { let val_err: ValidationError = InterGraphEdgeError::NonCopyableData { from: src, - from_offset: Port::new_outgoing(src_port), + from_offset: src_port, to: dst, - to_offset: Port::new_incoming(dst_port), + to_offset: dst_port, ty: EdgeKind::Value(typ), } .into(); @@ -694,9 +693,9 @@ fn wire_up( else { let val_err: ValidationError = InterGraphEdgeError::NoRelation { from: src, - from_offset: Port::new_outgoing(src_port), + from_offset: src_port, to: dst, - to_offset: Port::new_incoming(dst_port), + to_offset: dst_port, } .into(); return Err(val_err.into()); @@ -705,7 +704,7 @@ fn wire_up( // TODO: Avoid adding duplicate edges // This should be easy with https://github.com/CQCL-DEV/hugr/issues/130 base.add_other_edge(src, src_sibling)?; - } else if !typ.copyable() & base.linked_ports(src, src_offset).next().is_some() { + } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() { // Don't copy linear edges. return Err(BuildError::NoCopyLinear(typ)); } @@ -719,7 +718,7 @@ fn wire_up( data_builder .hugr_mut() .get_optype(dst) - .port_kind(Port::new_incoming(dst_port)) + .port_kind(dst_port) .unwrap(), EdgeKind::Value(_) )) diff --git a/src/hugr.rs b/src/hugr.rs index 2622a109d..294f99fea 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -208,20 +208,21 @@ pub struct Port { } /// A trait for getting the undirected index of a port. -/// -/// This allows functions to admit both [`Port`]s and explicit `usize`s for -/// identifying port offsets. pub trait PortIndex { /// Returns the offset of the port. fn index(self) -> usize; - /// Returns the offset of the port, doing a sanity check on the expected direction. - #[inline(always)] - fn try_index(self, _dir: Direction) -> Result - where - Self: Sized, - { - Ok(self.index()) - } +} + +/// A port in the incoming direction. +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)] +pub struct IncomingPort { + index: u16, +} + +/// A port in the outgoing direction. +#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)] +pub struct OutgoingPort { + index: u16, } /// The direction of a port. @@ -380,18 +381,36 @@ impl Port { /// Creates a new incoming port. #[inline] - pub fn new_incoming(port: usize) -> Self { - Self { - offset: portgraph::PortOffset::new_incoming(port), - } + pub fn new_incoming(port: impl Into) -> Self { + Self::try_new_incoming(port).unwrap() } /// Creates a new outgoing port. #[inline] - pub fn new_outgoing(port: usize) -> Self { - Self { - offset: portgraph::PortOffset::new_outgoing(port), - } + pub fn new_outgoing(port: impl Into) -> Self { + Self::try_new_outgoing(port).unwrap() + } + + /// Creates a new incoming port. + #[inline] + pub fn try_new_incoming(port: impl TryInto) -> Result { + let Ok(port) = port.try_into() else { + return Err(HugrError::InvalidPortDirection(Direction::Outgoing)); + }; + Ok(Self { + offset: portgraph::PortOffset::new_incoming(port.index()), + }) + } + + /// Creates a new outgoing port. + #[inline] + pub fn try_new_outgoing(port: impl TryInto) -> Result { + let Ok(port) = port.try_into() else { + return Err(HugrError::InvalidPortDirection(Direction::Incoming)); + }; + Ok(Self { + offset: portgraph::PortOffset::new_outgoing(port.index()), + }) } /// Returns the direction of the port. @@ -406,16 +425,6 @@ impl PortIndex for Port { fn index(self) -> usize { self.offset.index() } - #[inline(always)] - fn try_index(self, dir: Direction) -> Result - where - Self: Sized, - { - match dir == self.direction() { - true => Ok(self.index()), - false => Err(HugrError::InvalidPortDirection(dir)), - } - } } impl PortIndex for usize { @@ -425,6 +434,64 @@ impl PortIndex for usize { } } +impl PortIndex for IncomingPort { + #[inline(always)] + fn index(self) -> usize { + self.index as usize + } +} + +impl PortIndex for OutgoingPort { + #[inline(always)] + fn index(self) -> usize { + self.index as usize + } +} + +impl From for IncomingPort { + #[inline(always)] + fn from(index: usize) -> Self { + Self { + index: index as u16, + } + } +} + +impl From for OutgoingPort { + #[inline(always)] + fn from(index: usize) -> Self { + Self { + index: index as u16, + } + } +} + +impl TryFrom for IncomingPort { + type Error = HugrError; + #[inline(always)] + fn try_from(port: Port) -> Result { + match port.direction() { + Direction::Incoming => Ok(Self { + index: port.index() as u16, + }), + dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)), + } + } +} + +impl TryFrom for OutgoingPort { + type Error = HugrError; + #[inline(always)] + fn try_from(port: Port) -> Result { + match port.direction() { + Direction::Outgoing => Ok(Self { + index: port.index() as u16, + }), + dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)), + } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] /// A DataFlow wire, defined by a Value-kind output port of a node // Stores node and offset to output port @@ -433,8 +500,8 @@ pub struct Wire(Node, usize); impl Wire { /// Create a new wire from a node and a port. #[inline] - pub fn new(node: Node, port: Port) -> Self { - Self(node, port.try_index(Direction::Outgoing).unwrap()) + pub fn new(node: Node, port: impl TryInto) -> Self { + Self(node, Port::try_new_outgoing(port).unwrap().index()) } /// The node that this wire is connected to. diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 86bd095fc..94f6a1305 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -14,7 +14,7 @@ use crate::{Hugr, Port}; use self::sealed::HugrMutInternals; use super::views::SiblingSubgraph; -use super::{NodeMetadata, PortIndex, Rewrite}; +use super::{IncomingPort, NodeMetadata, OutgoingPort, PortIndex, Rewrite}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrView + HugrMutInternals { @@ -110,9 +110,9 @@ pub trait HugrMut: HugrView + HugrMutInternals { fn connect( &mut self, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result<(), HugrError> { self.valid_node(src)?; self.valid_node(dst)?; @@ -262,15 +262,15 @@ where fn connect( &mut self, src: Node, - src_port: impl PortIndex, + src_port: impl TryInto, dst: Node, - dst_port: impl PortIndex, + dst_port: impl TryInto, ) -> Result<(), HugrError> { self.as_mut().graph.link_nodes( src.index, - src_port.try_index(Direction::Outgoing)?, + Port::try_new_outgoing(src_port)?.index(), dst.index, - dst_port.try_index(Direction::Incoming)?, + Port::try_new_incoming(dst_port)?.index(), )?; Ok(()) }