Skip to content

Commit

Permalink
Check Const's are well-formed using a constructor (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc authored Jul 31, 2023
1 parent 37da9b7 commit 294ebc1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
6 changes: 6 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::hugr::typecheck::ConstTypeError;
use crate::hugr::{HugrError, Node, ValidationError, Wire};
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::SimpleType;
Expand Down Expand Up @@ -41,6 +42,11 @@ pub enum BuildError {
/// The constructed HUGR is invalid.
#[error("The constructed HUGR is invalid: {0}.")]
InvalidHUGR(#[from] ValidationError),
/// Tried to add a malformed [ConstValue]
///
/// [ConstValue]: crate::ops::constant::ConstValue
#[error("Constant failed typechecking: {0}")]
BadConstant(#[from] ConstTypeError),
/// HUGR construction error.
#[error("Error when mutating HUGR: {0}.")]
ConstructError(#[from] HugrError),
Expand Down
2 changes: 1 addition & 1 deletion src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub trait Container {
/// [`OpType::Const`] node.
fn add_constant(&mut self, val: ConstValue) -> Result<ConstID, BuildError> {
let typ = val.const_type();
let const_n = self.add_child_op(ops::Const(val))?;
let const_n = self.add_child_op(ops::Const::new(val).map_err(BuildError::BadConstant)?)?;

Ok((const_n, typ).into())
}
Expand Down
37 changes: 15 additions & 22 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;

use crate::hugr::typecheck::{typecheck_const, ConstTypeError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::OpTag;
use crate::ops::{self, OpTrait, OpType, ValidateOp};
use crate::ops::{OpTrait, OpType, ValidateOp};
use crate::resource::ResourceSet;
use crate::types::ClassicType;
use crate::types::{EdgeKind, SimpleType};
Expand Down Expand Up @@ -434,21 +433,16 @@ impl<'a> ValidationContext<'a> {
let local = Some(from_parent) == to_parent;

let is_static = match from_optype.port_kind(from_offset).unwrap() {
// Inter-graph constant wires do not have restrictions
EdgeKind::Static(typ) => {
if let OpType::Const(ops::Const(val)) = from_optype {
typecheck_const(&typ, val).map_err(ValidationError::from)?;
} else {
// If const edges aren't coming from const nodes, they're graph
// edges coming from FuncDecl or FuncDefn
if !OpTag::Function.is_superset(from_optype.tag()) {
return Err(InterGraphEdgeError::InvalidConstSrc {
from,
from_offset,
typ,
}
.into());
};
if !(OpTag::Const.is_superset(from_optype.tag())
|| OpTag::Function.is_superset(from_optype.tag()))
{
return Err(InterGraphEdgeError::InvalidConstSrc {
from,
from_offset,
typ,
}
.into());
};
true
}
Expand Down Expand Up @@ -651,9 +645,6 @@ pub enum ValidationError {
/// There are invalid inter-graph edges.
#[error(transparent)]
InterGraphEdgeError(#[from] InterGraphEdgeError),
/// Type error for constant values
#[error("Type error for constant value: {0}.")]
ConstTypeError(#[from] ConstTypeError),
/// Missing lift node
#[error("Resources at target node {to:?} ({to_offset:?}) ({to_resources}) exceed those at source {from:?} ({from_offset:?}) ({from_resources})")]
TgtExceedsSrcResources {
Expand Down Expand Up @@ -817,7 +808,7 @@ mod test {
parent: Node,
predicate_size: usize,
) -> (Node, Node, Node, Node) {
let const_op = ops::Const(ConstValue::simple_predicate(0, predicate_size));
let const_op = ops::Const::new(ConstValue::simple_predicate(0, predicate_size)).unwrap();
let tag_type = SimpleType::Classic(ClassicType::new_simple_predicate(predicate_size));

let input = b
Expand Down Expand Up @@ -1147,8 +1138,10 @@ mod test {
})
);
// Second input of Xor from a constant
let cst =
h.add_op_with_parent(h.root(), ops::Const(ConstValue::Int { width: 1, value: 1 }))?;
let cst = h.add_op_with_parent(
h.root(),
ops::Const::new(ConstValue::Int { width: 1, value: 1 }).unwrap(),
)?;
let lcst = h.add_op_with_parent(
h.root(),
ops::LoadConstant {
Expand Down
33 changes: 19 additions & 14 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::any::Any;

use crate::{
classic_row,
hugr::typecheck::{typecheck_const, ConstTypeError},
macros::impl_box_clone,
types::{ClassicRow, ClassicType, Container, CustomType, EdgeKind, HashableType},
};
Expand All @@ -16,7 +17,16 @@ use super::{OpName, OpTrait, StaticTag};

/// A constant value definition.
#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize)]
pub struct Const(pub ConstValue);
pub struct Const(ConstValue);

impl Const {
/// Creates a new Const, type-checking the value.
pub fn new(val: ConstValue) -> Result<Self, ConstTypeError> {
typecheck_const(&val.const_type(), &val)?;
Ok(Const(val))
}
}

impl OpName for Const {
fn name(&self) -> SmolStr {
self.0.name()
Expand Down Expand Up @@ -214,7 +224,7 @@ mod test {
use crate::{
builder::{BuildError, Container, DFGBuilder, Dataflow, DataflowHugr},
classic_row,
hugr::{typecheck::ConstTypeError, ValidationError},
hugr::typecheck::ConstTypeError,
type_row,
types::{ClassicType, SimpleRow, SimpleType},
};
Expand Down Expand Up @@ -257,19 +267,14 @@ mod test {
let pred_ty = SimpleType::new_predicate(pred_rows.clone());

let mut b = DFGBuilder::new(type_row![], SimpleRow::from(vec![pred_ty])).unwrap();
let c = b
.add_constant(ConstValue::predicate(
0,
ConstValue::Tuple(vec![]),
pred_rows,
))
.unwrap();
let w = b.load_const(&c).unwrap();
let res = b.add_constant(ConstValue::predicate(
0,
ConstValue::Tuple(vec![]),
pred_rows,
));
assert_eq!(
b.finish_hugr_with_outputs([w]),
Err(BuildError::InvalidHUGR(ValidationError::ConstTypeError(
ConstTypeError::TupleWrongLength
)))
res,
Err(BuildError::BadConstant(ConstTypeError::TupleWrongLength))
);
}
}

0 comments on commit 294ebc1

Please sign in to comment.