Skip to content

Commit

Permalink
refactor: ExtensionSolution only consists of input extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Aug 31, 2023
1 parent 9b66d6d commit 96b8853
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 42 deletions.
44 changes: 19 additions & 25 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use std::collections::{HashMap, HashSet};

use thiserror::Error;

/// A mapping from locations on the hugr to extension requirement sets which
/// have been inferred for them
pub type ExtensionSolution = HashMap<(Node, Direction), ExtensionSet>;
/// A mapping from nodes on the hugr to extension requirement sets which have
/// been inferred for their inputs.
pub type ExtensionSolution = HashMap<Node, ExtensionSet>;

/// Infer extensions for a hugr. This is the main API exposed by this module
///
Expand All @@ -38,9 +38,9 @@ pub fn infer_extensions(
let solution = ctx.main_loop()?;
ctx.instantiate_variables();
let closed_solution = ctx.main_loop()?;
let closure: HashMap<(Node, Direction), ExtensionSet> = closed_solution
let closure: ExtensionSolution = closed_solution
.into_iter()
.filter(|(loc, _)| !solution.contains_key(loc))
.filter(|(node, _)| !solution.contains_key(node))
.collect();
Ok((solution, closure))
}
Expand Down Expand Up @@ -536,7 +536,9 @@ impl UnificationContext {
}
}
}?;
results.insert(*loc, rs);
if loc.1 == Direction::Incoming {
results.insert(loc.0, rs);
}
}
debug_assert!(self.live_metas().is_empty());
Ok(results)
Expand Down Expand Up @@ -735,22 +737,11 @@ mod test {
let (_, closure) = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter(["A".into(), "B".into()]);
let abc = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(
*closure.get(&(hugr.root(), Direction::Incoming)).unwrap(),
empty
);
assert_eq!(
*closure.get(&(hugr.root(), Direction::Outgoing)).unwrap(),
abc
);
assert_eq!(*closure.get(&(mult_c, Direction::Incoming)).unwrap(), ab);
assert_eq!(*closure.get(&(mult_c, Direction::Outgoing)).unwrap(), abc);
assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty);
assert_eq!(*closure.get(&(add_ab, Direction::Outgoing)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty);
assert_eq!(
*closure.get(&(add_b, Direction::Incoming)).unwrap(),
*closure.get(&add_b).unwrap(),
ExtensionSet::singleton(&"A".into())
);
Ok(())
Expand Down Expand Up @@ -837,9 +828,9 @@ mod test {
ctx.add_constraint(ab, Constraint::Plus("A".into(), b));
ctx.add_constraint(ab, Constraint::Plus("B".into(), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming/Outgoing sides of
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
assert_eq!(solution.len(), 2);
assert_eq!(solution.len(), 1);
Ok(())
}

Expand Down Expand Up @@ -983,11 +974,14 @@ mod test {
hugr.connect(lift_node, 0, ochild, 0)?;
hugr.connect(child, 0, output, 0)?;

let (sol, _) = infer_extensions(&hugr)?;
hugr.infer_extensions()?;

// The solution for the const node should be {A, B}!
assert_eq!(
*sol.get(&(const_node, Direction::Outgoing)).unwrap(),
hugr.get_nodetype(const_node)
.signature()
.unwrap()
.output_extensions(),
ExtensionSet::from_iter(["A".into(), "B".into()])
);

Expand Down
18 changes: 12 additions & 6 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ use std::collections::HashMap;

use thiserror::Error;

use super::{ExtensionSet, ExtensionSolution};
use crate::hugr::NodeType;
use crate::{Direction, Hugr, HugrView, Node, Port};

use super::ExtensionSet;

/// Context for validating the extension requirements defined in a Hugr.
#[derive(Debug, Clone, Default)]
pub struct ExtensionValidator {
Expand All @@ -23,10 +22,17 @@ impl ExtensionValidator {
///
/// The `closure` argument is a set of extensions which doesn't actually
/// live on the graph, but is used to close the graph for validation
pub fn new(hugr: &Hugr, closure: HashMap<(Node, Direction), ExtensionSet>) -> Self {
let mut validator = ExtensionValidator {
extensions: closure,
};
pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self {
let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new();
for (node, incoming_sol) in closure.into_iter() {
let op_signature = hugr.get_nodetype(node).op_signature();
let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol);

extensions.insert((node, Direction::Incoming), incoming_sol);
extensions.insert((node, Direction::Outgoing), outgoing_sol);
}

let mut validator = ExtensionValidator { extensions };

for node in hugr.nodes() {
validator.gather_extensions(&node, hugr.get_nodetype(node));
Expand Down
11 changes: 3 additions & 8 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,7 @@ impl Hugr {
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for ((node, _), input_extensions) in solution
.iter()
.filter(|((_, dir), _)| *dir == Direction::Incoming)
{
for (node, input_extensions) in solution.iter() {
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
match &nodetype.input_extensions {
None => nodetype.input_extensions = Some(input_extensions.clone()),
Expand Down Expand Up @@ -504,16 +501,14 @@ mod test {
hugr.infer_extensions()?;

assert_eq!(
hugr.op_types
.get(lift.index)
hugr.get_nodetype(lift)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.op_types
.get(output.index)
hugr.get_nodetype(output)
.signature()
.unwrap()
.input_extensions,
Expand Down
6 changes: 3 additions & 3 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pyo3::prelude::*;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSet, InferExtensionError,
ExtensionRegistry, ExtensionSolution, InferExtensionError,
};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{OpTag, OpTrait, OpType, ValidateOp};
Expand Down Expand Up @@ -53,7 +53,7 @@ impl Hugr {
/// free extension variables
pub fn validate_with_extension_closure(
&self,
closure: HashMap<(Node, Direction), ExtensionSet>,
closure: ExtensionSolution,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, closure, extension_registry);
Expand All @@ -65,7 +65,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
pub fn new(
hugr: &'a Hugr,
extension_closure: HashMap<(Node, Direction), ExtensionSet>,
extension_closure: ExtensionSolution,
extension_registry: &'b ExtensionRegistry,
) -> Self {
Self {
Expand Down

0 comments on commit 96b8853

Please sign in to comment.