Skip to content

Commit

Permalink
feat: DataflowParent trait for getting inner signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 20, 2023
1 parent 268f120 commit c04910b
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 64 deletions.
29 changes: 7 additions & 22 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, BasicBlock, OpType};
use crate::ops::handle::NodeHandle;
use crate::ops::{self, dataflow::DataflowParent, BasicBlock, OpType};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
};
use crate::{hugr::views::HugrView, types::TypeRow};
use crate::{ops::handle::NodeHandle, types::Type};

use crate::Node;
use crate::{
Expand Down Expand Up @@ -150,13 +150,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
self.hugr_mut().add_node_with_parent(parent, op)
}?;

BlockBuilder::create(
self.hugr_mut(),
block_n,
tuple_sum_rows,
other_outputs,
inputs,
)
BlockBuilder::create(self.hugr_mut(), block_n)
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
Expand Down Expand Up @@ -248,18 +242,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
) -> Result<(), BuildError> {
Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
}
fn create(
base: B,
block_n: Node,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
inputs: TypeRow,
) -> Result<Self, BuildError> {
// The node outputs a TupleSum before the data outputs of the block node
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows);
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = FunctionType::new(inputs, TypeRow::from(node_outputs));
fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
let block_op = base.get_optype(block_n).as_basic_block().unwrap();
let signature = block_op.inner_signature();
let inp_ex = base
.as_ref()
.get_nodetype(block_n)
Expand Down Expand Up @@ -304,7 +289,7 @@ impl BlockBuilder<Hugr> {

let base = Hugr::new(NodeType::new(op, input_extensions));
let root = base.root();
Self::create(base, root, tuple_sum_rows, other_outputs, inputs)
Self::create(base, root)
}

/// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`).
Expand Down
1 change: 1 addition & 0 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowParent;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{LeafOp, OpType};

Expand Down
21 changes: 12 additions & 9 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ use portgraph::dot::{DotFormat, EdgeStyle, NodeStyle, PortStyle};
use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE};
use crate::ops::dataflow::DataflowParent;
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG};
use crate::ops::{OpName, OpTag, OpTrait, OpType};
#[rustversion::since(1.75)] // uses impl in return position
use crate::types::Type;
use crate::types::{EdgeKind, FunctionType, PolyFuncType};
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, IncomingPort, Node, OutgoingPort, Port};
#[rustversion::since(1.75)] // uses impl in return position
use itertools::Either;
Expand Down Expand Up @@ -336,14 +337,16 @@ pub trait HugrView: sealed::HugrInternals {

/// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function
/// type. Otherwise return None.
fn get_function_type(&self) -> Option<PolyFuncType> {
fn get_function_type(&self) -> Option<FunctionType> {
let op = self.get_nodetype(self.root());
match &op.op {
OpType::DFG(DFG { signature }) => Some(signature.clone().into()),
OpType::FuncDecl(FuncDecl { signature, .. })
| OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature.clone()),
_ => None,
}
Some(match &op.op {
OpType::DFG(dfg) => dfg.inner_signature(),
OpType::Case(case) => case.inner_signature(),
OpType::BasicBlock(block) => block.inner_signature(),
OpType::FuncDecl(decl) => decl.inner_signature(),
OpType::FuncDefn(def) => def.inner_signature(),
_ => return None,
})
}

/// Return a wrapper over the view that can be used in petgraph algorithms.
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ pub(super) mod test {

assert_eq!(
region.get_function_type(),
Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into())
Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]))
);
let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?;
assert_eq!(
inner_region.get_function_type(),
Some(FunctionType::new(type_row![NAT], type_row![NAT]).into())
Some(FunctionType::new(type_row![NAT], type_row![NAT]))
);

Ok(())
Expand Down
6 changes: 1 addition & 5 deletions src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,7 @@ mod test {
fn flat_mut(mut simple_dfg_hugr: Hugr) {
simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap();
let root = simple_dfg_hugr.root();
let signature = simple_dfg_hugr
.get_function_type()
.unwrap()
.instantiate(&[], &PRELUDE_REGISTRY)
.unwrap();
let signature = simple_dfg_hugr.get_function_type().unwrap();

let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
assert_eq!(
Expand Down
33 changes: 27 additions & 6 deletions src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};
use crate::{type_row, Direction};

use super::dataflow::DataflowOpTrait;
use super::dataflow::{DataflowOpTrait, DataflowParent};
use super::OpTag;
use super::{impl_op_name, OpName, OpTrait, StaticTag};

Expand Down Expand Up @@ -145,6 +145,26 @@ impl StaticTag for BasicBlock {
const TAG: OpTag = OpTag::BasicBlock;
}

impl DataflowParent for BasicBlock {
fn inner_signature(&self) -> FunctionType {
match self {
BasicBlock::DFB {
inputs,
other_outputs,
tuple_sum_rows,
..
} => {
// The node outputs a TupleSum before the data outputs of the block node
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone());
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(other_outputs);
FunctionType::new(inputs.clone(), TypeRow::from(node_outputs))
}
BasicBlock::Exit { cfg_outputs } => FunctionType::new(type_row![], cfg_outputs.clone()),
}
}
}

impl OpTrait for BasicBlock {
/// The description of the operation.
fn description(&self) -> &str {
Expand Down Expand Up @@ -223,6 +243,12 @@ impl StaticTag for Case {
const TAG: OpTag = OpTag::Case;
}

impl DataflowParent for Case {
fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

impl OpTrait for Case {
fn description(&self) -> &str {
"A case node inside a conditional"
Expand All @@ -247,11 +273,6 @@ impl Case {
pub fn dataflow_output(&self) -> &TypeRow {
&self.signature.output
}

/// The signature of the dataflow sibling graph contained in the [`Case`]
pub fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

fn tuple_sum_first(tuple_sum_row: &TypeRow, rest: &TypeRow) -> TypeRow {
Expand Down
15 changes: 14 additions & 1 deletion src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ impl LoadConstant {
}
}

/// Operations that is the parent of a dataflow graph.
pub trait DataflowParent {
/// Signature of the inner dataflow graph.
fn inner_signature(&self) -> FunctionType;
}

/// A simply nested dataflow graph.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DFG {
Expand All @@ -231,6 +237,13 @@ pub struct DFG {
}

impl_op_name!(DFG);

impl DataflowParent for DFG {
fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

impl DataflowOpTrait for DFG {
const TAG: OpTag = OpTag::Dfg;

Expand All @@ -239,6 +252,6 @@ impl DataflowOpTrait for DFG {
}

fn signature(&self) -> FunctionType {
self.signature.clone()
self.inner_signature()
}
}
17 changes: 16 additions & 1 deletion src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
use smol_str::SmolStr;

use crate::types::{EdgeKind, PolyFuncType};
use crate::types::{EdgeKind, FunctionType, PolyFuncType};
use crate::types::{Type, TypeBound};

use super::dataflow::DataflowParent;
use super::StaticTag;
use super::{impl_op_name, OpTag, OpTrait};

Expand Down Expand Up @@ -43,6 +44,13 @@ impl_op_name!(FuncDefn);
impl StaticTag for FuncDefn {
const TAG: OpTag = OpTag::FuncDefn;
}

impl DataflowParent for FuncDefn {
fn inner_signature(&self) -> FunctionType {
self.signature.body().clone()
}
}

impl OpTrait for FuncDefn {
fn description(&self) -> &str {
"A function definition"
Expand Down Expand Up @@ -70,6 +78,13 @@ impl_op_name!(FuncDecl);
impl StaticTag for FuncDecl {
const TAG: OpTag = OpTag::Function;
}

impl DataflowParent for FuncDecl {
fn inner_signature(&self) -> FunctionType {
self.signature.body().clone()
}
}

impl OpTrait for FuncDecl {
fn description(&self) -> &str {
"External function declaration, linked at runtime"
Expand Down
15 changes: 6 additions & 9 deletions src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use thiserror::Error;

use crate::types::{FunctionType, Type, TypeRow};

use super::dataflow::DataflowParent;
use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp};

/// A set of property flags required for an operation.
Expand Down Expand Up @@ -79,8 +80,8 @@ impl ValidateOp for super::FuncDefn {
) -> Result<(), ChildrenValidationError> {
// We check type-variables are declared in `validate_subtree`, so here
// we can just assume all type variables are valid regardless of binders.
let FunctionType { input, output, .. } = self.signature.body();
validate_io_nodes(input, output, "function definition", children)
let FunctionType { input, output, .. } = self.inner_signature();
validate_io_nodes(&input, &output, "function definition", children)
}
}

Expand Down Expand Up @@ -136,7 +137,7 @@ impl ValidateOp for super::Conditional {
let case_op = optype
.as_case()
.expect("Child check should have already checked valid ops.");
let sig = &case_op.signature;
let sig = &case_op.inner_signature();
if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
return Err(ChildrenValidationError::ConditionalCaseSignature {
child,
Expand Down Expand Up @@ -343,12 +344,8 @@ impl ValidateOp for super::Case {
&self,
children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
validate_io_nodes(
&self.signature.input,
&self.signature.output,
"Conditional",
children,
)
let sig = self.inner_signature();
validate_io_nodes(&sig.input, &sig.output, "Conditional", children)
}
}

Expand Down
32 changes: 23 additions & 9 deletions src/types/check.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
//! Logic for checking values against types.
use thiserror::Error;

use crate::{values::Value, HugrView};
use crate::{
ops::{FuncDecl, FuncDefn, OpType},
values::Value,
Hugr, HugrView,
};

use super::{CustomType, Type, TypeEnum};
use super::{CustomType, PolyFuncType, Type, TypeEnum};

/// Struct for custom type check fails.
#[derive(Clone, Debug, PartialEq, Eq, Error)]
Expand Down Expand Up @@ -45,6 +49,22 @@ pub enum ConstTypeError {
CustomCheckFail(#[from] CustomCheckFailure),
}

fn check_ts(v: &Hugr, t: &PolyFuncType) -> bool {
// exact signature equality, in future this may need to be
// relaxed to be compatibility checks between the signatures.
let root_op = v.get_optype(v.root());
if let OpType::FuncDecl(FuncDecl { signature, .. })
| OpType::FuncDefn(FuncDefn { signature, .. }) = root_op
{
signature == t
} else {
v.get_function_type().is_some_and(|ft| {
let value_ts: PolyFuncType = ft.into();
&value_ts == t
})
}
}

impl Type {
/// Check that a [`Value`] is a valid instance of this [`Type`].
///
Expand All @@ -57,13 +77,7 @@ impl Type {
e_val.0.check_custom_type(e)?;
Ok(())
}
(TypeEnum::Function(t), Value::Function { hugr: v })
if v.get_function_type().is_some_and(|f| **t == f) =>
{
// exact signature equality, in future this may need to be
// relaxed to be compatibility checks between the signatures.
Ok(())
}
(TypeEnum::Function(t), Value::Function { hugr: v }) if check_ts(v, t) => Ok(()),
(TypeEnum::Tuple(t), Value::Tuple { vs: t_v }) => {
if t.len() != t_v.len() {
return Err(ConstTypeError::TupleWrongLength);
Expand Down

0 comments on commit c04910b

Please sign in to comment.