Skip to content

Commit

Permalink
feat: IntOpType convenience struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 30, 2023
1 parent faaf444 commit abcd598
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 31 deletions.
17 changes: 15 additions & 2 deletions src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::hugr::{HugrView, NodeType};
use crate::types::{type_param::TypeArg, FunctionType};
use crate::{Hugr, Node};

use super::dataflow::DataflowOpTrait;
use super::tag::OpTag;
use super::{LeafOp, OpTrait, OpType};

Expand Down Expand Up @@ -74,7 +75,7 @@ impl ExternalOp {
pub fn description(&self) -> &str {
match self {
Self::Opaque(op) => op.description.as_str(),
Self::Extension(ExtensionOp { def, .. }) => def.description(),
Self::Extension(ext_op) => DataflowOpTrait::description(ext_op),
}
}

Expand All @@ -86,7 +87,7 @@ impl ExternalOp {
.signature
.clone()
.expect("Op should have been serialized with signature."),
Self::Extension(ExtensionOp { signature, .. }) => signature.clone(),
Self::Extension(ext_op) => ext_op.signature(),
}
}
}
Expand Down Expand Up @@ -170,6 +171,18 @@ impl PartialEq for ExtensionOp {
}
}

impl DataflowOpTrait for ExtensionOp {
const TAG: OpTag = OpTag::Leaf;

fn description(&self) -> &str {
self.def().description()
}

fn signature(&self) -> FunctionType {
self.signature.clone()
}
}

impl Eq for ExtensionOp {}

/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`]
Expand Down
155 changes: 126 additions & 29 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM};
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::simple_op::MakeOpDef;
use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs};
use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
use crate::extension::{
CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE,
};
use crate::ops::custom::ExtensionOp;
use crate::ops::OpName;
use crate::type_row;
use crate::types::{FunctionType, PolyFuncType};
use crate::utils::collect_array;
Expand All @@ -14,6 +18,7 @@ use crate::{
};

use lazy_static::lazy_static;
use smol_str::SmolStr;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

/// The extension identifier.
Expand All @@ -39,7 +44,7 @@ impl ValidateJustArgs for IOValidator {
/// Logic extension operation definitions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs, non_camel_case_types)]
pub enum IntOps {
pub enum IntOpDef {
iwiden_u,
iwiden_s,
inarrow_u,
Expand Down Expand Up @@ -87,13 +92,13 @@ pub enum IntOps {
irotr,
}

impl MakeOpDef for IntOps {
impl MakeOpDef for IntOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
}

fn signature(&self) -> SignatureFunc {
use IntOps::*;
use IntOpDef::*;
match self {
iwiden_s | iwiden_u => CustomValidator::new_with_validator(
int_polytype(2, vec![int_tv(0)], vec![int_tv(1)]),
Expand Down Expand Up @@ -152,7 +157,7 @@ impl MakeOpDef for IntOps {
}

fn description(&self) -> String {
use IntOps::*;
use IntOpDef::*;

match self {
iwiden_u => "widen an unsigned integer to a wider one with the same value",
Expand Down Expand Up @@ -241,18 +246,91 @@ lazy_static! {
ExtensionSet::singleton(&super::int_types::EXTENSION_ID),
);

IntOps::load_all_ops(&mut extension).unwrap();
IntOpDef::load_all_ops(&mut extension).unwrap();

extension
};

/// Registry of extensions required to validate integer operations.
pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
super::int_types::EXTENSION.to_owned(),
EXTENSION.to_owned(),
])
.unwrap();
}

/// Concrete integer operation with either one or two integer widths set.
#[derive(Debug, Clone, PartialEq)]
pub struct IntOpType {
def: IntOpDef,
first_width: u64,
second_width: Option<u64>,
}

impl OpName for IntOpType {
fn name(&self) -> SmolStr {
self.def.name()
}
}
impl MakeExtensionOp for IntOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = IntOpDef::from_def(ext_op.def())?;
let (first_width, second_width) = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => (n, None),
[TypeArg::BoundedNat { n }, TypeArg::BoundedNat { n: n2 }] => (n, Some(n2)),
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self {
def,
first_width,
second_width,
})
}

fn type_args(&self) -> Vec<TypeArg> {
[Some(self.first_width), self.second_width]
.iter()
.flatten()
.map(|&n| TypeArg::BoundedNat { n })
.collect()
}
}

impl MakeRegisteredOp for IntOpType {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&INT_OPS_REGISTRY
}
}

impl IntOpDef {
/// Initialize a concrete [IntOpType] from a [IntOpDef] which requires one
/// integer width set.
pub fn one_width(self, width: u64) -> IntOpType {
IntOpType {
def: self,
first_width: width,
second_width: None,
}
}
/// Initialize a concrete [IntOpType] from a [IntOpDef] which requires two
/// integer widths set.
pub fn two_widths(self, first_width: u64, second_width: u64) -> IntOpType {
IntOpType {
def: self,
first_width,
second_width: Some(second_width),
}
}
}

#[cfg(test)]
mod test {
use crate::{
extension::{ExtensionRegistry, PRELUDE},
std_extensions::arithmetic::int_types::int_type,
};
use crate::{ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type};

use super::*;

Expand All @@ -271,33 +349,52 @@ mod test {
}
#[test]
fn test_binary_signatures() {
let iwiden_s = EXTENSION.get_op("iwiden_s").unwrap();
let reg = ExtensionRegistry::try_new([
EXTENSION.to_owned(),
super::super::int_types::EXTENSION.to_owned(),
PRELUDE.to_owned(),
])
.unwrap();
assert_eq!(
iwiden_s.compute_signature(&[ta(3), ta(4)], &reg).unwrap(),
// iwiden_s
// .compute_signature(&[ta(3), ta(4)], &INT_OPS_REGISTRY)
// .unwrap(),
IntOpDef::iwiden_s
.two_widths(3, 4)
.to_extension_op()
.unwrap()
.signature(),
FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],)
);

let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap();
iwiden_u
.compute_signature(&[ta(4), ta(3)], &reg)
.unwrap_err();
// let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap();
// iwiden_u
// .compute_signature(&[ta(4), ta(3)], &INT_OPS_REGISTRY)
// .unwrap_err();

let inarrow_s = EXTENSION.get_op("inarrow_s").unwrap();
assert!(IntOpDef::iwiden_u
.two_widths(4, 3)
.to_extension_op()
.is_none());

assert_eq!(
inarrow_s.compute_signature(&[ta(2), ta(1)], &reg).unwrap(),
IntOpDef::inarrow_s
.two_widths(2, 1)
.to_extension_op()
.unwrap()
.signature(),
FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],)
);

let inarrow_u = EXTENSION.get_op("inarrow_u").unwrap();
inarrow_u
.compute_signature(&[ta(1), ta(2)], &reg)
.unwrap_err();
assert!(IntOpDef::inarrow_u
.two_widths(1, 2)
.to_extension_op()
.is_none());
}

#[test]
fn test_conversions() {
let o = IntOpDef::itobool.one_width(5);
assert!(IntOpDef::itobool
.two_widths(1, 2)
.to_extension_op()
.is_none());
let ext_op = o.clone().to_extension_op().unwrap();

assert_eq!(IntOpType::from_extension_op(&ext_op).unwrap(), o);
}
}

0 comments on commit abcd598

Please sign in to comment.