Skip to content

Commit

Permalink
feat: Make extension delta a parameter of CFG builders (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor authored Sep 11, 2023
1 parent a295ab8 commit 3f6d3c8
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 52 deletions.
31 changes: 16 additions & 15 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,12 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
pub(crate) mod test {
use super::*;
use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder};
use crate::extension::prelude::USIZE_T;
use crate::extension::{prelude::USIZE_T, ExtensionSet};

use crate::hugr::views::{HierarchyView, SiblingGraph};
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::Const;
use crate::types::Type;
use crate::types::{FunctionType, Type};
use crate::{type_row, Hugr};
const NAT: Type = USIZE_T;

Expand All @@ -426,13 +426,13 @@ pub(crate) mod test {
// /-> left --\
// entry -> split > merge -> head -> tail -> exit
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
Expand Down Expand Up @@ -611,7 +611,7 @@ pub(crate) mod test {
unit_const: &ConstID,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let split = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?,
const_pred,
)?;
let merge = build_then_else_merge_from_if(cfg, unit_const, split)?;
Expand All @@ -624,15 +624,15 @@ pub(crate) mod test {
split: BasicBlockID,
) -> Result<BasicBlockID, BuildError> {
let merge = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let left = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let right = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
cfg.branch(&split, 0, &left)?;
Expand All @@ -649,7 +649,7 @@ pub(crate) mod test {
header: BasicBlockID,
) -> Result<BasicBlockID, BuildError> {
let tail = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?,
const_pred,
)?;
cfg.branch(&tail, 1, &header)?;
Expand All @@ -663,7 +663,7 @@ pub(crate) mod test {
unit_const: &ConstID,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let header = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let tail = build_loop_from_header(cfg, const_pred, header)?;
Expand All @@ -674,18 +674,19 @@ pub(crate) mod test {
pub fn build_cond_then_loop_cfg(
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2)?,
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
&pred_const,
)?;
let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?;
let head = if separate {
let h = n_identity(
cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
&const_unit,
)?;
cfg_builder.branch(&merge, 0, &h)?;
Expand All @@ -708,13 +709,13 @@ pub(crate) mod test {
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
//let sum2_type = Type::new_predicate(2);

let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
Expand Down
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();

Expand All @@ -325,8 +326,7 @@ pub trait Dataflow: Container {
NodeType::open_extensions(ops::CFG {
inputs: inputs.clone(),
outputs: output_types.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
}),
input_wires,
)?;
Expand Down
69 changes: 44 additions & 25 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,16 @@ impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {

impl CFGBuilder<Hugr> {
/// New CFG rooted HUGR builder
pub fn new(input: impl Into<TypeRow>, output: impl Into<TypeRow>) -> Result<Self, BuildError> {
let input = input.into();
let output = output.into();
pub fn new(signature: FunctionType) -> Result<Self, BuildError> {
let cfg_op = ops::CFG {
inputs: input.clone(),
outputs: output.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
inputs: signature.input.clone(),
outputs: signature.output.clone(),
extension_delta: signature.extension_reqs,
};

// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(cfg_op));
let cfg_node = base.root();
CFGBuilder::create(base, cfg_node, input, output)
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
}
}

Expand Down Expand Up @@ -119,24 +115,31 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, predicate_variants, other_outputs, false)
self.any_block_builder(
inputs,
predicate_variants,
other_outputs,
extension_delta,
false,
)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let op = OpType::BasicBlock(BasicBlock::DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
});
let parent = self.container_node();
let block_n = if entry {
Expand Down Expand Up @@ -165,11 +168,15 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if there is an error adding the node.
pub fn simple_block_builder(
&mut self,
inputs: TypeRow,
outputs: TypeRow,
signature: FunctionType,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(inputs, vec![type_row![]; n_cases], outputs)
self.block_builder(
signature.input,
vec![type_row![]; n_cases],
signature.extension_reqs,
signature.output,
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
Expand All @@ -183,12 +190,19 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(inputs, predicate_variants, other_outputs, true)
self.any_block_builder(
inputs,
predicate_variants,
other_outputs,
extension_delta,
true,
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
Expand All @@ -201,8 +215,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs)
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta)
}

/// Returns the exit block of this [`CFGBuilder`].
Expand Down Expand Up @@ -276,6 +291,7 @@ impl BlockBuilder<Hugr> {
inputs: impl Into<TypeRow>,
predicate_variants: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
let inputs = inputs.into();
let predicate_variants: Vec<_> = predicate_variants.into_iter().collect();
Expand All @@ -284,11 +300,9 @@ impl BlockBuilder<Hugr> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
};

// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(op));
let root = base.root();
Self::create(base, root, predicate_variants, other_outputs, inputs)
Expand Down Expand Up @@ -326,8 +340,11 @@ mod test {
let [int] = func_builder.input_wires_arr();

let cfg_id = {
let mut cfg_builder =
func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?;
let mut cfg_builder = func_builder.cfg_builder(
vec![(NAT, int)],
type_row![NAT],
ExtensionSet::new(),
)?;
build_basic_cfg(&mut cfg_builder)?;

cfg_builder.finish_sub_container()?
Expand All @@ -344,7 +361,7 @@ mod test {
}
#[test]
fn basic_cfg_hugr() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
build_basic_cfg(&mut cfg_builder)?;
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Expand All @@ -355,14 +372,16 @@ mod test {
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?;
let mut entry_b =
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?;
let entry = {
let [inw] = entry_b.input_wires_arr();

let sum = entry_b.make_predicate(1, sum2_variants, [inw])?;
entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b = cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?;
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::simple_unary_predicate())?;
let [inw] = middle_b.input_wires_arr();
Expand Down
34 changes: 24 additions & 10 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use itertools::Itertools;
use thiserror::Error;

use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::extension::PRELUDE_REGISTRY;
use crate::extension::{ExtensionSet, PRELUDE_REGISTRY};
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
Expand All @@ -26,10 +26,13 @@ impl OutlineCfg {
}
}

fn compute_entry_exit_outside(
/// Compute the entry and exit nodes of the CFG which contains
/// [`self.blocks`], along with the output neighbour its parent graph and
/// the combined extension_deltas of all of the blocks.
fn compute_entry_exit_outside_extensions(
&self,
h: &impl HugrView,
) -> Result<(Node, Node, Node), OutlineCfgError> {
) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> {
let cfg_n = match self
.blocks
.iter()
Expand All @@ -47,6 +50,7 @@ impl OutlineCfg {
let cfg_entry = h.children(cfg_n).next().unwrap();
let mut entry = None;
let mut exit_succ = None;
let mut extension_delta = ExtensionSet::new();
for &n in self.blocks.iter() {
if n == cfg_entry
|| h.input_neighbours(n)
Expand All @@ -61,6 +65,7 @@ impl OutlineCfg {
}
}
}
extension_delta = extension_delta.union(&o.signature().extension_reqs);
let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
match external_succs.at_most_one() {
Ok(None) => (), // No external successors
Expand All @@ -76,7 +81,7 @@ impl OutlineCfg {
};
}
match (entry, exit_succ) {
(Some(e), Some((x, o))) => Ok((e, x, o)),
(Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)),
(None, _) => Err(OutlineCfgError::NoEntryNode),
(_, None) => Err(OutlineCfgError::NoExitNode),
}
Expand All @@ -89,11 +94,12 @@ impl Rewrite for OutlineCfg {

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
self.compute_entry_exit_outside(h)?;
self.compute_entry_exit_outside_extensions(h)?;
Ok(())
}
fn apply(self, h: &mut impl HugrMut) -> Result<(), OutlineCfgError> {
let (entry, exit, outside) = self.compute_entry_exit_outside(h)?;
let (entry, exit, outside, extension_delta) =
self.compute_entry_exit_outside_extensions(h)?;
// 1. Compute signature
// These panic()s only happen if the Hugr would not have passed validate()
let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else {
Expand All @@ -109,11 +115,19 @@ impl Rewrite for OutlineCfg {

// 2. new_block contains input node, sub-cfg, exit node all connected
let new_block = {
let mut new_block_bldr =
BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
let mut new_block_bldr = BlockBuilder::new(
inputs.clone(),
vec![type_row![]],
outputs.clone(),
extension_delta.clone(),
)
.unwrap();
let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
cfg.exit_block(); // Makes inner exit block (but no entry block)
// N.B. By invoking the cfg_builder, we're forgetting any input
// extensions that may have existed on the original CFG.
let cfg = new_block_bldr
.cfg_builder(wires_in, outputs, extension_delta)
.unwrap();
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block_bldr
.add_constant(ops::Const::simple_unary_predicate())
Expand Down

0 comments on commit 3f6d3c8

Please sign in to comment.