diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 3d7f0fc90..311fe781c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -14,6 +14,7 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } +ascent = { version = "0.7.0" } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } @@ -25,3 +26,6 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } +proptest-recurse = { version = "0.5.0" } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs new file mode 100644 index 000000000..dd3e6d2c0 --- /dev/null +++ b/hugr-passes/src/dataflow.rs @@ -0,0 +1,110 @@ +#![warn(missing_docs)] +//! Dataflow analysis of Hugrs. + +mod datalog; +pub use datalog::Machine; +mod value_row; + +mod results; +pub use results::{AnalysisResults, TailLoopTermination}; + +mod partial_value; +pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; + +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::types::TypeArg; +use hugr_core::{Hugr, HugrView, Node}; + +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). +pub trait DFContext: ConstLoader + std::ops::Deref { + /// Type of view contained within this context. (Ideally we'd constrain + /// by `std::ops::Deref` but that's not stable yet.) + type View: HugrView; + + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. + /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] + /// which is the correct value to leave if nothing can be deduced about that output. + /// (The default does nothing, i.e. leaves `Top` for all outputs.) + /// + /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple + /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple + fn interpret_leaf_op( + &self, + _node: Node, + _e: &ExtensionOp, + _ins: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } +} + +/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. +/// Implementors will likely want to override some/all of [Self::value_from_opaque], +/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// are "correct" but maximally conservative (minimally informative). +pub trait ConstLoader { + /// Produces a [PartialValue] from a constant. The default impl (expected + /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants + /// to their leaves ([Value::Extension] and [Value::Function]), + /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], + /// and builds nested [PartialValue::new_variant] to represent the structure. + fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { + traverse_value(self, n, &mut Vec::new(), cst) + } + + /// Produces an abstract value from an [OpaqueValue], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { + None + } + + /// Produces an abstract value from a Hugr in a [Value::Function], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { + None + } + + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node + /// (that has been loaded via a [LoadFunction]), if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDecl + /// [LoadFunction]: hugr_core::ops::LoadFunction + fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { + None + } +} + +fn traverse_value( + s: &(impl ConstLoader + ?Sized), + n: Node, + fields: &mut Vec, + cst: &Value, +) -> PartialValue { + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values.iter().enumerate().map(|(idx, elem)| { + fields.push(idx); + let r = traverse_value(s, n, fields, elem); + fields.pop(); + r + }); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => s + .value_from_opaque(n, fields, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => s + .value_from_const_hugr(n, fields, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs new file mode 100644 index 000000000..33ab4524d --- /dev/null +++ b/hugr-passes/src/dataflow/datalog.rs @@ -0,0 +1,346 @@ +//! [ascent] datalog implementation of analysis. + +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 + +use ascent::lattice::BoundedLattice; +use itertools::Itertools; + +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +use super::value_row::ValueRow; +use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; + +type PV = PartialValue; + +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl Default for Machine { + fn default() -> Self { + Self(Default::default()) + } +} + +impl Machine { + /// Provide initial values for some wires. + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + self.0.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) + /// The context passed in allows interpretation of leaf operations. + pub fn run>( + mut self, + context: C, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let root = context.root(); + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis + // (Consider: for a conditional that selects *either* the unknown input *or* value V, + // analysis must produce Top == we-know-nothing, not `V` !) + let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( + (0..context.signature(root).unwrap_or_default().input_count()).map(IncomingPort::from), + ); + self.0.iter().for_each(|(n, p, _)| { + if n == &root { + need_inputs.remove(p); + } + }); + for p in need_inputs { + self.0.push((root, p, PartialValue::Top)); + } + // Note/TODO, if analysis is running on a subregion then we should do similar + // for any nonlocal edges providing values from outside the region. + run_datalog(context, self.0) + } +} + +pub(super) fn run_datalog>( + ctx: C, + in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, +) -> AnalysisResults { + // ascent-(macro-)generated code generates a bunch of warnings, + // keep code in here to a minimum. + #![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if + )] + let all_results = ascent::ascent_run! { + pub(super) struct AscentProgram; + relation node(Node); // exists in the hugr + relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` + relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` + relation parent_of_node(Node, Node); // is parent of + relation input_child(Node, Node); // has 1st child that is its `Input` + relation output_child(Node, Node); // has 2nd child that is its `Output` + lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(Node, ValueRow); // 's inputs are + + node(n) <-- for n in ctx.nodes(); + + in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) + + parent_of_node(parent, child) <-- + node(child), if let Some(parent) = ctx.get_parent(*child); + + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); + + // Initialize all wires to bottom + out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + + // Outputs to inputs + in_wire_value(n, ip, v) <-- in_wire(n, ip), + if let Some((m, op)) = ctx.single_linked_output(*n, *ip), + out_wire_value(m, op, v); + + // Prepopulate in_wire_value from in_wire_value_proto. + in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + node(n), + if let Some(sig) = ctx.signature(*n), + if sig.input_ports().contains(p); + + // Assemble node_in_value_row from in_wire_value's + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); + node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + + // Interpret leaf ops + out_wire_value(n, p, v) <-- + node(n), + let op_t = ctx.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), + node_in_value_row(n, vs), + if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), + for (p, v) in (0..).map(OutgoingPort::from).zip(outs); + + // DFG -------------------- + relation dfg_node(Node); // is a `DFG` + dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); + + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + input_child(dfg, i), in_wire_value(dfg, p, v); + + out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + output_child(dfg, o), in_wire_value(o, p, v); + + // TailLoop -------------------- + // inputs of tail loop propagate to Input node of child region + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), + if ctx.get_optype(*tl).is_tail_loop(), + input_child(tl, i), + in_wire_value(tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + input_child(tl, in_n), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ...and select just what's possible for CONTINUE_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + for (out_p, v) in fields.enumerate(); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ... and select just what's possible for BREAK_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + for (out_p, v) in fields.enumerate(); + + // Conditional -------------------- + // is a `Conditional` and its 'th child (a `Case`) is : + relation case_node(Node, usize, Node); + case_node(cond, i, case) <-- node(cond), + if ctx.get_optype(*cond).is_conditional(), + for (i, case) in ctx.children(*cond).enumerate(), + if ctx.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- + case_node(cond, case_index, case), + input_child(case, i_node), + node_in_value_row(cond, in_row), + let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (out_p, v) in fields.enumerate(); + + // outputs of case nodes propagate to outputs of conditional *if* case reachable + out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(cond, _i, case), + case_reachable(cond, case), + output_child(case, o), + in_wire_value(o, o_p, v); + + // In `Conditional` , child `Case` is reachable given our knowledge of predicate: + relation case_reachable(Node, Node); + case_reachable(cond, case) <-- case_node(cond, i, case), + in_wire_value(cond, IncomingPort::from(0), v), + if v.supports_tag(*i); + + // CFG -------------------- + relation cfg_node(Node); // is a `CFG` + cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); + + // In `CFG` , basic block is reachable given our knowledge of predicates: + relation bb_reachable(Node, Node); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); + bb_reachable(cfg, bb) <-- cfg_node(cfg), + bb_reachable(cfg, pred), + output_child(pred, pred_out), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + + // Inputs of CFG propagate to entry block + out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(cfg), + if let Some(entry) = ctx.children(*cfg).next(), + input_child(entry, i_node), + in_wire_value(cfg, p, v); + + // In `CFG` , values fed along a control-flow edge to + // come out of Value outports of : + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in ctx.children(*cfg), + if ctx.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); + + // Outputs of each reachable block propagated to successor block or CFG itself + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- + bb_reachable(cfg, pred), + if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), + output_child(pred, out_n), + _cfg_succ_dest(cfg, succ, dest), + node_in_value_row(out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in fields.enumerate(); + + // Call -------------------- + relation func_call(Node, Node); // is a `Call` to `FuncDefn` + func_call(call, func_defn) <-- + node(call), + if ctx.get_optype(*call).is_call(), + if let Some(func_defn) = ctx.static_source(*call); + + out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + input_child(func, inp), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + output_child(func, outp), + in_wire_value(outp, p, v); + }; + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + ctx, + out_wire_values, + in_wire_value: all_results.in_wire_value, + case_reachable: all_results.case_reachable, + bb_reachable: all_results.bb_reachable, + } +} + +fn propagate_leaf_op( + ctx: &impl DFContext, + n: Node, + ins: &[PV], + num_outs: usize, +) -> Option> { + match ctx.get_optype(n) { + // Handle basics here. We could instead leave these to DFContext, + // but at least we'd want these impls to be easily reusable. + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( + 0, + ins.iter().cloned(), + )])), + op if op.cast::().is_some() => { + let elem_tys = op.cast::().unwrap().0; + let tup = ins.iter().exactly_one().unwrap(); + tup.variant_values(0, elem_tys.len()) + .map(ValueRow::from_iter) + } + OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( + t.tag, + ins.iter().cloned(), + )])), + OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Call(_) => None, // handled via Input/Output of FuncDefn + OpType::Const(_) => None, // handled by LoadConstant: + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = ctx + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); + Some(ValueRow::single_known( + 1, + 0, + ctx.value_from_const(n, const_val), + )) + } + OpType::LoadFunction(load_op) => { + assert!(ins.is_empty()); // static edge + let func_node = ctx + .single_linked_output(n, load_op.function_port()) + .unwrap() + .0; + // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself + Some(ValueRow::single_known( + 1, + 0, + ctx.value_from_function(func_node, &load_op.type_args) + .map_or(PV::Top, PV::Value), + )) + } + OpType::ExtensionOp(e) => { + // Interpret op using DFContext + let init = if ins.iter().contains(&PartialValue::Bottom) { + // So far we think one or more inputs can't happen. + // So, don't pollute outputs with Top, and wait for better knowledge of inputs. + PartialValue::Bottom + } else { + // If we can't figure out anything about the outputs, assume nothing (they still happen!) + PartialValue::Top + }; + let mut outs = vec![init; num_outs]; + // It might be nice to convert these to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + Some(ValueRow::from_iter(outs)) + } + o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + } +} diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs new file mode 100644 index 000000000..992a72444 --- /dev/null +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -0,0 +1,720 @@ +use ascent::lattice::BoundedLattice; +use ascent::Lattice; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use thiserror::Error; + +/// Trait for an underlying domain of abstract values which can form the *elements* of a +/// [PartialValue] and thus be used in dataflow analysis. +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// Computes the join of two values (i.e. towards `Top``), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Top]). + /// + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_join(self, other: Self) -> Option { + (self == other).then_some(self) + } + + /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Bottom]). + /// + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_meet(self, other: Self) -> Option { + (self == other).then_some(self) + } +} + +/// Represents a sum with a single/known tag, abstracted over the representation of the elements. +/// (Identical to [Sum](hugr_core::ops::constant::Sum) except for the type abstraction.) +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub st: SumType, +} + +/// A representation of a value of [SumType], that may have one or more possible tags, +/// with a [PartialValue] representation of each element-value of each possible tag. +#[derive(PartialEq, Clone, Eq)] +pub struct PartialSum(pub HashMap>>); + +impl PartialSum { + /// New instance for a single known tag. + /// (Multi-tag instances can be created via [Self::try_join_mut].) + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + Self(HashMap::from([(tag, Vec::from_iter(values))])) + } + + /// The number of possible variants we know about. (NOT the number + /// of tags possible for the value's type, whatever [SumType] that might be.) + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_invariants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } +} + +impl PartialSum { + /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns + /// whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths. + pub fn try_join_mut(&mut self, other: Self) -> Result { + for (k, v) in &other.0 { + if self.0.get(k).is_some_and(|row| row.len() != v.len()) { + return Err(*k); + } + } + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + Ok(changed) + } + + /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, + /// returns whether `self` has changed. + /// + /// # Errors + /// Fails without mutation, either: + /// * `Some(tag)` if the two [PartialSum]s both had rows with that `tag` but of different lengths + /// * `None` if the two instances had no rows in common (i.e., the result is "Bottom") + pub fn try_meet_mut(&mut self, other: Self) -> Result> { + let mut changed = false; + let mut keys_to_remove = vec![]; + for (k, v) in self.0.iter() { + match other.0.get(k) { + None => keys_to_remove.push(*k), + Some(o_v) => { + if v.len() != o_v.len() { + return Err(Some(*k)); + } + } + } + } + if keys_to_remove.len() == self.0.len() { + return Err(None); + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + Ok(changed) + } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Turns this instance into a [Sum] of some "concrete" value type `V2`, + /// *if* this PartialSum has exactly one possible tag. + /// + /// # Errors + /// + /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] + /// supporting the single possible tag with the correct number of elements and no row variables; + /// or if converting a child element failed via [PartialValue::try_into_value]. + pub fn try_into_value( + self, + typ: &Type, + ) -> Result, ExtractValueError> + where + V2: TryFrom + TryFrom, Error = SE>, + { + let Ok((k, v)) = self.0.iter().exactly_one() else { + return Err(ExtractValueError::MultipleVariants(self)); + }; + if let TypeEnum::Sum(st) = typ.as_type_enum() { + if let Some(r) = st.get_variant(*k) { + if let Ok(r) = TypeRow::try_from(r.clone()) { + if v.len() == r.len() { + return Ok(Sum { + tag: *k, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>()?, + st: st.clone(), + }); + } + } + } + } + Err(ExtractValueError::BadSumType { + typ: typ.clone(), + tag: *k, + num_elements: v.len(), + }) + } +} + +/// An error converting a [PartialValue] or [PartialSum] into a concrete value type +/// via [PartialValue::try_into_value] or [PartialSum::try_into_value] +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[allow(missing_docs)] +pub enum ExtractValueError { + #[error("PartialSum value had multiple possible tags: {0}")] + MultipleVariants(PartialSum), + #[error("Value contained `Bottom`")] + ValueIsBottom, + #[error("Value contained `Top`")] + ValueIsTop, + #[error("Could not convert element from abstract value into concrete: {0}")] + CouldNotConvert(V, #[source] VE), + #[error("Could not build Sum from concrete element values")] + CouldNotBuildSum(#[source] SE), + #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] + BadSumType { + typ: Type, + tag: usize, + num_elements: usize, + }, +} + +impl PartialSum { + /// If this Sum might have the specified `tag`, get the elements inside that tag. + pub fn variant_values(&self, variant: usize) -> Option>> { + self.0.get(&variant).cloned() + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + Some(match keys1.cmp(&keys2) { + ord @ Ordering::Greater | ord @ Ordering::Less => ord, + Ordering::Equal => { + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(k) else { + unreachable!() + }; + let key_cmp = lhs.partial_cmp(rhs); + if key_cmp != Some(Ordering::Equal) { + return key_cmp; + } + } + Ordering::Equal + } + }) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + /// No possibilities known (so far) + Bottom, + /// A single value (of the underlying representation) + Value(V), + /// Sum (with at least one, perhaps several, possible tags) of underlying values + PartialSum(PartialSum), + /// Might be more than one distinct value of the underlying type `V` + Top, +} + +impl From for PartialValue { + fn from(v: V) -> Self { + Self::Value(v) + } +} + +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + fn assert_invariants(&self) { + if let Self::PartialSum(ps) = self { + ps.assert_invariants(); + } + } + + /// New instance of a sum with a single known tag. + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() + } + + /// New instance of unit type (i.e. the only possible value, with no contents) + pub fn new_unit() -> Self { + Self::new_variant(0, []) + } +} + +impl PartialValue { + /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. + /// + /// # Panics + /// + /// if the value is believed, for that tag, to have a number of values other than `len` + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + let vals = match self { + PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::PartialSum(ps) => ps.variant_values(tag)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, + /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by + /// [PartialSum::try_into_value]. + /// + /// # Errors + /// + /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) + /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is + /// incorrect), or if that [Sum] could not be converted into a `V2`. + pub fn try_into_value(self, typ: &Type) -> Result> + where + V2: TryFrom + TryFrom, Error = SE>, + { + match self { + Self::Value(v) => V2::try_from(v.clone()) + .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), + Self::PartialSum(ps) => { + let v = ps.try_into_value(typ)?; + V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) + } + Self::Top => Err(ExtractValueError::ValueIsTop), + Self::Bottom => Err(ExtractValueError::ValueIsBottom), + } + } +} + +impl TryFrom> for Value { + type Error = ConstTypeError; + + fn try_from(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } +} + +impl Lattice for PartialValue { + fn join_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + match (&*self, other) { + (Self::Top, _) => false, + (_, other @ Self::Top) => { + *self = other; + true + } + (_, Self::Bottom) => false, + (Self::Bottom, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { + *self = Self::Top; + true + } + }, + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + match ps1.try_join_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Top; + true + } + } + } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + *self = Self::Top; + true + } + } + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + match (&*self, other) { + (Self::Bottom, _) => false, + (_, other @ Self::Bottom) => { + *self = other; + true + } + (_, Self::Top) => false, + (Self::Top, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { + *self = Self::Bottom; + true + } + }, + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + match ps1.try_meet_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Bottom; + true + } + } + } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + *self = Self::Bottom; + true + } + } + } +} + +impl BoundedLattice for PartialValue { + fn top() -> Self { + Self::Top + } + + fn bottom() -> Self { + Self::Bottom + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use ascent::{lattice::BoundedLattice, Lattice}; + use itertools::{zip_eq, Itertools as _}; + use prop::sample::subsequence; + use proptest::prelude::*; + + use proptest_recurse::{StrategyExt, StrategySet}; + + use super::{AbstractValue, PartialSum, PartialValue}; + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(Vec>>), + /// None => unit, Some => TestValue <= this *usize* + Leaf(Option), + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestValue(usize); + + impl AbstractValue for TestValue {} + + #[derive(Clone)] + struct SumTypeParams { + depth: usize, + desired_size: usize, + expected_branch_size: usize, + } + + impl Default for SumTypeParams { + fn default() -> Self { + Self { + depth: 5, + desired_size: 20, + expected_branch_size: 5, + } + } + } + + impl TestSumType { + fn check_value(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), + (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.check_value(rhs)) { + return false; + } + } + true + } + _ => false, + } + } + } + + impl Arbitrary for TestSumType { + type Parameters = SumTypeParams; + type Strategy = SBoxedStrategy; + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { + use proptest::collection::vec; + let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + leaf_strat.prop_mutually_recursive( + params.depth as u32, + params.desired_size as u32, + params.expected_branch_size as u32, + set, + move |set| { + let params2 = params.clone(); + vec( + vec( + set.get::(move |set| arb(params2, set)) + .prop_map(Arc::new), + 1..=params.expected_branch_size, + ), + 1..=params.expected_branch_size, + ) + .prop_map(TestSumType::Branch) + .sboxed() + }, + ) + } + + arb(params, &mut StrategySet::default()) + } + } + + fn single_sum_strat( + tag: usize, + elems: Vec>, + ) -> impl Strategy> { + elems + .iter() + .map(Arc::as_ref) + .map(any_partial_value_of_type) + .collect::>() + .prop_map(move |elems| PartialSum::new_variant(tag, elems)) + } + + fn partial_sum_strat( + variants: &[Vec>], + ) -> impl Strategy> { + // We have to clone the `variants` here but only as far as the Vec>> + let tagged_variants = variants.iter().cloned().enumerate().collect::>(); + // The type annotation here (and the .boxed() enabling it) are just for documentation + let sum_variants_strat: BoxedStrategy>> = + subsequence(tagged_variants, 1..=variants.len()) + .prop_flat_map(|selected_variants| { + selected_variants + .into_iter() + .map(|(tag, elems)| single_sum_strat(tag, elems)) + .collect::>() + }) + .boxed(); + sum_variants_strat.prop_map(|psums: Vec>| { + let mut psums = psums.into_iter(); + let first = psums.next().unwrap(); + psums.fold(first, |mut a, b| { + a.try_join_mut(b).unwrap(); + a + }) + }) + } + + fn any_partial_value_of_type( + ust: &TestSumType, + ) -> impl Strategy> { + match ust { + TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), + TestSumType::Leaf(Some(i)) => (0..*i) + .prop_map(TestValue) + .prop_map(PartialValue::from) + .boxed(), + TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), + } + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy> { + any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) + } + + fn any_partial_value() -> impl Strategy> { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy; N]> { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(&ust)) + .collect_vec(), + ) + .unwrap() + }) + } + + fn any_typed_partial_value() -> impl Strategy)> { + any::() + .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) + } + + proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.check_value(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet == v2.clone().meet(v1.clone()), "meet not symmetric"); + prop_assert!(meet == meet.clone().meet(v1.clone()), "repeated meet should be a no-op"); + prop_assert!(meet == meet.clone().meet(v2.clone()), "repeated meet should be a no-op"); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + prop_assert!(join == v2.clone().join(v1.clone()), "join not symmetric"); + prop_assert!(join == join.clone().join(v1.clone()), "repeated join should be a no-op"); + prop_assert!(join == join.clone().join(v2.clone()), "repeated join should be a no-op"); + } + + #[test] + fn lattice_associative([v1, v2, v3] in any_partial_values()) { + let a = v1.clone().meet(v2.clone()).meet(v3.clone()); + let b = v1.clone().meet(v2.clone().meet(v3.clone())); + prop_assert!(a==b, "meet not associative"); + + let a = v1.clone().join(v2.clone()).join(v3.clone()); + let b = v1.clone().join(v2.clone().join(v3.clone())); + prop_assert!(a==b, "join not associative") + } + } +} diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs new file mode 100644 index 000000000..6c90e33b3 --- /dev/null +++ b/hugr-passes/src/dataflow/results.rs @@ -0,0 +1,137 @@ +use std::collections::HashMap; + +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; + +use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue, Sum}; + +/// Results of a dataflow analysis, packaged with the Hugr for easy inspection. +/// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). +pub struct AnalysisResults> { + pub(super) ctx: C, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, +} + +impl> AnalysisResults { + /// Allows to use the [HugrView] contained within + pub fn hugr(&self) -> &C::View { + &self.ctx + } + + /// Gets the lattice value computed for the given wire + pub fn read_out_wire(&self, w: Wire) -> Option> { + self.out_wire_values.get(&w).cloned() + } + + /// Tells whether a [TailLoop] node can terminate, i.e. whether + /// `Break` and/or `Continue` tags may be returned by the nested DFG. + /// Returns `None` if the specified `node` is not a [TailLoop]. + /// + /// [TailLoop]: hugr_core::ops::TailLoop + pub fn tail_loop_terminates(&self, node: Node) -> Option { + self.hugr().get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr().get_io(node).unwrap(); + Some(TailLoopTermination::from_control_value( + self.in_wire_value + .iter() + .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + )) + } + + /// Tells whether a [Case] node is reachable, i.e. whether the predicate + /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. + /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] + /// (e.g. a [Case]-rooted Hugr). + /// + /// [Case]: hugr_core::ops::Case + /// [Conditional]: hugr_core::ops::Conditional + pub fn case_reachable(&self, case: Node) -> Option { + self.hugr().get_optype(case).as_case()?; + let cond = self.hugr().get_parent(case)?; + self.hugr().get_optype(cond).as_conditional()?; + Some( + self.case_reachable + .iter() + .any(|(cond2, case2)| &cond == cond2 && &case == case2), + ) + } + + /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known + /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + /// + /// [CFG]: hugr_core::ops::CFG + /// [DataflowBlock]: hugr_core::ops::DataflowBlock + /// [ExitBlock]: hugr_core::ops::ExitBlock + pub fn bb_reachable(&self, bb: Node) -> Option { + let cfg = self.hugr().get_parent(bb)?; // Not really required...?? + self.hugr().get_optype(cfg).as_cfg()?; + let t = self.hugr().get_optype(bb); + if !t.is_dataflow_block() && !t.is_exit_block() { + return None; + }; + Some( + self.bb_reachable + .iter() + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), + ) + } + + /// Reads a concrete representation of the value on an output wire, if the lattice value + /// computed for the wire can be turned into such. (The lattice value must be either a + /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) + /// + /// # Errors + /// `None` if the analysis did not produce a result for that wire + /// `Some(e)` if conversion to a concrete value failed with error `e`, see [PartialValue::try_into_value] + /// + /// # Panics + /// + /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr + pub fn try_read_wire_value( + &self, + w: Wire, + ) -> Result>> + where + V2: TryFrom + TryFrom, Error = SE>, + { + let v = self.read_out_wire(w).ok_or(None)?; + let (_, typ) = self + .hugr() + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + v.try_into_value(&typ).map_err(Some) + } +} + +/// Tells whether a loop iterates (never, always, sometimes) +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum TailLoopTermination { + /// The loop never exits (is an infinite loop); no value is ever + /// returned out of the loop. (aka, Bottom.) + // TODO what about a loop that never exits OR continues because of a nested infinite loop? + NeverBreaks, + /// The loop never iterates (so is equivalent to a [DFG](hugr_core::ops::DFG), + /// modulo untupling of the control value) + NeverContinues, + /// The loop might iterate and/or exit. (aka, Top) + BreaksAndContinues, +} + +impl TailLoopTermination { + fn from_control_value(v: &PartialValue) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break { + if may_continue { + Self::BreaksAndContinues + } else { + Self::NeverContinues + } + } else { + Self::NeverBreaks + } + } +} diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs new file mode 100644 index 000000000..c0fbf395a --- /dev/null +++ b/hugr-passes/src/dataflow/test.rs @@ -0,0 +1,485 @@ +use ascent::{lattice::BoundedLattice, Lattice}; + +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::handle::DfgID; +use hugr_core::ops::TailLoop; +use hugr_core::{ + builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + extension::{ + prelude::{UnpackTuple, BOOL_T}, + ExtensionSet, EMPTY_REG, + }, + ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, + type_row, + types::{Signature, SumType, Type}, + HugrView, +}; +use hugr_core::{Hugr, Wire}; +use rstest::{fixture, rstest}; + +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; + +// ------- Minimal implementation of DFContext and AbstractValue ------- +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Void {} + +impl AbstractValue for Void {} + +struct TestContext(H); + +impl std::ops::Deref for TestContext { + type Target = H; + fn deref(&self) -> &H { + &self.0 + } +} +impl ConstLoader for TestContext {} +impl DFContext for TestContext { + type View = H; +} + +// This allows testing creation of tuple/sum Values (only) +impl From for Value { + fn from(v: Void) -> Self { + match v {} + } +} + +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} + +fn pv_true_or_false() -> PartialValue { + pv_true().join(pv_false()) +} + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::default().run(TestContext(hugr), []); + + let x: Value = results.try_read_wire_value(v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple_const() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::default().run(TestContext(hugr), []); + + let o1_r: Value = results.try_read_wire_value(o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r: Value = results.try_read_wire_value(o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value(r_v.clone()); + let tag = Tag::new( + TailLoop::BREAK_TAG, + vec![type_row![], r_v.get_type().into()], + ); + let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); + + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::default().run(TestContext(hugr), []); + + let o_r: Value = results.try_read_wire_value(tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + Some(TailLoopTermination::NeverContinues), + results.tail_loop_terminates(tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + TailLoop::CONTINUE_TAG, + [], + SumType::new([type_row![], BOOL_T.into()]), + ) + .unwrap(), + ); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(BOOL_T, true_w)], vec![BOOL_T].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::default().run(TestContext(&hugr), []); + + let o_r1 = results.read_out_wire(tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = results.read_out_wire(tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + Some(TailLoopTermination::NeverBreaks), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_two_iters() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let tlb = builder + .tail_loop_builder_exts( + [], + [(BOOL_T, false_w), (BOOL_T, true_w)], + type_row![], + ExtensionSet::new(), + ) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().signature(), + Signature::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::default().run(TestContext(&hugr), []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true_or_false()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_true_or_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_containing_conditional() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let control_variants = vec![type_row![BOOL_T;2]; 2]; + let control_t = Type::new_sum(control_variants.clone()); + let body_out_variants = vec![control_t.clone().into(), type_row![BOOL_T; 2]]; + + let init = builder.add_load_value( + Value::sum( + 0, + [Value::false_val(), Value::true_val()], + SumType::new(control_variants.clone()), + ) + .unwrap(), + ); + + let mut tlb = builder + .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) + .unwrap(); + let tl = tlb.loop_signature().unwrap().clone(); + let [in_w] = tlb.input_wires_arr(); + + // Branch on in_wire, so first iter 0(false, true)... + let mut cond = tlb + .conditional_builder( + (control_variants.clone(), in_w), + [], + Type::new_sum(body_out_variants.clone()).into(), + ) + .unwrap(); + let mut case0_b = cond.case_builder(0).unwrap(); + let [a, b] = case0_b.input_wires_arr(); + // Builds value for next iter as 1(true, false) by flipping arguments + let [next_input] = case0_b + .add_dataflow_op(Tag::new(1, control_variants), [b, a]) + .unwrap() + .outputs_arr(); + let cont = case0_b.make_continue(tl.clone(), [next_input]).unwrap(); + case0_b.finish_with_outputs([cont]).unwrap(); + // Second iter 1(true, false) => exit with (true, false) + let mut case1_b = cond.case_builder(1).unwrap(); + let loop_res = case1_b.make_break(tl, case1_b.input_wires()).unwrap(); + case1_b.finish_with_outputs([loop_res]).unwrap(); + let [r] = cond.finish_sub_container().unwrap().outputs_arr(); + + let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::default().run(TestContext(&hugr), []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_conditional() { + let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder + .conditional_builder( + (variants, arg_w), + [(BOOL_T, true_w)], + type_row!(BOOL_T, BOOL_T), + ) + .unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w, false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1, _c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1, cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( + 2, + [PartialValue::new_variant(0, [])], + )); + let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); + + let cond_r1: Value = results.try_read_wire_value(cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(results.try_read_wire_value::(cond_o2).is_err()); + + assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(results.case_reachable(case2.node()), Some(true)); + assert_eq!(results.case_reachable(case3.node()), Some(true)); + assert_eq!(results.case_reachable(cond.node()), None); +} + +// A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) +#[fixture] +fn xor_and_cfg() -> Hugr { + // Entry branch on first arg, passes arguments on unchanged + // /T F\ + // A --T-> B A(x=true, y) branch on second arg, passing (first arg == true, false) + // \F / B(w,v) => X(v,w) + // > X < + // Inputs received: + // Entry A B X + // F,F - F,F F,F + // F,T - F,T T,F + // T,F T,F - T,F + // T,T T,T T,F F,T + let mut builder = + CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); + + // entry (x, y) => (if x then A else B)(x=true, y) + let entry = builder + .entry_builder(vec![type_row![]; 2], type_row![BOOL_T;2]) + .unwrap(); + let [in_x, in_y] = entry.input_wires_arr(); + let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); + + // A(x==true, y) => (if y then B else X)(x, false) + let mut a = builder + .block_builder( + type_row![BOOL_T; 2], + vec![type_row![]; 2], + type_row![BOOL_T; 2], + ) + .unwrap(); + let [in_x, in_y] = a.input_wires_arr(); + let false_w1 = a.add_load_value(Value::false_val()); + let a = a.finish_with_outputs(in_y, [in_x, false_w1]).unwrap(); + + // B(w, v) => X(v, w) + let mut b = builder + .block_builder(type_row![BOOL_T; 2], [type_row![]], type_row![BOOL_T; 2]) + .unwrap(); + let [in_w, in_v] = b.input_wires_arr(); + let [control] = b + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) + .unwrap() + .outputs_arr(); + let b = b.finish_with_outputs(control, [in_v, in_w]).unwrap(); + + let x = builder.exit_block(); + + let [fals, tru]: [usize; 2] = [0, 1]; + builder.branch(&entry, tru, &a).unwrap(); // if true + builder.branch(&entry, fals, &b).unwrap(); // if false + builder.branch(&a, tru, &b).unwrap(); // if true + builder.branch(&a, fals, &x).unwrap(); // if false + builder.branch(&b, 0, &x).unwrap(); + builder.finish_hugr(&EMPTY_REG).unwrap() +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_false(), pv_true())] +#[case(pv_true(), pv_false(), pv_true(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), PartialValue::Top, pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_false(), pv_false())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] +#[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 +#[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, PartialValue::Top)] +fn test_cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out0: PartialValue, + #[case] out1: PartialValue, + xor_and_cfg: Hugr, +) { + let root = xor_and_cfg.root(); + let results = Machine::default().run( + TestContext(xor_and_cfg), + [(0.into(), inp0), (1.into(), inp1)], + ); + + assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); + assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_false(), pv_false(), pv_false())] +#[case(pv_true(), pv_false(), pv_true_or_false())] // Two calls alias +fn test_call( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out: PartialValue, +) { + let mut builder = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T; 2])).unwrap(); + let func_bldr = builder + .define_function("id", Signature::new_endo(BOOL_T)) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let [a, b] = builder.input_wires_arr(); + let [a2] = builder + .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let [b2] = builder + .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let hugr = builder + .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) + .unwrap(); + + let results = Machine::default().run(TestContext(&hugr), [(0.into(), inp0), (1.into(), inp1)]); + + let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + // The two calls alias so both results will be the same: + assert_eq!(res0, out); + assert_eq!(res1, out); +} + +#[test] +fn test_region() { + let mut builder = + DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T;2])).unwrap(); + let [in_w] = builder.input_wires_arr(); + let cst_w = builder.add_load_const(Value::false_val()); + let nested = builder + .dfg_builder(Signature::new_endo(type_row![BOOL_T; 2]), [in_w, cst_w]) + .unwrap(); + let nested_ins = nested.input_wires(); + let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let hugr = builder + .finish_prelude_hugr_with_outputs(nested.outputs()) + .unwrap(); + let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); + let whole_hugr_results = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(pv_false()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 1)), + Some(pv_false()) + ); + + let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); + // Do not provide a value on the second input (constant false in the whole hugr, above) + let sub_hugr_results = Machine::default().run(TestContext(subview), [(0.into(), pv_true())]); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(PartialValue::Top) + ); + for w in [0, 1] { + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(hugr.root(), w)), + None + ); + } +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs new file mode 100644 index 000000000..0d8bc15a6 --- /dev/null +++ b/hugr-passes/src/dataflow/value_row.rs @@ -0,0 +1,101 @@ +// Wrap a (known-length) row of values into a lattice. + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::{lattice::BoundedLattice, Lattice}; +use itertools::zip_eq; + +use super::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Debug, Eq, Hash)] +pub(super) struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) + } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 13dd47776..06781f7c5 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod dataflow; pub mod force_order; mod half_node; pub mod lower;