Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: erratic stack overflow in infer.rs (live_var) #638

Merged
merged 7 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 86 additions & 44 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::validate::ExtensionError;

use petgraph::graph as pg;

use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet, VecDeque};

use thiserror::Error;

Expand Down Expand Up @@ -521,57 +521,46 @@ impl UnificationContext {
pub fn results(&self) -> Result<ExtensionSolution, InferExtensionError> {
// Check that all of the metavariables associated with nodes of the
// graph are solved
let depended_upon = {
let mut h: HashMap<Meta, Vec<Meta>> = HashMap::new();
for (m, m2) in self.constraints.iter().flat_map(|(m, cs)| {
cs.iter().flat_map(|c| match c {
Constraint::Plus(_, m2) => Some((*m, self.resolve(*m2))),
_ => None,
})
}) {
h.entry(m2).or_default().push(m);
}
h
};
// Calculate everything dependent upon a variable.
// Note it would be better to find metas ALL of whose dependencies were (transitively)
// on variables, but this is more complex, and hard to define if there are cycles
// of PLUS constraints, so leaving that as a TODO until we've handled such cycles.
let mut depends_on_var = HashSet::new();
let mut queue = VecDeque::from_iter(self.variables.iter());
while let Some(m) = queue.pop_front() {
if depends_on_var.insert(m) {
if let Some(d) = depended_upon.get(m) {
queue.extend(d.iter())
}
}
}

let mut results: ExtensionSolution = HashMap::new();
for (loc, meta) in self.extensions.iter() {
if let Some(rs) = self.get_solution(meta) {
if loc.1 == Direction::Incoming {
results.insert(loc.0, rs.clone());
}
} else if self.live_var(meta).is_some() {
// If it depends on some other live meta, that's bad news.
return Err(InferExtensionError::Unsolved { location: *loc });
}
// If it only depends on graph variables, then we don't have
// a *solution*, but it's fine
}
debug_assert!(self.live_metas().is_empty());
Ok(results)
}

// Get the live var associated with a meta.
// TODO: This should really be a list
fn live_var(&self, m: &Meta) -> Option<Meta> {
if self.variables.contains(m) || self.variables.contains(&self.resolve(*m)) {
return None;
}

// TODO: We should be doing something to ensure that these are the same check...
if self.get_solution(m).is_none() {
if let Some(cs) = self.get_constraints(m) {
for c in cs {
match c {
Constraint::Plus(_, m) => return self.live_var(m),
_ => panic!("we shouldn't be here!"),
}
} else {
// Unsolved nodes must be unsolved because they depend on graph variables.
if !depends_on_var.contains(&self.resolve(*meta)) {
return Err(InferExtensionError::Unsolved { location: *loc });
}
}
Some(*m)
} else {
None
}
}

/// Return the set of "live" metavariables in the context.
/// "Live" here means a metavariable:
/// - Is associated to a location in the graph in `UnifyContext.extensions`
/// - Is still unsolved
/// - Isn't a variable
fn live_metas(&self) -> HashSet<Meta> {
self.extensions
.values()
.filter_map(|m| self.live_var(m))
.filter(|m| !self.variables.contains(m))
.collect()
Ok(results)
}

/// Iterates over a set of metas (the argument) and tries to solve
Expand Down Expand Up @@ -665,12 +654,16 @@ mod test {

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::OpType;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::ops::{LeafOp, OpType};

use crate::type_row;
use crate::types::{FunctionType, Type, TypeRow};

Expand Down Expand Up @@ -1539,4 +1532,53 @@ mod test {

Ok(())
}

/// This was stack-overflowing approx 50% of the time,
/// see https://github.com/CQCL/hugr/issues/633
#[test]
fn plus_on_self() -> Result<(), Box<dyn std::error::Error>> {
let ext = ExtensionId::new("unknown1").unwrap();
let delta = ExtensionSet::singleton(&ext);
let ft = FunctionType::new_linear(type_row![QB_T, QB_T]).with_extension_delta(&delta);
let mut dfg = DFGBuilder::new(ft.clone())?;

// While https://github.com/CQCL-DEV/hugr/issues/388 is unsolved,
// most operations have empty extension_reqs (not including their own extension).
// Define some that do.
let binop: LeafOp = ExternalOp::Opaque(OpaqueOp::new(
ext.clone(),
"2qb_op",
String::new(),
vec![],
Some(ft),
))
.into();
let unary_sig = FunctionType::new_linear(type_row![QB_T])
.with_extension_delta(&ExtensionSet::singleton(&ext));
let unop: LeafOp = ExternalOp::Opaque(OpaqueOp::new(
ext,
"1qb_op",
String::new(),
vec![],
Some(unary_sig),
))
.into();
// Constrain q1,q2 as PLUS(ext1, inputs):
let [q1, q2] = dfg
.add_dataflow_op(binop.clone(), dfg.input_wires())?
.outputs_arr();
// Constrain q1 as PLUS(ext2, q2):
let [q1] = dfg.add_dataflow_op(unop, [q1])?.outputs_arr();
// Constrain q1 as EQUALS(q2) by using both together
dfg.finish_hugr_with_outputs([q1, q2], &PRELUDE_REGISTRY)?;
// The combined q1+q2 variable now has two PLUS constraints - on itself and the inputs.
Ok(())
}

/// [plus_on_self] had about a 50% rate of failing with stack overflow.
/// So if we run 10 times, that would succeed about 1 run in 2^10, i.e. <0.1%
#[test]
fn plus_on_self_10_times() {
[0; 10].iter().for_each(|_| plus_on_self().unwrap())
}
}