Skip to content

Commit

Permalink
feat(hugr-core): Add pointer standard extension (#1337)
Browse files Browse the repository at this point in the history
ptr<T> type and new, read and write operations.

---------

Co-authored-by: Douglas Wilson <douglas.wilson@quantinuum.com>
  • Loading branch information
ss2165 and doug-q authored Jul 25, 2024
1 parent 7ac015b commit 88af215
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 0 deletions.
1 change: 1 addition & 0 deletions hugr-core/src/std_extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
pub mod arithmetic;
pub mod collections;
pub mod logic;
pub mod ptr;
280 changes: 280 additions & 0 deletions hugr-core/src/std_extensions/ptr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
//! Pointer type and operations.

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::builder::{BuildError, Dataflow};
use crate::extension::TypeDefBound;
use crate::ops::OpName;
use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName};
use crate::Wire;
use crate::{
extension::{
simple_op::{
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
},
ExtensionId, ExtensionRegistry, OpDef, SignatureError, SignatureFunc,
},
ops::{custom::ExtensionOp, NamedOp},
type_row,
types::type_param::{TypeArg, TypeParam},
Extension,
};
use lazy_static::lazy_static;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs)]
#[non_exhaustive]
/// Pointer operation definitions.
pub enum PtrOpDef {
/// Create a new pointer.
New,
/// Read a value from a pointer.
Read,
/// Write a value to a pointer.
Write,
}

impl PtrOpDef {
/// Create a new concrete pointer operation with the given value type.
pub fn with_type(self, ty: Type) -> PtrOp {
PtrOp::new(self, ty)
}
}

impl MakeOpDef for PtrOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized,
{
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}

fn signature(&self) -> SignatureFunc {
let ptr_t = ptr_type(Type::new_var_use(0, TypeBound::Copyable));
let inner_t = Type::new_var_use(0, TypeBound::Copyable);
let body = match self {
PtrOpDef::New => Signature::new(inner_t, ptr_t),
PtrOpDef::Read => Signature::new(ptr_t, inner_t),
PtrOpDef::Write => Signature::new(vec![ptr_t, inner_t], type_row![]),
};

PolyFuncType::new(TYPE_PARAMS, body).into()
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID
}

fn description(&self) -> String {
match self {
PtrOpDef::New => "Create a new pointer from a value.".into(),
PtrOpDef::Read => "Read a value from a pointer.".into(),
PtrOpDef::Write => "Write a value to a pointer, overwriting existing value.".into(),
}
}
}

/// Name of pointer extension.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr");
/// Name of pointer type.
pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr");
const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type {
b: TypeBound::Copyable,
}];
/// Extension for pointer operations.
fn extension() -> Extension {
let mut extension = Extension::new(EXTENSION_ID);
extension
.add_type(
PTR_TYPE_ID,
TYPE_PARAMS.into(),
"Standard extension pointer type.".into(),
TypeDefBound::Explicit(TypeBound::Eq),
)
.unwrap();
PtrOpDef::load_all_ops(&mut extension).unwrap();

extension
}

lazy_static! {
/// Reference to the pointer Extension.
pub static ref EXTENSION: Extension = extension();
/// Registry required to validate pointer extension.
pub static ref PTR_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap();
}

/// Integer type of a given bit width (specified by the TypeArg). Depending on
/// the operation, the semantic interpretation may be unsigned integer, signed
/// integer or bit string.
pub fn ptr_custom_type(ty: impl Into<Type>) -> CustomType {
let ty = ty.into();
CustomType::new(PTR_TYPE_ID, [ty.into()], EXTENSION_ID, TypeBound::Eq)
}

/// Integer type of a given bit width (specified by the TypeArg).
///
/// Constructed from [ptr_custom_type].
pub fn ptr_type(ty: impl Into<Type>) -> Type {
Type::new_extension(ptr_custom_type(ty))
}

#[derive(Clone, Debug, PartialEq)]
/// A concrete pointer operation.
pub struct PtrOp {
/// The operation definition.
pub def: PtrOpDef,
/// Type of the value being pointed to.
pub ty: Type,
}

impl PtrOp {
fn new(op: PtrOpDef, ty: Type) -> Self {
Self { def: op, ty }
}
}

impl NamedOp for PtrOp {
fn name(&self) -> OpName {
self.def.name()
}
}

impl MakeExtensionOp for PtrOp {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = PtrOpDef::from_def(ext_op.def())?;
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<TypeArg> {
vec![self.ty.clone().into()]
}
}

impl MakeRegisteredOp for PtrOp {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&PTR_REG
}
}

/// An extension trait for [Dataflow] providing methods to add pointer
/// operations.
pub trait PtrOpBuilder: Dataflow {
/// Add a "ptr.New" op.
fn add_new_ptr(&mut self, val_wire: Wire) -> Result<Wire, BuildError> {
let ty = self.get_wire_type(val_wire)?;
let handle = self.add_dataflow_op(PtrOpDef::New.with_type(ty), [val_wire])?;

Ok(handle.out_wire(0))
}

/// Add a "ptr.Read" op.
fn add_read_ptr(&mut self, ptr_wire: Wire, ty: Type) -> Result<Wire, BuildError> {
let handle = self.add_dataflow_op(PtrOpDef::Read.with_type(ty.clone()), [ptr_wire])?;
Ok(handle.out_wire(0))
}

/// Add a "ptr.Write" op.
fn add_write_ptr(&mut self, ptr_wire: Wire, val_wire: Wire) -> Result<(), BuildError> {
let ty = self.get_wire_type(val_wire)?;

let handle = self.add_dataflow_op(PtrOpDef::Write.with_type(ty), [ptr_wire, val_wire])?;
debug_assert_eq!(handle.outputs().len(), 0);
Ok(())
}
}

impl<D: Dataflow> PtrOpBuilder for D {}

impl HasConcrete for PtrOpDef {
type Concrete = PtrOp;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
let ty = match type_args {
[TypeArg::Type { ty }] => ty.clone(),
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};

Ok(self.with_type(ty))
}
}

impl HasDef for PtrOp {
type Def = PtrOpDef;
}

#[cfg(test)]
pub(crate) mod test {
use crate::builder::DFGBuilder;
use crate::extension::prelude::BOOL_T;
use crate::ops::CustomOp;
use crate::{
builder::{Dataflow, DataflowHugr},
ops::NamedOp,
std_extensions::arithmetic::int_types::INT_TYPES,
};
use cool_asserts::assert_matches;
use std::sync::Arc;
use strum::IntoEnumIterator;

use super::*;
use crate::std_extensions::arithmetic::float_types::{
EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE,
};
fn get_opdef(op: impl NamedOp) -> Option<&'static Arc<OpDef>> {
EXTENSION.get_op(&op.name())
}

#[test]
fn create_extension() {
assert_eq!(EXTENSION.name(), &EXTENSION_ID);

for o in PtrOpDef::iter() {
assert_eq!(PtrOpDef::from_def(get_opdef(o).unwrap()), Ok(o));
}
}

#[test]
fn test_ops() {
let ops = [
PtrOp::new(PtrOpDef::New, BOOL_T.clone()),
PtrOp::new(PtrOpDef::Read, FLOAT64_TYPE.clone()),
PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()),
];
for op in ops {
let op_t: CustomOp = op.clone().to_extension_op().unwrap().into();
let def_op = PtrOpDef::from_op(&op_t).unwrap();
assert_eq!(op.def, def_op);
let new_op = PtrOp::from_op(&op_t).unwrap();
assert_eq!(new_op, op);
}
}

#[test]
fn test_build() {
let in_row = vec![BOOL_T, FLOAT64_TYPE];

let reg =
ExtensionRegistry::try_new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]).unwrap();
let hugr = {
let mut builder = DFGBuilder::new(
Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID),
)
.unwrap();

let in_wires: [Wire; 2] = builder.input_wires_arr();
for (ty, w) in in_row.into_iter().zip(in_wires.iter()) {
let new_ptr = builder.add_new_ptr(*w).unwrap();
let read = builder.add_read_ptr(new_ptr, ty).unwrap();
builder.add_write_ptr(new_ptr, read).unwrap();
}

builder.finish_hugr_with_outputs([], &reg).unwrap()
};
assert_matches!(hugr.validate(&reg), Ok(_));
}
}

0 comments on commit 88af215

Please sign in to comment.