Skip to content

Commit

Permalink
Program arguments and inverter arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Renmusxd committed Aug 26, 2021
1 parent 8d0c800 commit 02db6cc
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/macros/common_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ wrap_fn!(pub not, UnitaryBuilder::not, r);
wrap_fn!(pub swap, (UnitaryBuilder::swap), ra, rb);

wrap_fn!(pub h, UnitaryBuilder::hadamard, r);

wrap_fn!(pub rz(theta: f64), UnitaryBuilder::rz, r);
54 changes: 50 additions & 4 deletions src/macros/inverter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ use crate::{CircuitError, OpBuilder, Register, UnitaryBuilder};
/// ```
#[macro_export]
macro_rules! wrap_and_invert {
(pub $newfunc:ident($arg:ident: $argtype:ident), pub $newinvert:ident, $($tail:tt)*) => {
wrap_fn!(pub $newfunc($arg: $argtype), $($tail)*);
invert_fn!(pub $newinvert($arg: $argtype), $newfunc);
};
($newfunc:ident($arg:ident: $argtype:ident), pub $newinvert:ident, $($tail:tt)*) => {
wrap_fn!($newfunc($arg: $argtype), $($tail)*);
invert_fn!(pub $newinvert($arg: $argtype), $newfunc);
};
(pub $newfunc:ident($arg:ident: $argtype:ident), $newinvert:ident, $($tail:tt)*) => {
wrap_fn!(pub $newfunc($arg: $argtype), $($tail)*);
invert_fn!($newinvert($arg: $argtype), $newfunc);
};
($newfunc:ident($arg:ident: $argtype:ident), $newinvert:ident, $($tail:tt)*) => {
wrap_fn!($newfunc($arg: $argtype), $($tail)*);
invert_fn!($newinvert($arg: $argtype), $newfunc);
};
(pub $newfunc:ident, pub $newinvert:ident, $($tail:tt)*) => {
wrap_fn!(pub $newfunc, $($tail)*);
invert_fn!(pub $newinvert, $newfunc);
Expand Down Expand Up @@ -71,6 +87,25 @@ macro_rules! wrap_and_invert {
/// ```
#[macro_export]
macro_rules! invert_fn {
(pub $newinvert:ident($arg:ident: $argtype:ident), $func:expr) => {
/// Invert the given function.
pub fn $newinvert(
b: &mut dyn $crate::UnitaryBuilder,
rs: Vec<Register>,
$arg: $argtype,
) -> Result<Vec<Register>, $crate::CircuitError> {
$crate::inverter(b, rs, |b, rs| $func(b, rs, $arg))
}
};
($newinvert:ident($arg:ident: $argtype:ident), $func:expr) => {
fn $newinvert(
b: &mut dyn $crate::UnitaryBuilder,
rs: Vec<Register>,
$arg: $argtype,
) -> Result<Vec<Register>, $crate::CircuitError> {
$crate::inverter(b, rs, |b, rs| $func(b, rs, $arg))
}
};
(pub $newinvert:ident, $func:expr) => {
/// Invert the given function.
pub fn $newinvert(
Expand All @@ -91,13 +126,14 @@ macro_rules! invert_fn {
}

/// Invert a circuit applied via the function f.
pub fn inverter<
F: Fn(&mut dyn UnitaryBuilder, Vec<Register>) -> Result<Vec<Register>, CircuitError>,
>(
pub fn inverter<F>(
b: &mut dyn UnitaryBuilder,
mut rs: Vec<Register>,
f: F,
) -> Result<Vec<Register>, CircuitError> {
) -> Result<Vec<Register>, CircuitError>
where
F: Fn(&mut dyn UnitaryBuilder, Vec<Register>) -> Result<Vec<Register>, CircuitError>,
{
let original_indices: Vec<_> = rs.iter().map(|r| r.indices.clone()).collect();
let mut inv_builder = OpBuilder::new();
let new_rs: Vec<_> = original_indices
Expand Down Expand Up @@ -329,4 +365,14 @@ mod inverter_test {
let _r = b.merge(rs)?;
Ok(())
}

#[test]
fn test_invert_and_wrap_rz() -> Result<(), CircuitError> {
wrap_and_invert!(rz_op(theta: f64), inv_rz, UnitaryBuilder::rz, r);
let mut b = OpBuilder::new();
let r = b.qubit();
let rs = rz_op(&mut b, vec![r], 1.0)?;
let _rs = inv_rz(&mut b, rs, 1.0)?;
Ok(())
}
}
100 changes: 100 additions & 0 deletions src/macros/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,31 @@ macro_rules! program {
program!(@skip_to_next_program($builder, $reg_man) $($tail)*)
};

// Start parsing a program of the form "control function(arg) [register <indices>, ...];"
(@program($builder:expr, $reg_man:ident) control $func:ident($funcargs:expr) $($tail:tt)*) => {
program!(@program($builder, $reg_man) control(!0) $func($funcargs) $($tail)*)
};
// Start parsing a program of the form "control function(arg) [register <indices>, ...];"
(@program($builder:expr, $reg_man:ident) control($control:expr) $func:ident($funcargs:expr) $($tail:tt)*) => {
// Get all args
let mut tmp_acc_vec: Vec<$crate::Register> = vec![];
program!(@args_acc($builder, $reg_man, tmp_acc_vec) $($tail)*);
let tmp_cr = tmp_acc_vec.remove(0);

let tmp_cr = $crate::negate_bitmask($builder, tmp_cr, $control);
let mut tmp_cb = $builder.with_condition(tmp_cr);

// Now all the args are in acc_vec
let mut tmp_results: Vec<$crate::Register> = $func(&mut tmp_cb, tmp_acc_vec, $funcargs)?;

let tmp_cr = tmp_cb.release_register();
let tmp_cr = $crate::negate_bitmask($builder, tmp_cr, $control);

tmp_results.push(tmp_cr);
program!(@replace_registers($builder, $reg_man, tmp_results));
program!(@skip_to_next_program($builder, $reg_man) $($tail)*);
};

// Start parsing a program of the form "control function [register <indices>, ...];"
(@program($builder:expr, $reg_man:ident) control $func:ident $($tail:tt)*) => {
program!(@program($builder, $reg_man) control(!0) $func $($tail)*)
Expand All @@ -307,6 +332,19 @@ macro_rules! program {
program!(@replace_registers($builder, $reg_man, tmp_results));
program!(@skip_to_next_program($builder, $reg_man) $($tail)*);
};

// Start parsing a program of the form "function(arg) [register <indices>, ...];"
(@program($builder:expr, $reg_man:ident) $func:ident($funcargs:expr) $($tail:tt)*) => {
// Get all args
let mut acc_vec: Vec<Register> = vec![];
program!(@args_acc($builder, $reg_man, acc_vec) $($tail)*);

// Now all the args are in acc_vec
let tmp_results: Vec<Register> = $func($builder, acc_vec, $funcargs)?;
program!(@replace_registers($builder, $reg_man, tmp_results));
program!(@skip_to_next_program($builder, $reg_man) $($tail)*);
};

// Start parsing a program of the form "function [register <indices>, ...];"
(@program($builder:expr, $reg_man:ident) $func:ident $($tail:tt)*) => {
// Get all args
Expand Down Expand Up @@ -455,6 +493,22 @@ macro_rules! wrap_fn {
wrap_fn!(@unwrap_regs($func, $rs) $($tail)*);
let $name = $rs.pop().ok_or_else(|| $crate::CircuitError::new(format!("Error unwrapping {} for {}", stringify!($name), stringify!($func))))?;
};
(@result_body($builder:expr, $func:expr, $rs:ident, $arg:ident) $($tail:tt)*) => {
{
wrap_fn!(@unwrap_regs($func, $rs) $($tail)*);
let wrap_fn!(@names () <- $($tail)*) = wrap_fn!(@invoke($func, $builder) () <- $($tail)*, $arg) ?;
let $rs: Vec<$crate::Register> = vec![$($tail)*];
Ok($rs)
}
};
(@raw_body($builder:expr, $func:expr, $rs:ident, $arg:ident) $($tail:tt)*) => {
{
wrap_fn!(@unwrap_regs($func, $rs) $($tail)*);
let wrap_fn!(@names () <- $($tail)*) = wrap_fn!(@invoke($func, $builder) () <- $($tail)*, $arg);
let $rs: Vec<$crate::Register> = vec![$($tail)*];
Ok($rs)
}
};
(@result_body($builder:expr, $func:expr, $rs:ident) $($tail:tt)*) => {
{
wrap_fn!(@unwrap_regs($func, $rs) $($tail)*);
Expand All @@ -471,6 +525,28 @@ macro_rules! wrap_fn {
Ok($rs)
}
};
(pub $newfunc:ident($arg:ident: $argtype:ident), ($func:expr), $($tail:tt)*) => {
/// Wrapped version of function
pub fn $newfunc(b: &mut dyn $crate::UnitaryBuilder, mut rs: Vec<$crate::Register>, $arg: $argtype) -> Result<Vec<$crate::Register>, $crate::CircuitError> {
wrap_fn!(@result_body(b, $func, rs, $arg) $($tail)*)
}
};
(pub $newfunc:ident($arg:ident: $argtype:ident), $func:expr, $($tail:tt)*) => {
/// Wrapped version of function
pub fn $newfunc(b: &mut dyn $crate::UnitaryBuilder, mut rs: Vec<$crate::Register>, $arg: $argtype) -> Result<Vec<$crate::Register>, $crate::CircuitError> {
wrap_fn!(@raw_body(b, $func, rs, $arg) $($tail)*)
}
};
($newfunc:ident($arg:ident: $argtype:ident), ($func:expr), $($tail:tt)*) => {
fn $newfunc(b: &mut dyn $crate::UnitaryBuilder, mut rs: Vec<$crate::Register>, $arg: $argtype) -> Result<Vec<$crate::Register>, $crate::CircuitError> {
wrap_fn!(@result_body(b, $func, rs, $arg) $($tail)*)
}
};
($newfunc:ident($arg:ident: $argtype:ident), $func:expr, $($tail:tt)*) => {
fn $newfunc(b: &mut dyn $crate::UnitaryBuilder, mut rs: Vec<$crate::Register>, $arg: $argtype) -> Result<Vec<$crate::Register>, $crate::CircuitError> {
wrap_fn!(@raw_body(b, $func, rs, $arg) $($tail)*)
}
};
(pub $newfunc:ident, ($func:expr), $($tail:tt)*) => {
/// Wrapped version of function
pub fn $newfunc(b: &mut dyn $crate::UnitaryBuilder, mut rs: Vec<$crate::Register>) -> Result<Vec<$crate::Register>, $crate::CircuitError> {
Expand Down Expand Up @@ -1228,4 +1304,28 @@ mod common_circuit_tests {
assert_eq!(macro_circuit, basic_circuit);
Ok(())
}

#[test]
fn wrap_unitary_fn_arg() -> Result<(), CircuitError> {
wrap_fn!(wrapped_rz(theta: f64), UnitaryBuilder::rz, r);

let mut b = OpBuilder::new();
let r = b.qubit();

let r = program!(&mut b, r;
wrapped_rz(1.0) r;
)?;

run_debug(&r)?;
// Compare to expected value
let macro_circuit = make_circuit_matrix::<f64>(1, &r, true);
let mut b = OpBuilder::new();
let r = b.qubit();

let r = b.rz(r, 1.0);
run_debug(&r)?;
let basic_circuit = make_circuit_matrix::<f64>(1, &r, true);
assert_eq!(macro_circuit, basic_circuit);
Ok(())
}
}

0 comments on commit 02db6cc

Please sign in to comment.