Skip to content

Commit

Permalink
fix: const-folding Module keeps at least "main" (#1901)
Browse files Browse the repository at this point in the history
Minimal, non-breaking, fix for #1797, this seems consistent with what
dataflow analysis does.
  • Loading branch information
acl-cqc authored Feb 4, 2025
1 parent d6b8681 commit 7312c62
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 11 deletions.
31 changes: 22 additions & 9 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use hugr_core::{
},
ops::{
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
Value,
OpType, Value,
},
types::{EdgeKind, TypeArg},
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
Expand Down Expand Up @@ -88,8 +88,7 @@ impl ConstantFoldPass {
});

let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs);
let mut keep_nodes = HashSet::new();
self.find_needed_nodes(&results, &mut keep_nodes);
let keep_nodes = self.find_needed_nodes(&results);
let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i);

let remove_nodes = hugr
Expand Down Expand Up @@ -145,17 +144,30 @@ impl ConstantFoldPass {
fn find_needed_nodes<H: HugrView>(
&self,
results: &AnalysisResults<ValueHandle, H>,
needed: &mut HashSet<Node>,
) {
let mut q = VecDeque::new();
) -> HashSet<Node> {
let mut needed = HashSet::new();
let h = results.hugr();
q.push_back(h.root());
let mut q = VecDeque::from_iter([h.root()]);
while let Some(n) = q.pop_front() {
if !needed.insert(n) {
continue;
};

if h.get_optype(n).is_cfg() {
if h.get_optype(n).is_module() {
for ch in h.children(n) {
match h.get_optype(ch) {
OpType::AliasDecl(_) | OpType::AliasDefn(_) => {
// Use of these is done via names, rather than following edges.
// We could track these as well but for now be conservative.
q.push_back(ch);
}
OpType::FuncDefn(f) if f.name == "main" => {
// Dataflow analysis will have applied any inputs the 'main' function, so assume reachable.
q.push_back(ch);
}
_ => (),
}
}
} else if h.get_optype(n).is_cfg() {
for bb in h.children(n) {
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
q.push_back(bb);
Expand Down Expand Up @@ -192,6 +204,7 @@ impl ConstantFoldPass {
}
}
}
needed
}
}

Expand Down
54 changes: 52 additions & 2 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::collections::hash_map::RandomState;
use std::collections::HashSet;

use hugr_core::ops::handle::NodeHandle;
use hugr_core::ops::Const;
use hugr_core::std_extensions::arithmetic::{int_ops, int_types};
use itertools::Itertools;
use lazy_static::lazy_static;
use rstest::rstest;

use hugr_core::builder::{
endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
SubContainer,
HugrBuilder, ModuleBuilder, SubContainer,
};
use hugr_core::extension::prelude::{
bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple,
Expand All @@ -25,7 +28,7 @@ use hugr_core::std_extensions::arithmetic::{
int_types::{ConstInt, INT_TYPES},
};
use hugr_core::std_extensions::logic::LogicOp;
use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV};
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};

use crate::dataflow::{partial_from_const, DFContext, PartialValue};
Expand Down Expand Up @@ -1580,3 +1583,50 @@ fn test_cfg(
assert_eq!(output_src, nested);
}
}

#[test]
fn test_module() -> Result<(), Box<dyn std::error::Error>> {
let mut mb = ModuleBuilder::new();
// Define a top-level constant, (only) the second of which can be removed
let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?));
let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?));
let ad1 = mb.add_alias_declare("unused", TypeBound::Any)?;
let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?;
let mut main = mb.define_function(
"main",
Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2])
.with_extension_delta(int_types::EXTENSION_ID)
.with_extension_delta(int_ops::EXTENSION_ID),
)?;
let lc7 = main.load_const(&c7);
let lc17 = main.load_const(&c17);
let [add] = main
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [lc7, lc17])?
.outputs_arr();
let main = main.finish_with_outputs([lc7, add])?;
let mut hugr = mb.finish_hugr()?;
constant_fold_pass(&mut hugr);
assert!(hugr.get_optype(hugr.root()).is_module());
assert_eq!(
hugr.children(hugr.root()).collect_vec(),
[c7.node(), ad1.node(), ad2.node(), main.node()]
);
let tags = hugr
.children(main.node())
.map(|n| hugr.get_optype(n).tag())
.collect_vec();
for (tag, expected_count) in [
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::Const, 1),
(OpTag::LoadConst, 2),
] {
assert_eq!(tags.iter().filter(|t| **t == tag).count(), expected_count);
}
assert_eq!(
hugr.children(main.node())
.find_map(|n| hugr.get_optype(n).as_const()),
Some(&Const::new(ConstInt::new_u(5, 24).unwrap().into()))
);
Ok(())
}

0 comments on commit 7312c62

Please sign in to comment.