diff --git a/src/extensions/logic.rs b/src/extensions/logic.rs index c2b08054b..dcbaadbbf 100644 --- a/src/extensions/logic.rs +++ b/src/extensions/logic.rs @@ -9,7 +9,7 @@ use crate::{ resource::{OpDef, ResourceSet}, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, - SimpleType, + HashableType, SimpleType, }, Resource, }; @@ -26,6 +26,7 @@ pub fn bool_type() -> SimpleType { /// Resource for basic logical operations. pub fn resource() -> Resource { + const H_INT: TypeParam = TypeParam::Value(HashableType::Int(8)); let mut resource = Resource::new(resource_id()); let not_op = OpDef::new_with_custom_sig( @@ -45,14 +46,14 @@ pub fn resource() -> Resource { let and_op = OpDef::new_with_custom_sig( "And".into(), "logical 'and'".into(), - vec![TypeParam::Int], + vec![H_INT], HashMap::default(), |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u128 = match a { TypeArg::Int(n) => *n, _ => { - return Err(TypeArgError::TypeMismatch(a.clone(), TypeParam::Int).into()); + return Err(TypeArgError::TypeMismatch(a.clone(), H_INT).into()); } }; Ok(( @@ -66,14 +67,14 @@ pub fn resource() -> Resource { let or_op = OpDef::new_with_custom_sig( "Or".into(), "logical 'or'".into(), - vec![TypeParam::Int], + vec![H_INT], HashMap::default(), |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u128 = match a { TypeArg::Int(n) => *n, _ => { - return Err(TypeArgError::TypeMismatch(a.clone(), TypeParam::Int).into()); + return Err(TypeArgError::TypeMismatch(a.clone(), H_INT).into()); } }; Ok(( diff --git a/src/hugr/typecheck.rs b/src/hugr/typecheck.rs index 2ec9fd243..7d86346b0 100644 --- a/src/hugr/typecheck.rs +++ b/src/hugr/typecheck.rs @@ -13,13 +13,9 @@ use crate::types::{ClassicRow, ClassicType, Container, HashableType, PrimType, T use crate::ops::constant::{HugrIntValueStore, HugrIntWidthStore, HUGR_MAX_INT_WIDTH}; -/// Errors that arise from typechecking constants -#[derive(Clone, Debug, Eq, PartialEq, Error)] -pub enum ConstTypeError { - /// This case hasn't been implemented. Possibly because we don't have value - /// constructors to check against it - #[error("Unimplemented: there are no constants of type {0}")] - Unimplemented(ClassicType), +/// An error in fitting an integer constant into its size +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum ConstIntError { /// The value exceeds the max value of its `I` type /// E.g. checking 300 against I8 #[error("Const int {1} too large for type I{0}")] @@ -30,6 +26,18 @@ pub enum ConstTypeError { /// The width of an integer type wasn't a power of 2 #[error("The int type I{0} is invalid, because {0} is not a power of 2")] IntWidthInvalid(HugrIntWidthStore), +} + +/// Errors that arise from typechecking constants +#[derive(Clone, Debug, Eq, PartialEq, Error)] +pub enum ConstTypeError { + /// This case hasn't been implemented. Possibly because we don't have value + /// constructors to check against it + #[error("Unimplemented: there are no constants of type {0}")] + Unimplemented(ClassicType), + /// There was some problem fitting a const int into its declared size + #[error("Error with int constant")] + Int(#[from] ConstIntError), /// Expected width (packed with const int) doesn't match type #[error("Type mismatch for int: expected I{0}, but found I{1}")] IntWidthMismatch(HugrIntWidthStore, HugrIntWidthStore), @@ -57,15 +65,27 @@ lazy_static! { } /// Per the spec, valid widths for integers are 2^n for all n in [0,7] -fn check_valid_width(width: HugrIntWidthStore) -> Result<(), ConstTypeError> { +pub(crate) fn check_int_fits_in_width( + value: HugrIntValueStore, + width: HugrIntWidthStore, +) -> Result<(), ConstIntError> { if width > HUGR_MAX_INT_WIDTH { - return Err(ConstTypeError::IntWidthTooLarge(width)); + return Err(ConstIntError::IntWidthTooLarge(width)); } if VALID_WIDTHS.contains(&width) { - Ok(()) + let max_value = if width == HUGR_MAX_INT_WIDTH { + HugrIntValueStore::MAX + } else { + HugrIntValueStore::pow(2, width as u32) - 1 + }; + if value <= max_value { + Ok(()) + } else { + Err(ConstIntError::IntTooLarge(width, value)) + } } else { - Err(ConstTypeError::IntWidthInvalid(width)) + Err(ConstIntError::IntWidthInvalid(width)) } } @@ -99,21 +119,8 @@ fn map_vals( pub fn typecheck_const(typ: &ClassicType, val: &ConstValue) -> Result<(), ConstTypeError> { match (typ, val) { (ClassicType::Hashable(HashableType::Int(exp_width)), ConstValue::Int { value, width }) => { - // Check that the types make sense - check_valid_width(*exp_width)?; - check_valid_width(*width)?; - // Check that the terms make sense against the types if exp_width == width { - let max_value = if *width == HUGR_MAX_INT_WIDTH { - HugrIntValueStore::MAX - } else { - HugrIntValueStore::pow(2, *width as u32) - 1 - }; - if value <= &max_value { - Ok(()) - } else { - Err(ConstTypeError::IntTooLarge(*width, *value)) - } + check_int_fits_in_width(*value, *width).map_err(ConstTypeError::Int) } else { Err(ConstTypeError::IntWidthMismatch(*exp_width, *width)) } diff --git a/src/types/type_param.rs b/src/types/type_param.rs index f659c0bfd..8878b9ed8 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -6,9 +6,10 @@ use thiserror::Error; +use crate::hugr::typecheck::{check_int_fits_in_width, ConstIntError}; use crate::ops::constant::HugrIntValueStore; -use super::{ClassicType, SimpleType}; +use super::{simple::Container, ClassicType, HashableType, SimpleType}; /// A parameter declared by an OpDef. Specifies a value /// that must be provided by each operation node. @@ -19,55 +20,84 @@ use super::{ClassicType, SimpleType}; pub enum TypeParam { /// Argument is a [TypeArg::Type] - classic or linear Type, - /// Argument is a [TypeArg::ClassicType] + /// Argument is a [TypeArg::ClassicType] - hashable or otherwise ClassicType, - /// Argument is an integer - Int, + /// Argument is a [TypeArg::HashableType] + HashableType, /// Node must provide a [TypeArg::List] (of whatever length) /// TODO it'd be better to use [`Container`] here. /// /// [`Container`]: crate::types::simple::Container List(Box), - /// Argument is a [TypeArg::Value], containing a yaml-encoded object - /// interpretable by the operation. - Value, + /// Argument is a value of the specified type. + Value(HashableType), } /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] pub enum TypeArg { - /// Where the TypeDef declares that an argument is a [TypeParam::Type] + /// Where the (Type/Op)Def declares that an argument is a [TypeParam::Type] Type(SimpleType), - /// Where the TypeDef declares that an argument is a [TypeParam::ClassicType], + /// Where the (Type/Op)Def declares that an argument is a [TypeParam::ClassicType], /// it'll get one of these (rather than embedding inside a Type) ClassicType(ClassicType), - /// Where the TypeDef declares a [TypeParam::Int] + /// Where the (Type/Op)Def declares that an argument is a [TypeParam::HashableType], + /// this is the value. + HashableType(HashableType), + /// Where the (Type/Op)Def declares a [TypeParam::Value] of type [HashableType::Int], a constant value thereof Int(HugrIntValueStore), - /// Where an argument has type [TypeParam::List]`` - all elements will implicitly - /// be of the same variety of TypeArg, representing a `T`. + /// Where the (Type/Op)Def declares a [TypeParam::Value] of type [HashableType::String], here it is + String(String), + /// Where the (Type/Op)Def declares a [TypeParam::List]`` - all elements will implicitly + /// be of the same variety of TypeArg, i.e. `T`s. List(Vec), - /// Where the TypeDef declares a [TypeParam::Value] - Value(serde_yaml::Value), + /// Where the TypeDef declares a [TypeParam::Value] of [Container::Opaque] + CustomValue(serde_yaml::Value), } /// Checks a [TypeArg] is as expected for a [TypeParam] pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { match (arg, param) { - (TypeArg::Type(_), TypeParam::Type) => (), - (TypeArg::ClassicType(_), TypeParam::ClassicType) => (), - (TypeArg::Int(_), TypeParam::Int) => (), + (TypeArg::Type(_), TypeParam::Type) => Ok(()), + (TypeArg::ClassicType(_), TypeParam::ClassicType) => Ok(()), + (TypeArg::HashableType(_), TypeParam::HashableType) => Ok(()), (TypeArg::List(items), TypeParam::List(ty)) => { for item in items { check_type_arg(item, ty.as_ref())?; } + Ok(()) } - (TypeArg::Value(_), TypeParam::Value) => (), - _ => { - return Err(TypeArgError::TypeMismatch(arg.clone(), param.clone())); + (TypeArg::Int(v), TypeParam::Value(HashableType::Int(width))) => { + check_int_fits_in_width(*v, *width).map_err(TypeArgError::Int) } - }; - Ok(()) + (TypeArg::String(_), TypeParam::Value(HashableType::String)) => Ok(()), + (arg, TypeParam::Value(HashableType::Container(ctr))) => match ctr { + Container::Opaque(_) => match arg { + TypeArg::CustomValue(_) => Ok(()), // Are there more checks we should do here? + _ => Err(TypeArgError::TypeMismatch(arg.clone(), param.clone())), + }, + Container::List(elem) => check_type_arg( + arg, + &TypeParam::List(Box::new(TypeParam::Value((**elem).clone()))), + ), + Container::Map(_) => unimplemented!(), + Container::Tuple(_) => unimplemented!(), + Container::Sum(_) => unimplemented!(), + Container::Array(elem, sz) => { + let TypeArg::List(items) = arg else {return Err(TypeArgError::TypeMismatch(arg.clone(), param.clone()))}; + if items.len() != *sz { + return Err(TypeArgError::WrongNumber(items.len(), *sz)); + } + check_type_arg( + arg, + &TypeParam::List(Box::new(TypeParam::Value((**elem).clone()))), + ) + } + Container::Alias(n) => Err(TypeArgError::NoAliases(n.to_string())), + }, + _ => Err(TypeArgError::TypeMismatch(arg.clone(), param.clone())), + } } /// Errors that can occur fitting a [TypeArg] into a [TypeParam] @@ -83,4 +113,10 @@ pub enum TypeArgError { // However in the future it may be applicable to e.g. contents of Tuples too. #[error("Wrong number of type arguments: {0} vs expected {1} declared type parameters")] WrongNumber(usize, usize), + /// The type declared for a TypeParam was an alias that was not resolved to an actual type + #[error("TypeParam required an unidentified alias type {0}")] + NoAliases(String), + /// There was some problem fitting a const int into its declared size + #[error("Error with int constant")] + Int(#[from] ConstIntError), }