diff --git a/src/ops/custom.rs b/src/ops/custom.rs index f5c013d60..b179a7bb2 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -10,6 +10,7 @@ use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; +use super::dataflow::DataflowOpTrait; use super::tag::OpTag; use super::{LeafOp, OpTrait, OpType}; @@ -74,7 +75,7 @@ impl ExternalOp { pub fn description(&self) -> &str { match self { Self::Opaque(op) => op.description.as_str(), - Self::Extension(ExtensionOp { def, .. }) => def.description(), + Self::Extension(ext_op) => DataflowOpTrait::description(ext_op), } } @@ -86,7 +87,7 @@ impl ExternalOp { .signature .clone() .expect("Op should have been serialized with signature."), - Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), + Self::Extension(ext_op) => ext_op.signature(), } } } @@ -170,6 +171,18 @@ impl PartialEq for ExtensionOp { } } +impl DataflowOpTrait for ExtensionOp { + const TAG: OpTag = OpTag::Leaf; + + fn description(&self) -> &str { + self.def().description() + } + + fn signature(&self) -> FunctionType { + self.signature.clone() + } +} + impl Eq for ExtensionOp {} /// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`] diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 56b636f19..4ae262b77 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -7,7 +7,7 @@ use crate::{ Extension, }; -use super::int_types::int_type_var; +use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. @@ -17,15 +17,12 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.con pub fn extension() -> Extension { let ftoi_sig = PolyFuncType::new( vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new( - type_row![FLOAT64_TYPE], - vec![sum_with_error(int_type_var(0))], - ), + FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), ); let itof_sig = PolyFuncType::new( vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]), + FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), ); let mut extension = Extension::new_with_reqs( diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index cd9221deb..fa76adfc7 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,8 +1,13 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type_var, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; -use crate::extension::{CustomValidator, ValidateJustArgs}; +use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; +use crate::extension::{ + CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE, +}; +use crate::ops::custom::ExtensionOp; +use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; @@ -13,6 +18,8 @@ use crate::{ }; use lazy_static::lazy_static; +use smol_str::SmolStr; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int"); @@ -34,7 +41,181 @@ impl ValidateJustArgs for IOValidator { Ok(()) } } +/// Logic extension operation definitions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum IntOpDef { + iwiden_u, + iwiden_s, + inarrow_u, + inarrow_s, + itobool, + ifrombool, + ieq, + ine, + ilt_u, + ilt_s, + igt_u, + igt_s, + ile_u, + ile_s, + ige_u, + ige_s, + imax_u, + imax_s, + imin_u, + imin_s, + iadd, + isub, + ineg, + imul, + idivmod_checked_u, + idivmod_u, + idivmod_checked_s, + idivmod_s, + idiv_checked_u, + idiv_u, + imod_checked_u, + imod_u, + idiv_checked_s, + idiv_s, + imod_checked_s, + imod_s, + iabs, + iand, + ior, + ixor, + inot, + ishl, + ishr, + irotl, + irotr, +} + +impl MakeOpDef for IntOpDef { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use IntOpDef::*; + match self { + iwiden_s | iwiden_u => CustomValidator::new_with_validator( + int_polytype(2, vec![int_tv(0)], vec![int_tv(1)]), + IOValidator { f_gt_s: false }, + ) + .into(), + inarrow_s | inarrow_u => CustomValidator::new_with_validator( + int_polytype(2, vec![int_tv(0)], vec![sum_with_error(int_tv(1))]), + IOValidator { f_gt_s: true }, + ) + .into(), + itobool => int_polytype(1, vec![int_tv(0)], type_row![BOOL_T]).into(), + ifrombool => int_polytype(1, type_row![BOOL_T], vec![int_tv(0)]).into(), + ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => { + int_polytype(1, vec![int_tv(0); 2], type_row![BOOL_T]).into() + } + imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor => { + ibinop_sig().into() + } + ineg | iabs | inot => iunop_sig().into(), + //TODO inline + idivmod_checked_u | idivmod_checked_s => { + let intpair: TypeRow = vec![int_tv(0), int_tv(1)].into(); + int_polytype( + 2, + intpair.clone(), + vec![sum_with_error(Type::new_tuple(intpair))], + ) + } + .into(), + idivmod_u | idivmod_s => { + let intpair: TypeRow = vec![int_tv(0), int_tv(1)].into(); + int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) + } + .into(), + idiv_u | idiv_s => int_polytype(2, vec![int_tv(0), int_tv(1)], vec![int_tv(0)]).into(), + idiv_checked_u | idiv_checked_s => int_polytype( + 2, + vec![int_tv(0), int_tv(1)], + vec![sum_with_error(int_tv(0))], + ) + .into(), + imod_checked_u | imod_checked_s => int_polytype( + 2, + vec![int_tv(0), int_tv(1).clone()], + vec![sum_with_error(int_tv(1))], + ) + .into(), + imod_u | imod_s => { + int_polytype(2, vec![int_tv(0), int_tv(1).clone()], vec![int_tv(1)]).into() + } + ishl | ishr | irotl | irotr => { + int_polytype(2, vec![int_tv(0), int_tv(1)], vec![int_tv(0)]).into() + } + } + } + fn description(&self) -> String { + use IntOpDef::*; + + match self { + iwiden_u => "widen an unsigned integer to a wider one with the same value", + iwiden_s => "widen a signed integer to a wider one with the same value", + inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible", + inarrow_s => "narrow a signed integer to a narrower one with the same value if possible", + itobool => "convert to bool (1 is true, 0 is false)", + ifrombool => "convert from bool (1 is true, 0 is false)", + ieq => "equality test", + ine => "inequality test", + ilt_u => "\"less than\" as unsigned integers", + ilt_s => "\"less than\" as signed integers", + igt_u =>"\"greater than\" as unsigned integers", + igt_s => "\"greater than\" as signed integers", + ile_u => "\"less than or equal\" as unsigned integers", + ile_s => "\"less than or equal\" as signed integers", + ige_u => "\"greater than or equal\" as unsigned integers", + ige_s => "\"greater than or equal\" as signed integers", + imax_u => "maximum of unsigned integers", + imax_s => "maximum of signed integers", + imin_u => "minimum of unsigned integers", + imin_s => "minimum of signed integers", + iadd => "addition modulo 2^N (signed and unsigned versions are the same op)", + isub => "subtraction modulo 2^N (signed and unsigned versions are the same op)", + ineg => "negation modulo 2^N (signed and unsigned versions are the same op)", + imul => "multiplication modulo 2^N (signed and unsigned versions are the same op)", + idivmod_checked_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ + q*m+r=n, 0<=r "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ + q*m+r=n, 0<=r "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ + signed q and unsigned r where q*m+r=n, 0<=r "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ + signed q and unsigned r where q*m+r=n, 0<=r "as idivmod_checked_u but discarding the second output", + idiv_u => "as idivmod_u but discarding the second output", + imod_checked_u => "as idivmod_checked_u but discarding the first output", + imod_u => "as idivmod_u but discarding the first output", + idiv_checked_s => "as idivmod_checked_s but discarding the second output", + idiv_s => "as idivmod_s but discarding the second output", + imod_checked_s => "as idivmod_checked_s but discarding the first output", + imod_s => "as idivmod_s but discarding the first output", + iabs => "convert signed to unsigned by taking absolute value", + iand => "bitwise AND", + ior => "bitwise OR", + ixor => "bitwise XOR", + inot => "bitwise NOT", + ishl => "shift first input left by k bits where k is unsigned interpretation of second input \ + (leftmost bits dropped, rightmost bits set to zero", + ishr => "shift first input right by k bits where k is unsigned interpretation of second input \ + (rightmost bits dropped, leftmost bits set to zero)", + irotl => "rotate first input left by k bits where k is unsigned interpretation of second input \ + (leftmost bits replace rightmost bits)", + irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \ + (rightmost bits replace leftmost bits)", + }.into() + } +} fn int_polytype( n_vars: usize, input: impl Into, @@ -47,403 +228,109 @@ fn int_polytype( } fn ibinop_sig() -> PolyFuncType { - let int_type_var = int_type_var(0); + let int_type_var = int_tv(0); int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var]) } fn iunop_sig() -> PolyFuncType { - let int_type_var = int_type_var(0); + let int_type_var = int_tv(0); int_polytype(1, vec![int_type_var.clone()], vec![int_type_var]) } -fn idivmod_checked_sig() -> PolyFuncType { - let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); - int_polytype( - 2, - intpair.clone(), - vec![sum_with_error(Type::new_tuple(intpair))], - ) -} - -fn idivmod_sig() -> PolyFuncType { - let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); - int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) -} - -/// Extension for basic integer operations. -fn extension() -> Extension { - let itob_sig = int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T]); - - let btoi_sig = int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)]); - - let icmp_sig = int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T]); - - let idiv_checked_sig = int_polytype( - 2, - vec![int_type_var(0), int_type_var(1)], - vec![sum_with_error(int_type_var(0))], - ); +lazy_static! { + /// Extension for basic integer operations. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::singleton(&super::int_types::EXTENSION_ID), + ); - let idiv_sig = int_polytype( - 2, - vec![int_type_var(0), int_type_var(1)], - vec![int_type_var(0)], - ); + IntOpDef::load_all_ops(&mut extension).unwrap(); - let imod_checked_sig = int_polytype( - 2, - vec![int_type_var(0), int_type_var(1).clone()], - vec![sum_with_error(int_type_var(1))], - ); + extension + }; - let imod_sig = int_polytype( - 2, - vec![int_type_var(0), int_type_var(1).clone()], - vec![int_type_var(1)], - ); + /// Registry of extensions required to validate integer operations. + pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + super::int_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} - let ish_sig = int_polytype( - 2, - vec![int_type_var(0), int_type_var(1)], - vec![int_type_var(0)], - ); +/// Concrete integer operation with either one or two integer widths set. +#[derive(Debug, Clone, PartialEq)] +pub struct IntOpType { + def: IntOpDef, + first_width: u64, + second_width: Option, +} - let widen_poly = int_polytype(2, vec![int_type_var(0)], vec![int_type_var(1)]); - let narrow_poly = int_polytype( - 2, - vec![int_type_var(0)], - vec![sum_with_error(int_type_var(1))], - ); - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::singleton(&super::int_types::EXTENSION_ID), - ); +impl OpName for IntOpType { + fn name(&self) -> SmolStr { + self.def.name() + } +} +impl MakeExtensionOp for IntOpType { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def = IntOpDef::from_def(ext_op.def())?; + let (first_width, second_width) = match *ext_op.args() { + [TypeArg::BoundedNat { n }] => (n, None), + [TypeArg::BoundedNat { n }, TypeArg::BoundedNat { n: n2 }] => (n, Some(n2)), + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + Ok(Self { + def, + first_width, + second_width, + }) + } - extension - .add_op( - "iwiden_u".into(), - "widen an unsigned integer to a wider one with the same value".to_owned(), - CustomValidator::new_with_validator(widen_poly.clone(), IOValidator { f_gt_s: false }), - ) - .unwrap(); + fn type_args(&self) -> Vec { + [Some(self.first_width), self.second_width] + .iter() + .flatten() + .map(|&n| TypeArg::BoundedNat { n }) + .collect() + } +} - extension - .add_op( - "iwiden_s".into(), - "widen a signed integer to a wider one with the same value".to_owned(), - CustomValidator::new_with_validator(widen_poly, IOValidator { f_gt_s: false }), - ) - .unwrap(); - extension - .add_op( - "inarrow_u".into(), - "narrow an unsigned integer to a narrower one with the same value if possible" - .to_owned(), - CustomValidator::new_with_validator(narrow_poly.clone(), IOValidator { f_gt_s: true }), - ) - .unwrap(); - extension - .add_op( - "inarrow_s".into(), - "narrow a signed integer to a narrower one with the same value if possible".to_owned(), - CustomValidator::new_with_validator(narrow_poly, IOValidator { f_gt_s: true }), - ) - .unwrap(); - extension - .add_op( - "itobool".into(), - "convert to bool (1 is true, 0 is false)".to_owned(), - itob_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ifrombool".into(), - "convert from bool (1 is true, 0 is false)".to_owned(), - btoi_sig.clone(), - ) - .unwrap(); - extension - .add_op("ieq".into(), "equality test".to_owned(), icmp_sig.clone()) - .unwrap(); - extension - .add_op("ine".into(), "inequality test".to_owned(), icmp_sig.clone()) - .unwrap(); - extension - .add_op( - "ilt_u".into(), - "\"less than\" as unsigned integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ilt_s".into(), - "\"less than\" as signed integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "igt_u".into(), - "\"greater than\" as unsigned integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "igt_s".into(), - "\"greater than\" as signed integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ile_u".into(), - "\"less than or equal\" as unsigned integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ile_s".into(), - "\"less than or equal\" as signed integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ige_u".into(), - "\"greater than or equal\" as unsigned integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "ige_s".into(), - "\"greater than or equal\" as signed integers".to_owned(), - icmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "imax_u".into(), - "maximum of unsigned integers".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "imax_s".into(), - "maximum of signed integers".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "imin_u".into(), - "minimum of unsigned integers".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "imin_s".into(), - "minimum of signed integers".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "iadd".into(), - "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "isub".into(), - "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "ineg".into(), - "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - iunop_sig(), - ) - .unwrap(); - extension - .add_op( - "imul".into(), - "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(), - ) - .unwrap(); - extension - .add_op( - "idivmod_checked_u".into(), - "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ - q*m+r=n, 0<=r ExtensionId { + EXTENSION_ID.to_owned() + } - extension + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &INT_OPS_REGISTRY + } } -lazy_static! { - /// Extension for basic integer operations. - pub static ref EXTENSION: Extension = extension(); +impl IntOpDef { + /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires one + /// integer width set. + pub fn with_width(self, width: u64) -> IntOpType { + IntOpType { + def: self, + first_width: width, + second_width: None, + } + } + /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires two + /// integer widths set. + pub fn with_two_widths(self, first_width: u64, second_width: u64) -> IntOpType { + IntOpType { + def: self, + first_width, + second_width: Some(second_width), + } + } } #[cfg(test)] mod test { - use crate::{ - extension::{ExtensionRegistry, PRELUDE}, - std_extensions::arithmetic::int_types::int_type, - }; + use crate::{ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type}; use super::*; @@ -451,6 +338,7 @@ mod test { fn test_int_ops_extension() { assert_eq!(EXTENSION.name() as &str, "arithmetic.int"); assert_eq!(EXTENSION.types().count(), 0); + assert_eq!(EXTENSION.operations().count(), 45); for (name, _) in EXTENSION.operations() { assert!(name.starts_with('i')); } @@ -461,33 +349,49 @@ mod test { } #[test] fn test_binary_signatures() { - let iwiden_s = EXTENSION.get_op("iwiden_s").unwrap(); - let reg = ExtensionRegistry::try_new([ - EXTENSION.to_owned(), - super::super::int_types::EXTENSION.to_owned(), - PRELUDE.to_owned(), - ]) - .unwrap(); assert_eq!( - iwiden_s.compute_signature(&[ta(3), ta(4)], ®).unwrap(), + IntOpDef::iwiden_s + .with_two_widths(3, 4) + .to_extension_op() + .unwrap() + .signature(), FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) ); - - let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap(); - iwiden_u - .compute_signature(&[ta(4), ta(3)], ®) - .unwrap_err(); - - let inarrow_s = EXTENSION.get_op("inarrow_s").unwrap(); + assert!( + IntOpDef::iwiden_u + .with_two_widths(4, 3) + .to_extension_op() + .is_none(), + "type arguments invalid" + ); assert_eq!( - inarrow_s.compute_signature(&[ta(2), ta(1)], ®).unwrap(), + IntOpDef::inarrow_s + .with_two_widths(2, 1) + .to_extension_op() + .unwrap() + .signature(), FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) ); - let inarrow_u = EXTENSION.get_op("inarrow_u").unwrap(); - inarrow_u - .compute_signature(&[ta(1), ta(2)], ®) - .unwrap_err(); + assert!(IntOpDef::inarrow_u + .with_two_widths(1, 2) + .to_extension_op() + .is_none()); + } + + #[test] + fn test_conversions() { + let o = IntOpDef::itobool.with_width(5); + assert!( + IntOpDef::itobool + .with_two_widths(1, 2) + .to_extension_op() + .is_none(), + "type arguments invalid" + ); + let ext_op = o.clone().to_extension_op().unwrap(); + + assert_eq!(IntOpType::from_extension_op(&ext_op).unwrap(), o); } } diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 7a67de28a..f45d93964 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -203,8 +203,8 @@ lazy_static! { pub static ref EXTENSION: Extension = extension(); } -/// get an integer type variable, given the integer type definition -pub(super) fn int_type_var(var_id: usize) -> Type { +/// get an integer type with width corresponding to a type variable with id `var_id` +pub(super) fn int_tv(var_id: usize) -> Type { Type::new_extension( EXTENSION .get_type(&INT_TYPE_ID)