Skip to content

Commit

Permalink
feat(hugr-core): HasDef and HasConcrete traits for def/concrete o…
Browse files Browse the repository at this point in the history
…p design pattern

When there are custom types standing in for opdefs and concrete versions of those ops, implementing these simple traits (how to instantiate one from the other) allows them to be loaded from `CustomOp`.
  • Loading branch information
ss2165 committed Jul 24, 2024
1 parent dbb3232 commit ed35c90
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 19 deletions.
57 changes: 56 additions & 1 deletion hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use strum::IntoEnumIterator;

use crate::ops::{OpName, OpNameRef};
use crate::ops::{CustomOp, OpName, OpNameRef};
use crate::{
ops::{custom::ExtensionOp, NamedOp, OpType},
types::TypeArg,
Expand Down Expand Up @@ -85,6 +85,43 @@ pub trait MakeOpDef: NamedOp {
}
Ok(())
}

/// If the definition can be loaded from a string, load from an [ExtensionOp].
fn from_op(custom_op: &CustomOp) -> Result<Self, OpLoadError>
where
Self: Sized + std::str::FromStr,
{
match custom_op {
CustomOp::Extension(ext) => Self::from_extension_op(ext),
CustomOp::Opaque(opaque) => try_from_name(opaque.name(), opaque.extension()),
}
}
}

/// [MakeOpDef] with an associate concrete Op type which can be instantiated with type arguments.
pub trait HasConcrete: MakeOpDef {
/// Associated concrete type.
type Concrete: MakeExtensionOp;

/// Instantiate the operation with type arguments.
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
}

/// [MakeExtensionOp] with an associated [HasConcrete].
pub trait HasDef: MakeExtensionOp {
/// Associated [HasConcrete] type.
type Def: HasConcrete<Concrete = Self> + std::str::FromStr;

/// Load the operation from a [CustomOp].
fn from_op(custom_op: &CustomOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
match custom_op {
CustomOp::Extension(ext) => Self::from_extension_op(ext),
CustomOp::Opaque(opaque) => Self::Def::from_op(custom_op)?.instantiate(opaque.args()),
}
}
}

/// Traits implemented by types which can be loaded from [`ExtensionOp`]s,
Expand Down Expand Up @@ -264,6 +301,18 @@ mod test {
EXT_ID.to_owned()
}
}

impl HasConcrete for DummyEnum {
type Concrete = Self;

fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
if _type_args.is_empty() {
Ok(self.clone())
} else {
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
}
}
}
const_extension_ids! {
const EXT_ID: ExtensionId = "DummyExt";
}
Expand Down Expand Up @@ -302,5 +351,11 @@ mod test {
);
let registered: RegisteredOp<_> = o.clone().into();
assert_eq!(registered.to_inner(), o);

assert_eq!(o.instantiate(&[]), Ok(o.clone()));
assert_eq!(
o.instantiate(&[TypeArg::BoundedNat { n: 1 }]),
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
);
}
}
42 changes: 32 additions & 10 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM};
use crate::extension::prelude::{sum_with_error, BOOL_T, STRING_TYPE};
use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
use crate::extension::simple_op::{
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{
CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE,
};
Expand Down Expand Up @@ -278,6 +280,25 @@ lazy_static! {
.unwrap();
}

impl HasConcrete for IntOpDef {
type Concrete = ConcreteIntOp;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
let log_widths: Vec<u8> = type_args
.iter()
.map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
.collect::<Result<_, _>>()?;
Ok(ConcreteIntOp {
def: *self,
log_widths,
})
}
}

impl HasDef for ConcreteIntOp {
type Def = IntOpDef;
}

/// Concrete integer operation with integer widths set.
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
Expand All @@ -298,12 +319,7 @@ impl NamedOp for ConcreteIntOp {
impl MakeExtensionOp for ConcreteIntOp {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = IntOpDef::from_def(ext_op.def())?;
let args = ext_op.args();
let log_widths: Vec<u8> = args
.iter()
.map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
.collect::<Result<_, _>>()?;
Ok(Self { def, log_widths })
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<TypeArg> {
Expand Down Expand Up @@ -355,7 +371,8 @@ fn sum_ty_with_err(t: Type) -> Type {
#[cfg(test)]
mod test {
use crate::{
ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type,
ops::{dataflow::DataflowOpTrait, CustomOp},
std_extensions::arithmetic::int_types::int_type,
types::Signature,
};

Expand Down Expand Up @@ -432,8 +449,13 @@ mod test {
.is_none(),
"type arguments invalid"
);
let ext_op = o.clone().to_extension_op().unwrap();
let custom_op: CustomOp = o.clone().to_extension_op().unwrap().into();

assert_eq!(ConcreteIntOp::from_extension_op(&ext_op).unwrap(), o);
assert_eq!(ConcreteIntOp::from_op(&custom_op).unwrap(), o);
assert_eq!(IntOpDef::from_op(&custom_op).unwrap(), IntOpDef::itobool);
assert_eq!(
IntOpDef::from_op(&custom_op.into_opaque().into()).unwrap(),
IntOpDef::itobool
);
}
}
37 changes: 29 additions & 8 deletions hugr-core/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use crate::types::{FuncValueType, Signature};
use crate::{
extension::{
prelude::BOOL_T,
simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError},
simple_op::{
try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp,
OpLoadError,
},
ExtensionId, ExtensionRegistry, OpDef, SignatureError, SignatureFromArgs, SignatureFunc,
},
ops::{self, custom::ExtensionOp, NamedOp},
Expand Down Expand Up @@ -82,6 +85,21 @@ impl MakeOpDef for NaryLogic {
}
}

impl HasConcrete for NaryLogic {
type Concrete = ConcreteLogicOp;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
let [TypeArg::BoundedNat { n }] = type_args else {
return Err(SignatureError::InvalidTypeArgs.into());
};
Ok(self.with_n_inputs(*n))
}
}

impl HasDef for ConcreteLogicOp {
type Def = NaryLogic;
}

/// Make a [NaryLogic] operation concrete by setting the type argument.
#[derive(Debug, Clone, PartialEq)]
pub struct ConcreteLogicOp(pub NaryLogic, u64);
Expand All @@ -101,10 +119,7 @@ impl NamedOp for ConcreteLogicOp {
impl MakeExtensionOp for ConcreteLogicOp {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def: NaryLogic = NaryLogic::from_def(ext_op.def())?;
let [TypeArg::BoundedNat { n }] = *ext_op.args() else {
return Err(SignatureError::InvalidTypeArgs.into());
};
Ok(Self(def, n))
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<TypeArg> {
Expand Down Expand Up @@ -245,9 +260,9 @@ pub(crate) mod test {
use crate::{
extension::{
prelude::BOOL_T,
simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp},
simple_op::{HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp},
},
ops::{NamedOp, Value},
ops::{CustomOp, NamedOp, Value},
Extension,
};

Expand All @@ -273,7 +288,13 @@ pub(crate) mod test {
for def in [NaryLogic::And, NaryLogic::Or] {
let o = def.with_n_inputs(3);
let ext_op = o.clone().to_extension_op().unwrap();
assert_eq!(ConcreteLogicOp::from_extension_op(&ext_op).unwrap(), o);
let custom_op: CustomOp = ext_op.into();
assert_eq!(NaryLogic::from_op(&custom_op).unwrap(), def);
assert_eq!(ConcreteLogicOp::from_op(&custom_op).unwrap(), o);
assert_eq!(
ConcreteLogicOp::from_op(&custom_op.into_opaque().into()).unwrap(),
o
);
}

NotOp::from_extension_op(&NotOp.to_extension_op().unwrap()).unwrap();
Expand Down

0 comments on commit ed35c90

Please sign in to comment.