Skip to content

Commit

Permalink
fix: Improve convexity checking and fix test (#585)
Browse files Browse the repository at this point in the history
Fixes #518 .

I changed the `non_convex_subgraph` test because the graph it was
constructing seemed to meet the criteria for convexity, and replaced it
with another one which doesn't. But please check -- I was a little
confused by the definitions and may be missing something.
  • Loading branch information
cqc-alec authored Oct 2, 2023
1 parent bb09fd4 commit e03a2ed
Showing 1 changed file with 101 additions and 14 deletions.
115 changes: 101 additions & 14 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,10 @@ fn get_edge_type<H: HugrView>(hugr: &H, ports: &[(Node, Port)]) -> Option<Type>

/// Whether a subgraph is valid.
///
/// Does NOT check for convexity.
/// Verifies that input and output ports are valid subgraph boundaries, i.e. they belong
/// to nodes within the subgraph and are linked to at least one node outside of the subgraph.
/// This does NOT check convexity proper, i.e. whether the set of nodes form a convex
/// induced graph.
fn validate_subgraph<H: HugrView>(
hugr: &H,
nodes: &[Node],
Expand Down Expand Up @@ -517,16 +520,34 @@ fn validate_subgraph<H: HugrView>(
}

let mut ports_inside = inputs.iter().flatten().chain(outputs).copied();
let mut ports_outside = ports_inside
.clone()
.flat_map(|(n, p)| hugr.linked_ports(n, p));
// Check incoming & outgoing ports have target resp. source inside
let nodes = nodes.iter().copied().collect::<HashSet<_>>();
if ports_inside.any(|(n, _)| !nodes.contains(&n)) {
return Err(InvalidSubgraph::InvalidBoundary);
}
// Check incoming & outgoing ports have source resp. target outside
if ports_outside.any(|(n, _)| nodes.contains(&n)) {
// Check that every inside port has at least one linked port outside.
if ports_inside.any(|(n, p)| hugr.linked_ports(n, p).all(|(n1, _)| nodes.contains(&n1))) {
return Err(InvalidSubgraph::InvalidBoundary);
}
// Check that every incoming port of a node in the subgraph whose source is not in the subgraph
// belongs to inputs.
if nodes.clone().into_iter().any(|n| {
hugr.node_inputs(n).any(|p| {
hugr.linked_ports(n, p).any(|(n1, _)| {
!nodes.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
})
})
}) {
return Err(InvalidSubgraph::NotConvex);
}
// Check that every outgoing port of a node in the subgraph whose target is not in the subgraph
// belongs to outputs.
if nodes.clone().into_iter().any(|n| {
hugr.node_outputs(n).any(|p| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !nodes.contains(&n1) && !outputs.contains(&(n, p)))
})
}) {
return Err(InvalidSubgraph::NotConvex);
}

Expand Down Expand Up @@ -662,10 +683,13 @@ mod tests {
hugr::views::{HierarchyView, SiblingGraph},
hugr::HugrMut,
ops::{
handle::{FuncID, NodeHandle},
handle::{DfgID, FuncID, NodeHandle},
OpType,
},
std_extensions::{logic::test::and_op, quantum::test::cx_gate},
std_extensions::{
logic::test::{and_op, not_op},
quantum::test::cx_gate,
},
type_row,
};

Expand Down Expand Up @@ -712,6 +736,23 @@ mod tests {
Ok((hugr, func_id.node()))
}

fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
let mut mod_builder = ModuleBuilder::new();
let func =
mod_builder.declare("test", FunctionType::new_linear(type_row![BOOL_T]).pure())?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?;
let outs2 = dfg.add_dataflow_op(not_op(), outs1.outputs())?;
let outs3 = dfg.add_dataflow_op(not_op(), outs2.outputs())?;
dfg.finish_with_outputs(outs3.outputs())?
};
let hugr = mod_builder
.finish_prelude_hugr()
.map_err(|e| -> BuildError { e.into() })?;
Ok((hugr, func_id.node()))
}

/// A HUGR with a copy
fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
let mut mod_builder = ModuleBuilder::new();
Expand Down Expand Up @@ -854,19 +895,41 @@ mod tests {

#[test]
fn non_convex_subgraph() {
let (hugr, func_root) = build_3not_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let not1 = hugr.output_neighbours(inp).exactly_one().unwrap();
let not2 = hugr.output_neighbours(not1).exactly_one().unwrap();
let not3 = hugr.output_neighbours(not2).exactly_one().unwrap();
let not1_inp = hugr.node_inputs(not1).next().unwrap();
let not1_out = hugr.node_outputs(not1).next().unwrap();
let not3_inp = hugr.node_inputs(not3).next().unwrap();
let not3_out = hugr.node_outputs(not3).next().unwrap();
assert!(matches!(
SiblingSubgraph::try_new(
vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
vec![(not1, not1_out), (not3, not3_out)],
&func
),
Err(InvalidSubgraph::NotConvex)
));
}

#[test]
fn invalid_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
let snd_cx_edge = hugr.node_inputs(out).next().unwrap();
// All graph but one edge
let cx_edges_in = hugr.node_outputs(inp);
let cx_edges_out = hugr.node_inputs(out);
// All graph but the CX
assert!(matches!(
SiblingSubgraph::try_new(
vec![vec![(out, snd_cx_edge)]],
vec![(inp, first_cx_edge)],
cx_edges_out.map(|p| vec![(out, p)]).collect(),
cx_edges_in.map(|p| (inp, p)).collect(),
&func,
),
Err(InvalidSubgraph::NotConvex)
Err(InvalidSubgraph::InvalidBoundary)
));
}

Expand Down Expand Up @@ -894,4 +957,28 @@ mod tests {

Ok(())
}

#[test]
fn edge_both_output_and_copy() {
// https://github.com/CQCL-DEV/hugr/issues/518
let one_bit = type_row![BOOL_T];
let two_bit = type_row![BOOL_T, BOOL_T];

let mut builder =
DFGBuilder::new(FunctionType::new(one_bit.clone(), two_bit.clone())).unwrap();
let inw = builder.input_wires().exactly_one().unwrap();
let outw1 = builder
.add_dataflow_op(not_op(), [inw])
.unwrap()
.out_wire(0);
let outw2 = builder
.add_dataflow_op(and_op(), [inw, outw1])
.unwrap()
.outputs();
let outw = [outw1].into_iter().chain(outw2);
let h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap();
let view = SiblingGraph::<DfgID>::try_new(&h, h.root()).unwrap();
let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap();
assert_eq!(subg.nodes().len(), 2);
}
}

0 comments on commit e03a2ed

Please sign in to comment.