diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 837127f7c..9c8ce6731 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -524,21 +524,21 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: collections::EXTENSION_NAME.into(), + extension_delta: collections::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), }, ); [0, 1] diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index a5e901a3d..cc30ec7fc 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -547,7 +547,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { PolyFuncType::new( [BOUND], Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(collections::EXTENSION_NAME), + .with_extension_delta(collections::EXTENSION_ID), ), )?; let empty_list = Value::extension(collections::ListValue::new_empty(Type::new_var_use( diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 480780496..67178c076 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -1,36 +1,37 @@ //! List type and operations. +mod list_fold; + +use std::str::FromStr; + use itertools::Itertools; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; +use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; -use crate::types::TypeName; +use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, - TypeDefBound, + ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, - ops::{self, custom::ExtensionOp, NamedOp}, + ops::{custom::ExtensionOp, NamedOp}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound, }, - utils::sorted_consts, Extension, }; /// Reported unique name of the list type. pub const LIST_TYPENAME: TypeName = TypeName::new_inline("List"); -/// Pop operation name. -pub const POP_NAME: OpName = OpName::new_inline("pop"); -/// Push operation name. -pub const PUSH_NAME: OpName = OpName::new_inline("push"); /// Reported unique name of the extension -pub const EXTENSION_NAME: ExtensionId = ExtensionId::new_unchecked("collections"); +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); @@ -100,92 +101,154 @@ impl CustomConst for ListValue { fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_NAME.into()) + .union(EXTENSION_ID.into()) } } -struct PopFold; +/// A list operation +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(non_camel_case_types)] +#[non_exhaustive] +pub enum ListOp { + /// Pop from end of list + pop, + /// Push to end of list + push, +} -impl ConstFold for PopFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, ops::Value)], - ) -> crate::extension::ConstFoldResult { - let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let mut list = list.clone(); - let elem = list.0.pop()?; // empty list fails to evaluate "pop" +impl ListOp { + /// Type parameter used in the list types. + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - Some(vec![(0.into(), list.into()), (1.into(), elem)]) + /// Instantiate a list operation with an `element_type` + pub fn with_type(self, element_type: Type) -> ListOpInst { + ListOpInst { + elem_type: element_type, + op: self, + } + } + + /// Compute the signature of the operation, given the list type definition. + fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc { + use ListOp::*; + let e = Type::new_var_use(0, TypeBound::Any); + let l = self.list_type(list_type_def, 0); + match self { + pop => self + .list_polytype(vec![l.clone()], vec![l.clone(), e.clone()]) + .into(), + push => self.list_polytype(vec![l.clone(), e], vec![l]).into(), + } + } + + /// Compute a polymorphic function type for a list operation. + fn list_polytype( + self, + input: impl Into, + output: impl Into, + ) -> PolyFuncTypeRV { + PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) + } + + /// Returns the type of a generic list, associated with the element type parameter at index `idx`. + fn list_type(self, list_type_def: &TypeDef, idx: usize) -> Type { + Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(idx, Self::TP)]) + .unwrap(), + ) } } -struct PushFold; +impl MakeOpDef for ListOp { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + /// Add an operation implemented as an [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the list type def while computing the signature, + // to avoid recursive loops initializing the extension. + fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig)?; + + self.post_opdef(def); -impl ConstFold for PushFold { - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, ops::Value)], - ) -> crate::extension::ConstFoldResult { - let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?; - let list: &ListValue = list.get_custom_value().expect("Should be list value."); - let mut list = list.clone(); - list.0.push(elem.clone()); + Ok(()) + } - Some(vec![(0.into(), list.into())]) + fn signature(&self) -> SignatureFunc { + self.compute_signature(list_type_def()) + } + + fn description(&self) -> String { + use ListOp::*; + + match self { + pop => "Pop from back of list", + push => "Push to back of list", + } + .into() + } + + fn post_opdef(&self, def: &mut OpDef) { + list_fold::set_fold(self, def) } } -const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; -fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_NAME, VERSION); +lazy_static! { + /// Extension for list operations. + pub static ref EXTENSION: Extension = { + println!("creating collections extension"); + let mut extension = Extension::new(EXTENSION_ID, VERSION); - extension - .add_type( + // The list type must be defined before the operations are added. + extension.add_type( LIST_TYPENAME, - vec![TP], + vec![ListOp::TP], "Generic dynamically sized list of type T.".into(), TypeDefBound::from_params(vec![0]), ) .unwrap(); - let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); - - let (l, e) = list_and_elem_type_vars(list_type_def); - extension - .add_op( - POP_NAME, - "Pop from back of list".into(), - PolyFuncTypeRV::new( - vec![TP], - FuncValueType::new(vec![l.clone()], vec![l.clone(), e.clone()]), - ), - ) - .unwrap() - .set_constant_folder(PopFold); - extension - .add_op( - PUSH_NAME, - "Push to back of list".into(), - PolyFuncTypeRV::new(vec![TP], FuncValueType::new(vec![l.clone(), e], vec![l])), - ) - .unwrap() - .set_constant_folder(PushFold); - extension + ListOp::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate list operations. + pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); } -lazy_static! { - /// Collections extension definition. - pub static ref EXTENSION: Extension = extension(); +impl MakeRegisteredOp for ListOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &COLLECTIONS_REGISTRY + } +} + +/// Get the type of a list of `elem_type` as a `CustomType`. +pub fn list_type_def() -> &'static TypeDef { + // This must not be called while the extension is being built. + EXTENSION.get_type(&LIST_TYPENAME).unwrap() } /// Get the type of a list of `elem_type` as a `CustomType`. pub fn list_custom_type(elem_type: Type) -> CustomType { - EXTENSION - .get_type(&LIST_TYPENAME) - .unwrap() + list_type_def() .instantiate(vec![TypeArg::Type { ty: elem_type }]) .unwrap() } @@ -195,37 +258,9 @@ pub fn list_type(elem_type: Type) -> Type { list_custom_type(elem_type).into() } -fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) { - let elem_type = Type::new_var_use(0, TypeBound::Any); - let list_type = Type::new_extension( - list_type_def - .instantiate(vec![TypeArg::new_var_use(0, TP)]) - .unwrap(), - ); - (list_type, elem_type) -} - -/// A list operation -#[derive(Debug, Clone, PartialEq)] -#[non_exhaustive] -pub enum ListOp { - /// Pop from end of list - Pop, - /// Push to end of list - Push, -} - -impl ListOp { - /// Instantiate a list operation with an `element_type` - pub fn with_type(self, element_type: Type) -> ListOpInst { - ListOpInst { - elem_type: element_type, - op: self, - } - } -} - /// A list operation with a concrete element type. +/// +/// See [ListOp] for the parametric version. #[derive(Debug, Clone, PartialEq)] pub struct ListOpInst { op: ListOp, @@ -234,10 +269,8 @@ pub struct ListOpInst { impl NamedOp for ListOpInst { fn name(&self) -> OpName { - match self.op { - ListOp::Pop => POP_NAME, - ListOp::Push => PUSH_NAME, - } + let name: &str = self.op.into(); + name.into() } } @@ -249,11 +282,8 @@ impl MakeExtensionOp for ListOpInst { return Err(SignatureError::InvalidTypeArgs.into()); }; let name = ext_op.def().name(); - let op = match name { - // can't use const SmolStr in pattern - _ if name == &POP_NAME => ListOp::Pop, - _ if name == &PUSH_NAME => ListOp::Push, - _ => return Err(OpLoadError::NotMember(name.to_string())), + let Ok(op) = ListOp::from_str(name) else { + return Err(OpLoadError::NotMember(name.to_string())); }; Ok(Self { @@ -283,7 +313,7 @@ impl ListOpInst { ) .unwrap(); ExtensionOp::new( - registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(), + registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), self.type_args(), ®istry, ) @@ -307,16 +337,17 @@ mod test { #[test] fn test_extension() { - let r: Extension = extension(); - assert_eq!(r.name(), &EXTENSION_NAME); - let ops = r.operations(); - assert_eq!(ops.count(), 2); + assert_eq!(&ListOp::push.extension_id(), EXTENSION.name()); + assert_eq!(&ListOp::push.extension(), EXTENSION.name()); + assert!(ListOp::pop.registry().contains(EXTENSION.name())); + for (_, op_def) in EXTENSION.operations() { + assert_eq!(op_def.extension(), &EXTENSION_ID); + } } #[test] fn test_list() { - let r: Extension = extension(); - let list_def = r.get_type(&LIST_TYPENAME).unwrap(); + let list_def = list_type_def(); let list_type = list_def .instantiate([TypeArg::Type { ty: USIZE_T }]) @@ -340,7 +371,7 @@ mod test { let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) .unwrap(); - let pop_op = ListOp::Pop.with_type(QB_T); + let pop_op = ListOp::pop.with_type(QB_T); let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); let pop_sig = pop_ext.dataflow_signature().unwrap(); @@ -352,7 +383,7 @@ mod test { assert_eq!(pop_sig.input(), &just_list_row); assert_eq!(pop_sig.output(), &both_row); - let push_op = ListOp::Push.with_type(FLOAT64_TYPE); + let push_op = ListOp::push.with_type(FLOAT64_TYPE); let push_ext = push_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); let push_sig = push_ext.dataflow_signature().unwrap(); diff --git a/hugr-core/src/std_extensions/collections/list_fold.rs b/hugr-core/src/std_extensions/collections/list_fold.rs new file mode 100644 index 000000000..ca9b78c77 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/list_fold.rs @@ -0,0 +1,49 @@ +//! Folding definitions for list operations. + +use crate::extension::{ConstFold, OpDef}; +use crate::ops; +use crate::types::type_param::TypeArg; +use crate::utils::sorted_consts; + +use super::{ListOp, ListValue}; + +pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) { + match op { + ListOp::pop => def.set_constant_folder(PopFold), + ListOp::push => def.set_constant_folder(PushFold), + } +} + +pub struct PopFold; + +impl ConstFold for PopFold { + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Value)], + ) -> crate::extension::ConstFoldResult { + let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + let elem = list.0.pop()?; // empty list fails to evaluate "pop" + + Some(vec![(0.into(), list.into()), (1.into(), elem)]) + } +} + +pub struct PushFold; + +impl ConstFold for PushFold { + fn fold( + &self, + _type_args: &[TypeArg], + consts: &[(crate::IncomingPort, ops::Value)], + ) -> crate::extension::ConstFoldResult { + let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?; + let list: &ListValue = list.get_custom_value().expect("Should be list value."); + let mut list = list.clone(); + list.0.push(elem.clone()); + + Some(vec![(0.into(), list.into())]) + } +} diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 28f43f31e..0898e93d5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -143,12 +143,12 @@ fn test_list_ops() -> Result<(), Box> { let list_wire = build.add_load_const(list.clone()); let pop = build.add_dataflow_op( - ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + ListOp::pop.with_type(BOOL_T).to_extension_op(®).unwrap(), [list_wire], )?; let push = build.add_dataflow_op( - ListOp::Push + ListOp::push .with_type(BOOL_T) .to_extension_op(®) .unwrap(),