Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Move int conversions to conversions ext, add to/from usize #1490

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,13 @@ pub enum OpaqueOpError {
#[cfg(test)]
mod test {

use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY};
use crate::{
extension::{
prelude::{BOOL_T, QB_T, USIZE_T},
SignatureFunc,
},
std_extensions::arithmetic::{
int_ops::{self, INT_OPS_REGISTRY},
int_types::INT_TYPES,
},
std_extensions::arithmetic::int_types::INT_TYPES,
types::FuncValueType,
Extension,
};
Expand Down Expand Up @@ -387,10 +385,10 @@ mod test {

#[test]
fn resolve_opaque_op() {
let registry = &INT_OPS_REGISTRY;
let registry = &CONVERT_OPS_REGISTRY;
let i0 = &INT_TYPES[0];
let opaque = OpaqueOp::new(
int_ops::EXTENSION_ID,
conversions::EXTENSION_ID,
"itobool",
"description".into(),
vec![],
Expand Down
176 changes: 144 additions & 32 deletions hugr-core/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::{BOOL_T, STRING_TYPE, USIZE_T};
use crate::extension::simple_op::{HasConcrete, HasDef};
use crate::ops::OpName;
use crate::std_extensions::arithmetic::int_ops::int_polytype;
use crate::std_extensions::arithmetic::int_types::int_type;
use crate::{
extension::{
prelude::sum_with_error,
Expand All @@ -12,12 +16,12 @@ use crate::{
},
ops::{custom::ExtensionOp, NamedOp},
type_row,
types::{FuncValueType, PolyFuncTypeRV, TypeArg, TypeRV},
types::{TypeArg, TypeRV},
Extension,
};

use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use super::float_types::FLOAT64_TYPE;
use super::int_types::{get_log_width, int_tv};
use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
Expand All @@ -34,6 +38,12 @@ pub enum ConvertOpDef {
trunc_s,
convert_u,
convert_s,
itobool,
ifrombool,
itostring_u,
itostring_s,
itousize,
ifromusize,
}

impl MakeOpDef for ConvertOpDef {
Expand All @@ -47,18 +57,19 @@ impl MakeOpDef for ConvertOpDef {

fn signature(&self) -> SignatureFunc {
use ConvertOpDef::*;
PolyFuncTypeRV::new(
vec![LOG_WIDTH_TYPE_PARAM],
match self {
trunc_s | trunc_u => FuncValueType::new(
type_row![FLOAT64_TYPE],
TypeRV::from(sum_with_error(int_tv(0))),
),
convert_s | convert_u => {
FuncValueType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE])
}
},
)
match self {
trunc_s | trunc_u => int_polytype(
1,
type_row![FLOAT64_TYPE],
TypeRV::from(sum_with_error(int_tv(0))),
),
convert_s | convert_u => int_polytype(1, vec![int_tv(0)], type_row![FLOAT64_TYPE]),
itobool => int_polytype(0, vec![int_type(0)], vec![BOOL_T]),
ifrombool => int_polytype(0, vec![BOOL_T], vec![int_type(0)]),
itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![STRING_TYPE]),
itousize => int_polytype(0, vec![int_type(6)], vec![USIZE_T]),
ifromusize => int_polytype(0, vec![USIZE_T], vec![int_type(6)]),
}
.into()
}

Expand All @@ -69,6 +80,12 @@ impl MakeOpDef for ConvertOpDef {
trunc_s => "float to signed int",
convert_u => "unsigned int to float",
convert_s => "signed int to float",
itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
itostring_s => "convert a signed integer to its string representation",
itostring_u => "convert an unsigned integer to its string representation",
itousize => "convert a 64b unsigned integer to its usize representation",
ifromusize => "convert a usize to a 64b unsigned integer",
}
.to_string()
}
Expand All @@ -79,19 +96,44 @@ impl MakeOpDef for ConvertOpDef {
}

impl ConvertOpDef {
/// Initialise a conversion op with an integer log width type argument.
/// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires no
/// integer widths set.
pub fn without_log_width(self) -> ConvertOpType {
ConvertOpType {
def: self,
log_width: None,
}
}
/// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires one
/// integer width set.
pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width,
log_width: Some(log_width),
}
}
}
/// Concrete convert operation with integer log width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
/// The kind of conversion op.
def: ConvertOpDef,
log_width: u8,
/// The integer width parameter of the conversion op, if any. This is interpreted
/// differently, depending on `def`. The integer types in the inputs and
/// outputs of the op will have [int_type]s of this width.
log_width: Option<u8>,
}

impl ConvertOpType {
/// Returns the generic [ConvertOpDef] of this [ConvertOpType].
pub fn def(&self) -> &ConvertOpDef {
&self.def
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a getter for the widths too


/// Returns the integer width parameters of this [ConvertOpType], if any.
pub fn log_widths(&self) -> &[u8] {
self.log_width.as_slice()
}
}

impl NamedOp for ConvertOpType {
Expand All @@ -103,20 +145,11 @@ impl NamedOp for ConvertOpType {
impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let log_width: u64 = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self {
def,
log_width: u8::try_from(log_width).unwrap(),
})
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat {
n: self.log_width as u64,
}]
fn type_args(&self) -> Vec<TypeArg> {
self.log_width.iter().map(|&n| (n as u64).into()).collect()
}
}

Expand Down Expand Up @@ -157,17 +190,96 @@ impl MakeRegisteredOp for ConvertOpType {
}
}

impl HasConcrete for ConvertOpDef {
type Concrete = ConvertOpType;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
let log_width = match type_args {
[] => None,
[arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(ConvertOpType {
def: *self,
log_width,
})
}
}

impl HasDef for ConvertOpType {
type Def = ConvertOpDef;
}

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::extension::prelude::ConstUsize;
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::IncomingPort;

use super::*;

#[test]
fn test_conversions_extension() {
let r = &EXTENSION;
assert_eq!(r.name() as &str, "arithmetic.conversions");
assert_eq!(r.types().count(), 0);
for (name, _) in r.operations() {
assert!(name.as_str().starts_with("convert") || name.as_str().starts_with("trunc"));
}

#[test]
fn test_conversions() {
// Initialization with an invalid number of type arguments should fail.
assert!(
ConvertOpDef::itobool
.with_log_width(1)
.to_extension_op()
.is_none(),
"type arguments invalid"
);

// This should work
let o = ConvertOpDef::itobool.without_log_width();
let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();

assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
assert_eq!(
ConvertOpDef::from_op(&ext_op).unwrap(),
ConvertOpDef::itobool
);
}

#[rstest]
#[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
#[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
#[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
#[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
#[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
#[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
fn convert_fold(
#[case] op: ConvertOpType,
#[case] inputs: &[Value],
#[case] outputs: &[Value],
) {
use crate::ops::Value;

let consts: Vec<(IncomingPort, Value)> = inputs
.iter()
.enumerate()
.map(|(i, v)| (i.into(), v.clone()))
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let res_val: &Value = &res.get(i).unwrap().1;

assert_eq!(res_val, expected);
}
}
}
Loading
Loading