-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(hugr-core): Add pointer standard extension (#1337)
ptr<T> type and new, read and write operations. --------- Co-authored-by: Douglas Wilson <douglas.wilson@quantinuum.com>
- Loading branch information
Showing
2 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
pub mod arithmetic; | ||
pub mod collections; | ||
pub mod logic; | ||
pub mod ptr; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
//! Pointer type and operations. | ||
|
||
use strum_macros::{EnumIter, EnumString, IntoStaticStr}; | ||
|
||
use crate::builder::{BuildError, Dataflow}; | ||
use crate::extension::TypeDefBound; | ||
use crate::ops::OpName; | ||
use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName}; | ||
use crate::Wire; | ||
use crate::{ | ||
extension::{ | ||
simple_op::{ | ||
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, | ||
}, | ||
ExtensionId, ExtensionRegistry, OpDef, SignatureError, SignatureFunc, | ||
}, | ||
ops::{custom::ExtensionOp, NamedOp}, | ||
type_row, | ||
types::type_param::{TypeArg, TypeParam}, | ||
Extension, | ||
}; | ||
use lazy_static::lazy_static; | ||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] | ||
#[allow(missing_docs)] | ||
#[non_exhaustive] | ||
/// Pointer operation definitions. | ||
pub enum PtrOpDef { | ||
/// Create a new pointer. | ||
New, | ||
/// Read a value from a pointer. | ||
Read, | ||
/// Write a value to a pointer. | ||
Write, | ||
} | ||
|
||
impl PtrOpDef { | ||
/// Create a new concrete pointer operation with the given value type. | ||
pub fn with_type(self, ty: Type) -> PtrOp { | ||
PtrOp::new(self, ty) | ||
} | ||
} | ||
|
||
impl MakeOpDef for PtrOpDef { | ||
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> | ||
where | ||
Self: Sized, | ||
{ | ||
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) | ||
} | ||
|
||
fn signature(&self) -> SignatureFunc { | ||
let ptr_t = ptr_type(Type::new_var_use(0, TypeBound::Copyable)); | ||
let inner_t = Type::new_var_use(0, TypeBound::Copyable); | ||
let body = match self { | ||
PtrOpDef::New => Signature::new(inner_t, ptr_t), | ||
PtrOpDef::Read => Signature::new(ptr_t, inner_t), | ||
PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]), | ||
}; | ||
|
||
PolyFuncType::new(TYPE_PARAMS, body).into() | ||
} | ||
|
||
fn extension(&self) -> ExtensionId { | ||
EXTENSION_ID | ||
} | ||
|
||
fn description(&self) -> String { | ||
match self { | ||
PtrOpDef::New => "Create a new pointer from a value.".into(), | ||
PtrOpDef::Read => "Read a value from a pointer.".into(), | ||
PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(), | ||
} | ||
} | ||
} | ||
|
||
/// Name of pointer extension. | ||
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr"); | ||
/// Name of pointer type. | ||
pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr"); | ||
const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type { | ||
b: TypeBound::Copyable, | ||
}]; | ||
/// Extension for pointer operations. | ||
fn extension() -> Extension { | ||
let mut extension = Extension::new(EXTENSION_ID); | ||
extension | ||
.add_type( | ||
PTR_TYPE_ID, | ||
TYPE_PARAMS.into(), | ||
"Standard extension pointer type.".into(), | ||
TypeDefBound::Explicit(TypeBound::Eq), | ||
) | ||
.unwrap(); | ||
PtrOpDef::load_all_ops(&mut extension).unwrap(); | ||
|
||
extension | ||
} | ||
|
||
lazy_static! { | ||
/// Reference to the pointer Extension. | ||
pub static ref EXTENSION: Extension = extension(); | ||
/// Registry required to validate pointer extension. | ||
pub static ref PTR_REG: ExtensionRegistry = | ||
ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); | ||
} | ||
|
||
/// Integer type of a given bit width (specified by the TypeArg). Depending on | ||
/// the operation, the semantic interpretation may be unsigned integer, signed | ||
/// integer or bit string. | ||
pub fn ptr_custom_type(ty: impl Into<Type>) -> CustomType { | ||
let ty = ty.into(); | ||
CustomType::new(PTR_TYPE_ID, [ty.into()], EXTENSION_ID, TypeBound::Eq) | ||
} | ||
|
||
/// Integer type of a given bit width (specified by the TypeArg). | ||
/// | ||
/// Constructed from [ptr_custom_type]. | ||
pub fn ptr_type(ty: impl Into<Type>) -> Type { | ||
Type::new_extension(ptr_custom_type(ty)) | ||
} | ||
|
||
#[derive(Clone, Debug, PartialEq)] | ||
/// A concrete pointer operation. | ||
pub struct PtrOp { | ||
/// The operation definition. | ||
pub def: PtrOpDef, | ||
/// Type of the value being pointed to. | ||
pub ty: Type, | ||
} | ||
|
||
impl PtrOp { | ||
fn new(op: PtrOpDef, ty: Type) -> Self { | ||
Self { def: op, ty } | ||
} | ||
} | ||
|
||
impl NamedOp for PtrOp { | ||
fn name(&self) -> OpName { | ||
self.def.name() | ||
} | ||
} | ||
|
||
impl MakeExtensionOp for PtrOp { | ||
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> { | ||
let def = PtrOpDef::from_def(ext_op.def())?; | ||
def.instantiate(ext_op.args()) | ||
} | ||
|
||
fn type_args(&self) -> Vec<TypeArg> { | ||
vec![self.ty.clone().into()] | ||
} | ||
} | ||
|
||
impl MakeRegisteredOp for PtrOp { | ||
fn extension_id(&self) -> ExtensionId { | ||
EXTENSION_ID.to_owned() | ||
} | ||
|
||
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { | ||
&PTR_REG | ||
} | ||
} | ||
|
||
/// An extension trait for [Dataflow] providing methods to add pointer | ||
/// operations. | ||
pub trait PtrOpBuilder: Dataflow { | ||
/// Add a "ptr.New" op. | ||
fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> { | ||
let ty = self.get_wire_type(val_wire)?; | ||
let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?; | ||
|
||
Ok(handle.out_wire(0)) | ||
} | ||
|
||
/// Add a "ptr.Read" op. | ||
fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> { | ||
let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?; | ||
Ok(handle.out_wire(0)) | ||
} | ||
|
||
/// Add a "ptr.Write" op. | ||
fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> { | ||
let ty = self.get_wire_type(val_wire)?; | ||
|
||
let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?; | ||
debug_assert_eq!(handle.outputs().len(), 0); | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl<D: Dataflow> PtrOpBuilder for D {} | ||
|
||
impl HasConcrete for PtrOpDef { | ||
type Concrete = PtrOp; | ||
|
||
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> { | ||
let ty = match type_args { | ||
[TypeArg::Type { ty }] => ty.clone(), | ||
_ => return Err(SignatureError::InvalidTypeArgs.into()), | ||
}; | ||
|
||
Ok(self.with_type(ty)) | ||
} | ||
} | ||
|
||
impl HasDef for PtrOp { | ||
type Def = PtrOpDef; | ||
} | ||
|
||
#[cfg(test)] | ||
pub(crate) mod test { | ||
use crate::builder::DFGBuilder; | ||
use crate::extension::prelude::BOOL_T; | ||
use crate::ops::CustomOp; | ||
use crate::{ | ||
builder::{Dataflow, DataflowHugr}, | ||
ops::NamedOp, | ||
std_extensions::arithmetic::int_types::INT_TYPES, | ||
}; | ||
use cool_asserts::assert_matches; | ||
use std::sync::Arc; | ||
use strum::IntoEnumIterator; | ||
|
||
use super::*; | ||
use crate::std_extensions::arithmetic::float_types::{ | ||
EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE, | ||
}; | ||
fn get_opdef(op: impl NamedOp) -> Option<&'static Arc<OpDef>> { | ||
EXTENSION.get_op(&op.name()) | ||
} | ||
|
||
#[test] | ||
fn create_extension() { | ||
assert_eq!(EXTENSION.name(), &EXTENSION_ID); | ||
|
||
for o in PtrOpDef::iter() { | ||
assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o)); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_ops() { | ||
let ops = [ | ||
PtrOp::new(PtrOpDef::New, BOOL_T.clone()), | ||
PtrOp::new(PtrOpDef::Read, FLOAT64_TYPE.clone()), | ||
PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()), | ||
]; | ||
for op in ops { | ||
let op_t: CustomOp = op.clone().to_extension_op().unwrap().into(); | ||
let def_op = PtrOpDef::from_op(&op_t).unwrap(); | ||
assert_eq!(op.def, def_op); | ||
let new_op = PtrOp::from_op(&op_t).unwrap(); | ||
assert_eq!(new_op, op); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_build() { | ||
let in_row = vec![BOOL_T, FLOAT64_TYPE]; | ||
|
||
let reg = | ||
ExtensionRegistry::try_new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]).unwrap(); | ||
let hugr = { | ||
let mut builder = DFGBuilder::new( | ||
Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), | ||
) | ||
.unwrap(); | ||
|
||
let in_wires: [Wire; 2] = builder.input_wires_arr(); | ||
for (ty, w) in in_row.into_iter().zip(in_wires.iter()) { | ||
let new_ptr = builder.add_new_ptr(*w).unwrap(); | ||
let read = builder.add_read_ptr(new_ptr, ty).unwrap(); | ||
builder.add_write_ptr(new_ptr, read).unwrap(); | ||
} | ||
|
||
builder.finish_hugr_with_outputs([], ®).unwrap() | ||
}; | ||
assert_matches!(hugr.validate(®), Ok(_)); | ||
} | ||
} |