Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Static checking of Port direction #614

Merged
merged 19 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
94527a4
Rm Port::try_new_{incom,outgo}ing; add Port::as_{incom,outgo}ing as o…
acl-cqc Oct 18, 2023
a0498ce
Wire: just store OutgoingPort
acl-cqc Oct 18, 2023
3bb2daa
add_other_edge returns (OutgoingPort, IncomingPort)
acl-cqc Oct 18, 2023
1cac439
InsertIdentity requires IncomingPort
acl-cqc Oct 18, 2023
ad45ec8
Add OutgoingPorts/IncomingPorts, use for node_inputs/node_outputs
acl-cqc Oct 18, 2023
6995f0e
sibling_subgraph - convert to using Incoming/OutgoingPort
acl-cqc Oct 18, 2023
dbbbbf0
Add linked_outputs/linked_inputs
acl-cqc Oct 18, 2023
3604918
simple_replace - convert to Incoming/OutgoingPort
acl-cqc Oct 18, 2023
f86b30e
insert_identity/sibling_subgraph: rm unused Error variants (staticall…
acl-cqc Oct 20, 2023
1e9912d
Remove some unnecessary type annotations
acl-cqc Oct 20, 2023
f30ae99
Remove Port::new_{incom,outgo}ing
acl-cqc Oct 20, 2023
9a43cb8
outline_cfg: another linked_ports -> linked_outputs
acl-cqc Oct 25, 2023
960d0fa
Redefine IncomingPorts / OutgoingPorts as a std::iter::Map
acl-cqc Oct 25, 2023
f70ec2b
Redefine (Incom/Outgo)ingNodePorts similarly
acl-cqc Oct 25, 2023
2efb02c
Update comments on (node,linked)_(inputs,outputs)
acl-cqc Oct 25, 2023
af62fba
Merge remote-tracking branch 'origin/main' into refactor/incoming_out…
acl-cqc Oct 25, 2023
e8c416b
Merge remote-tracking branch 'origin/main' into refactor/incoming_out…
aborgna-q Oct 27, 2023
10e5313
Reduce change
acl-cqc Oct 27, 2023
2101784
sibling_subgraph: use combine_in_out more
acl-cqc Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, Port, ValidationError};
use crate::hugr::{IncomingPort, Node, NodeMetadata, OutgoingPort, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};

use std::iter;
Expand Down Expand Up @@ -368,7 +368,7 @@ pub trait Dataflow: Container {
input_extensions,
),
// Constant wire from the constant value node
vec![Wire::new(const_node, Port::new_outgoing(0))],
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;

Ok(load_n.out_wire(0))
Expand Down Expand Up @@ -658,12 +658,12 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<bool, BuildError> {
let src_port = Port::try_new_outgoing(src_port)?;
let dst_port = Port::try_new_incoming(dst_port)?;
let src_port: OutgoingPort = src_port.into();
let dst_port: IncomingPort = dst_port.into();
let base = data_builder.hugr_mut();

let src_parent = base.get_parent(src);
Expand All @@ -675,9 +675,9 @@ fn wire_up<T: Dataflow + ?Sized>(
if !typ.copyable() {
let val_err: ValidationError = InterGraphEdgeError::NonCopyableData {
from: src,
from_offset: src_port,
from_offset: src_port.into(),
Copy link
Contributor Author

@acl-cqc acl-cqc Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly we could extend NonCopyableData to take an OutgoingPort and an IncomingPort, etc., etc. - by all means shout out cases like that, I am not claiming to have found them all, but OTOH we don't need to get everything in the first PR, we can spread directionality-checking more widely over time.

to: dst,
to_offset: dst_port,
to_offset: dst_port.into(),
ty: EdgeKind::Value(typ),
}
.into();
Expand All @@ -693,9 +693,9 @@ fn wire_up<T: Dataflow + ?Sized>(
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
from: src,
from_offset: src_port,
from_offset: src_port.into(),
to: dst,
to_offset: dst_port,
to_offset: dst_port.into(),
}
.into();
return Err(val_err.into());
Expand Down
12 changes: 5 additions & 7 deletions src/builder/handle.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! Handles to nodes in HUGR used during the building phase.
//!
use crate::{
hugr::OutgoingPort,
ops::{
handle::{BasicBlockID, CaseID, DfgID, FuncID, NodeHandle, TailLoopID},
OpTag,
},
Port,
};
use crate::{Node, Wire};

Expand Down Expand Up @@ -64,7 +64,7 @@ impl<T: NodeHandle> BuildHandle<T> {
/// Retrieve a [`Wire`] corresponding to the given offset.
/// Does not check whether such a wire is valid for this node.
pub fn out_wire(&self, offset: usize) -> Wire {
Wire::new(self.node(), Port::new_outgoing(offset))
Wire::new(self.node(), OutgoingPort::from(offset))
}

#[inline]
Expand Down Expand Up @@ -124,14 +124,12 @@ impl Iterator for Outputs {
fn next(&mut self) -> Option<Self::Item> {
self.range
.next()
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
.map(|offset| Wire::new(self.node, OutgoingPort::from(offset)))
}

#[inline]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
self.range
.nth(n)
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
self.range.nth(n).map(|offset| Wire::new(self.node, offset))
}

#[inline]
Expand All @@ -157,7 +155,7 @@ impl DoubleEndedIterator for Outputs {
fn next_back(&mut self) -> Option<Self::Item> {
self.range
.next_back()
.map(|offset| Wire::new(self.node, Port::new_outgoing(offset)))
.map(|offset| Wire::new(self.node, offset))
}
}

Expand Down
74 changes: 38 additions & 36 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,26 +390,42 @@ impl Port {
/// Creates a new incoming port.
#[inline]
pub fn new_incoming(port: impl Into<IncomingPort>) -> Self {
Self::try_new_incoming(port).unwrap()
Self {
offset: portgraph::PortOffset::new_incoming(port.into().index()),
}
}

/// Converts to an [IncomingPort] if this port is one; else fails with
/// [HugrError::InvalidPortDirection]
#[inline]
pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
match self.direction() {
Direction::Incoming => Ok(IncomingPort {
index: self.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
}

/// Creates a new outgoing port.
#[inline]
pub fn new_outgoing(port: impl Into<OutgoingPort>) -> Self {
Self::try_new_outgoing(port).unwrap()
Self {
offset: portgraph::PortOffset::new_outgoing(port.into().index()),
}
}

/// Creates a new incoming port.
/// Converts to an [OutgoingPort] if this port is one; else fails with
/// [HugrError::InvalidPortDirection]
#[inline]
pub fn try_new_incoming(port: impl TryInto<IncomingPort>) -> Result<Self, HugrError> {
let Ok(port) = port.try_into() else {
return Err(HugrError::InvalidPortDirection(Direction::Outgoing));
};
Ok(Self {
offset: portgraph::PortOffset::new_incoming(port.index()),
})
pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
match self.direction() {
Direction::Outgoing => Ok(OutgoingPort {
index: self.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
}

/// Creates a new outgoing port.
#[inline]
pub fn try_new_outgoing(port: impl TryInto<OutgoingPort>) -> Result<Self, HugrError> {
Expand Down Expand Up @@ -474,29 +490,15 @@ impl From<usize> for OutgoingPort {
}
}

impl TryFrom<Port> for IncomingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Incoming => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Outgoing => Err(HugrError::InvalidPortDirection(dir)),
}
impl From<IncomingPort> for Port {
fn from(value: IncomingPort) -> Self {
Port::new_incoming(value)
}
}

impl TryFrom<Port> for OutgoingPort {
type Error = HugrError;
#[inline(always)]
fn try_from(port: Port) -> Result<Self, Self::Error> {
match port.direction() {
Direction::Outgoing => Ok(Self {
index: port.index() as u16,
}),
dir @ Direction::Incoming => Err(HugrError::InvalidPortDirection(dir)),
}
impl From<OutgoingPort> for Port {
fn from(value: OutgoingPort) -> Self {
Port::new_outgoing(value)
}
}

Expand All @@ -509,13 +511,13 @@ impl NodeIndex for Node {
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
pub struct Wire(Node, usize);
pub struct Wire(Node, OutgoingPort);

impl Wire {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: impl TryInto<OutgoingPort>) -> Self {
Self(node, Port::try_new_outgoing(port).unwrap().index())
pub fn new(node: Node, port: impl Into<OutgoingPort>) -> Self {
Self(node, port.into())
}

/// The node that this wire is connected to.
Expand All @@ -526,8 +528,8 @@ impl Wire {

/// The output port that this wire is connected to.
#[inline]
pub fn source(&self) -> Port {
Port::new_outgoing(self.1)
pub fn source(&self) -> OutgoingPort {
self.1
}
}

Expand Down
39 changes: 25 additions & 14 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ pub trait HugrMut: HugrMutInternals {
fn connect(
&mut self,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<(), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
Expand All @@ -123,7 +123,7 @@ pub trait HugrMut: HugrMutInternals {
///
/// The port is left in place.
#[inline]
fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> {
fn disconnect(&mut self, node: Node, port: impl Into<Port>) -> Result<(), HugrError> {
self.valid_node(node)?;
self.hugr_mut().disconnect(node, port)
}
Expand All @@ -136,7 +136,11 @@ pub trait HugrMut: HugrMutInternals {
///
/// [`OpTrait::other_input`]: crate::ops::OpTrait::other_input
/// [`OpTrait::other_output`]: crate::ops::OpTrait::other_output
fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> {
fn add_other_edge(
&mut self,
src: Node,
dst: Node,
) -> Result<(OutgoingPort, IncomingPort), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
self.hugr_mut().add_other_edge(src, dst)
Expand Down Expand Up @@ -249,20 +253,21 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
fn connect(
&mut self,
src: Node,
src_port: impl TryInto<OutgoingPort>,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl TryInto<IncomingPort>,
dst_port: impl Into<IncomingPort>,
) -> Result<(), HugrError> {
self.as_mut().graph.link_nodes(
src.index,
Port::try_new_outgoing(src_port)?.index(),
Port::new_outgoing(src_port).index(),
dst.index,
Port::try_new_incoming(dst_port)?.index(),
Port::new_incoming(dst_port).index(),
)?;
Ok(())
}

fn disconnect(&mut self, node: Node, port: Port) -> Result<(), HugrError> {
fn disconnect(&mut self, node: Node, port: impl Into<Port>) -> Result<(), HugrError> {
let port = port.into();
let offset = port.offset;
let port = self.as_mut().graph.port_index(node.index, offset).ok_or(
portgraph::LinkError::UnknownOffset {
Expand All @@ -274,15 +279,21 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
Ok(())
}

fn add_other_edge(&mut self, src: Node, dst: Node) -> Result<(Port, Port), HugrError> {
let src_port: Port = self
fn add_other_edge(
&mut self,
src: Node,
dst: Node,
) -> Result<(OutgoingPort, IncomingPort), HugrError> {
let src_port = self
.get_optype(src)
.other_port_index(Direction::Outgoing)
.expect("Source operation has no non-dataflow outgoing edges");
let dst_port: Port = self
.expect("Source operation has no non-dataflow outgoing edges")
.as_outgoing()?;
let dst_port = self
.get_optype(dst)
.other_port_index(Direction::Incoming)
.expect("Destination operation has no non-dataflow incoming edges");
.expect("Destination operation has no non-dataflow incoming edges")
.as_incoming()?;
self.connect(src, src_port, dst, dst_port)?;
Ok((src_port, dst_port))
}
Expand Down
19 changes: 5 additions & 14 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

use std::iter;

use crate::hugr::{HugrMut, Node};
use crate::hugr::{HugrMut, IncomingPort, Node};
use crate::ops::{LeafOp, OpTag, OpTrait};
use crate::types::EdgeKind;
use crate::{Direction, HugrView, Port};
use crate::HugrView;

use super::Rewrite;

Expand All @@ -18,12 +18,12 @@ pub struct IdentityInsertion {
/// The node following the identity to be inserted.
pub post_node: Node,
/// The port following the identity to be inserted.
pub post_port: Port,
pub post_port: IncomingPort,
}

impl IdentityInsertion {
/// Create a new [`IdentityInsertion`] specification.
pub fn new(post_node: Node, post_port: Port) -> Self {
pub fn new(post_node: Node, post_port: IncomingPort) -> Self {
Self {
post_node,
post_port,
Expand Down Expand Up @@ -71,17 +71,13 @@ impl Rewrite for IdentityInsertion {
unimplemented!()
}
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, IdentityInsertionError> {
if self.post_port.direction() != Direction::Incoming {
return Err(IdentityInsertionError::PortIsOutput);
}

let kind = h.get_optype(self.post_node).port_kind(self.post_port);
let Some(EdgeKind::Value(ty)) = kind else {
return Err(IdentityInsertionError::InvalidPortKind(kind));
};

let (pre_node, pre_port) = h
.linked_ports(self.post_node, self.post_port)
.linked_outputs(self.post_node, self.post_port)
.exactly_one()
.ok()
.expect("Value kind input can only have one connection.");
Expand Down Expand Up @@ -155,11 +151,6 @@ mod tests {

let final_node = tail.node();

let final_node_output = h.node_outputs(final_node).next().unwrap();
let rw = IdentityInsertion::new(final_node, final_node_output);
let apply_result = h.apply_rewrite(rw);
assert_eq!(apply_result, Err(IdentityInsertionError::PortIsOutput));

let final_node_input = h.node_inputs(final_node).next().unwrap();

let rw = IdentityInsertion::new(final_node, final_node_input);
Expand Down
3 changes: 2 additions & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ impl Rewrite for OutlineCfg {
for (pred, br) in preds {
if !self.blocks.contains(&pred) {
h.disconnect(pred, br).unwrap();
h.connect(pred, br, new_block, 0).unwrap();
h.connect(pred, br.as_outgoing().unwrap(), new_block, 0)
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
.unwrap();
}
}
if entry == outer_entry {
Expand Down
Loading
Loading