diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 0ea1a434b..81e0cacff 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -475,7 +475,10 @@ fn get_edge_type(hugr: &H, ports: &[(Node, Port)]) -> Option /// Whether a subgraph is valid. /// -/// Does NOT check for convexity. +/// Verifies that input and output ports are valid subgraph boundaries, i.e. they belong +/// to nodes within the subgraph and are linked to at least one node outside of the subgraph. +/// This does NOT check convexity proper, i.e. whether the set of nodes form a convex +/// induced graph. fn validate_subgraph( hugr: &H, nodes: &[Node], @@ -517,16 +520,34 @@ fn validate_subgraph( } let mut ports_inside = inputs.iter().flatten().chain(outputs).copied(); - let mut ports_outside = ports_inside - .clone() - .flat_map(|(n, p)| hugr.linked_ports(n, p)); // Check incoming & outgoing ports have target resp. source inside let nodes = nodes.iter().copied().collect::>(); if ports_inside.any(|(n, _)| !nodes.contains(&n)) { return Err(InvalidSubgraph::InvalidBoundary); } - // Check incoming & outgoing ports have source resp. target outside - if ports_outside.any(|(n, _)| nodes.contains(&n)) { + // Check that every inside port has at least one linked port outside. + if ports_inside.any(|(n, p)| hugr.linked_ports(n, p).all(|(n1, _)| nodes.contains(&n1))) { + return Err(InvalidSubgraph::InvalidBoundary); + } + // Check that every incoming port of a node in the subgraph whose source is not in the subgraph + // belongs to inputs. + if nodes.clone().into_iter().any(|n| { + hugr.node_inputs(n).any(|p| { + hugr.linked_ports(n, p).any(|(n1, _)| { + !nodes.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p))) + }) + }) + }) { + return Err(InvalidSubgraph::NotConvex); + } + // Check that every outgoing port of a node in the subgraph whose target is not in the subgraph + // belongs to outputs. + if nodes.clone().into_iter().any(|n| { + hugr.node_outputs(n).any(|p| { + hugr.linked_ports(n, p) + .any(|(n1, _)| !nodes.contains(&n1) && !outputs.contains(&(n, p))) + }) + }) { return Err(InvalidSubgraph::NotConvex); } @@ -662,10 +683,13 @@ mod tests { hugr::views::{HierarchyView, SiblingGraph}, hugr::HugrMut, ops::{ - handle::{FuncID, NodeHandle}, + handle::{DfgID, FuncID, NodeHandle}, OpType, }, - std_extensions::{logic::test::and_op, quantum::test::cx_gate}, + std_extensions::{ + logic::test::{and_op, not_op}, + quantum::test::cx_gate, + }, type_row, }; @@ -712,6 +736,23 @@ mod tests { Ok((hugr, func_id.node())) } + fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { + let mut mod_builder = ModuleBuilder::new(); + let func = + mod_builder.declare("test", FunctionType::new_linear(type_row![BOOL_T]).pure())?; + let func_id = { + let mut dfg = mod_builder.define_declaration(&func)?; + let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?; + let outs2 = dfg.add_dataflow_op(not_op(), outs1.outputs())?; + let outs3 = dfg.add_dataflow_op(not_op(), outs2.outputs())?; + dfg.finish_with_outputs(outs3.outputs())? + }; + let hugr = mod_builder + .finish_prelude_hugr() + .map_err(|e| -> BuildError { e.into() })?; + Ok((hugr, func_id.node())) + } + /// A HUGR with a copy fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); @@ -854,19 +895,41 @@ mod tests { #[test] fn non_convex_subgraph() { + let (hugr, func_root) = build_3not_hugr().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); + let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); + let not2 = hugr.output_neighbours(not1).exactly_one().unwrap(); + let not3 = hugr.output_neighbours(not2).exactly_one().unwrap(); + let not1_inp = hugr.node_inputs(not1).next().unwrap(); + let not1_out = hugr.node_outputs(not1).next().unwrap(); + let not3_inp = hugr.node_inputs(not3).next().unwrap(); + let not3_out = hugr.node_outputs(not3).next().unwrap(); + assert!(matches!( + SiblingSubgraph::try_new( + vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]], + vec![(not1, not1_out), (not3, not3_out)], + &func + ), + Err(InvalidSubgraph::NotConvex) + )); + } + + #[test] + fn invalid_boundary() { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); - let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); - let snd_cx_edge = hugr.node_inputs(out).next().unwrap(); - // All graph but one edge + let cx_edges_in = hugr.node_outputs(inp); + let cx_edges_out = hugr.node_inputs(out); + // All graph but the CX assert!(matches!( SiblingSubgraph::try_new( - vec![vec![(out, snd_cx_edge)]], - vec![(inp, first_cx_edge)], + cx_edges_out.map(|p| vec![(out, p)]).collect(), + cx_edges_in.map(|p| (inp, p)).collect(), &func, ), - Err(InvalidSubgraph::NotConvex) + Err(InvalidSubgraph::InvalidBoundary) )); } @@ -894,4 +957,28 @@ mod tests { Ok(()) } + + #[test] + fn edge_both_output_and_copy() { + // https://github.com/CQCL-DEV/hugr/issues/518 + let one_bit = type_row![BOOL_T]; + let two_bit = type_row![BOOL_T, BOOL_T]; + + let mut builder = + DFGBuilder::new(FunctionType::new(one_bit.clone(), two_bit.clone())).unwrap(); + let inw = builder.input_wires().exactly_one().unwrap(); + let outw1 = builder + .add_dataflow_op(not_op(), [inw]) + .unwrap() + .out_wire(0); + let outw2 = builder + .add_dataflow_op(and_op(), [inw, outw1]) + .unwrap() + .outputs(); + let outw = [outw1].into_iter().chain(outw2); + let h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); + assert_eq!(subg.nodes().len(), 2); + } }