Skip to content

Commit

Permalink
accept AsSubcircuit and IndicesInfo for apply
Browse files Browse the repository at this point in the history
  • Loading branch information
OmmyZhang committed Mar 5, 2024
1 parent 5b5fe2d commit d251009
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 80 deletions.
264 changes: 184 additions & 80 deletions qip/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,143 @@ type Registers<CB, const N: usize> = [<CB as CircuitBuilder>::Register; N];
type CircuitFunction<CB, const N: usize> =
dyn Fn(&mut CB, Registers<CB, N>) -> CircuitResult<Registers<CB, N>>;

/// indices for N registers
pub type Idx<const N: usize> = [Vec<(usize, usize)>; N];

/// A subcircuit that you can apply
pub trait AsSubcircuit<CB: CircuitBuilder, const L: usize> {
/// The innner function of this circuit
fn circuit_func(self) -> Rc<CircuitFunction<CB, L>>;
}

impl<CB: CircuitBuilder, const L: usize, F> AsSubcircuit<CB, L> for F
where
F: Fn(&mut CB, Registers<CB, L>) -> CircuitResult<Registers<CB, L>> + 'static,
{
fn circuit_func(self) -> Rc<CircuitFunction<CB, L>> {
Rc::new(self)
}
}

/// Provide indices information when apply a subcircuit
pub trait IndicesInfo<CB: CircuitBuilder, const N: usize, const L: usize>: 'static {
/// Temporary intermediate type for storing other registers
type IntermediateRegisters;

/// Get new registers
fn get_new_registers(
&self,
cb: &mut CB,
orig_rs: Registers<CB, N>,
) -> (Self::IntermediateRegisters, Registers<CB, L>);

/// Restore original registers
fn restore_original_registers(
&self,
cb: &mut CB,
itm_rs: Self::IntermediateRegisters,
sub_rs: Registers<CB, L>,
) -> Registers<CB, N>;
}

impl<CB: CircuitBuilder, const N: usize, const L: usize> IndicesInfo<CB, N, L> for [usize; L] {
type IntermediateRegisters = [Option<CB::Register>; N];

fn get_new_registers(
&self,
_cb: &mut CB,
orig_rs: Registers<CB, N>,
) -> (Self::IntermediateRegisters, Registers<CB, L>) {
let mut itm = orig_rs.map(Some);
let sub_rs = self.map(|idx| itm[idx].take().unwrap());
(itm, sub_rs)
}

fn restore_original_registers(
&self,
_cb: &mut CB,
mut itm_rs: Self::IntermediateRegisters,
sub_rs: Registers<CB, L>,
) -> Registers<CB, N> {
iter::zip(self, sub_rs).for_each(|(&idx, r)| itm_rs[idx] = Some(r));
itm_rs.map(|r| r.unwrap())
}
}

impl<CB: CircuitBuilder, const N: usize, const L: usize, MAP> IndicesInfo<CB, N, L> for MAP
where
MAP: Fn(Idx<N>) -> Idx<L> + 'static,
{
type IntermediateRegisters = [Vec<Option<CB::Register>>; N];

fn get_new_registers(
&self,
cb: &mut CB,
orig_rs: Registers<CB, N>,
) -> (Self::IntermediateRegisters, Registers<CB, L>) {
let mut itm = orig_rs.map(|r| {
let qubits = cb.split_all_register(r);
qubits.into_iter().map(Some).collect::<Vec<_>>()
});

let init_indices: [Vec<(usize, usize)>; N] = itm
.iter()
.enumerate()
.map(|(reg_idx, qubits)| {
(0..qubits.len())
.map(|idx| (reg_idx, idx))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let new_indices = self(init_indices);

let sub_rs = new_indices.map(|qubit_positions| {
cb.merge_registers(
qubit_positions
.iter()
.map(|&(reg_idx, idx)| itm[reg_idx][idx].take().unwrap()),
)
.unwrap()
});
(itm, sub_rs)
}

fn restore_original_registers(
&self,
cb: &mut CB,
mut itm_rs: Self::IntermediateRegisters,
sub_rs: Registers<CB, L>,
) -> Registers<CB, N> {
let init_indices: [Vec<(usize, usize)>; N] = itm_rs
.iter()
.enumerate()
.map(|(reg_idx, qubits)| {
(0..qubits.len())
.map(|idx| (reg_idx, idx))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let new_indices = self(init_indices);

iter::zip(new_indices, sub_rs.map(|r| cb.split_all_register(r))).for_each(
|(qubit_positions, out_qubits)| {
iter::zip(qubit_positions, out_qubits).for_each(|((reg_idx, idx), qubit)| {
itm_rs[reg_idx][idx] = Some(qubit);
});
},
);

itm_rs.map(|qubits| {
cb.merge_registers(qubits.into_iter().map(|qubit| qubit.unwrap()))
.unwrap()
})
}
}

/// A circuit described by a function
#[derive(Clone)]
pub struct Circuit<CB: CircuitBuilder, const N: usize> {
Expand All @@ -30,6 +167,18 @@ impl<CB: CircuitBuilder, const N: usize> Default for Circuit<CB, N> {
}
}

impl<CB: CircuitBuilder, const N: usize> AsSubcircuit<CB, N> for Circuit<CB, N> {
fn circuit_func(self) -> Rc<CircuitFunction<CB, N>> {
self.func.clone()
}
}

impl<CB: CircuitBuilder, const N: usize> AsSubcircuit<CB, N> for &Circuit<CB, N> {
fn circuit_func(self) -> Rc<CircuitFunction<CB, N>> {
self.func.clone()
}
}

impl<CB: CircuitBuilder + 'static, const N: usize> Circuit<CB, N> {
/// From a function
pub fn from<F>(f: F) -> Self
Expand All @@ -40,88 +189,20 @@ impl<CB: CircuitBuilder + 'static, const N: usize> Circuit<CB, N> {
}

/// Apply a function to part of this circuit
pub fn apply<F, const L: usize>(self, f: F, indices: [usize; L]) -> Self
where
F: Fn(&mut CB, Registers<CB, L>) -> CircuitResult<Registers<CB, L>> + 'static,
{
let func = self.func.clone();
Self {
func: Rc::new(move |cb, rs| {
let out = (*func)(cb, rs)?;
let mut out = out.map(Some);
let f_input = indices.map(|idx| out[idx].take().unwrap());
let f_out = f(cb, f_input)?;

iter::zip(indices, f_out).for_each(|(idx, r)| out[idx] = Some(r));

Ok(out.map(|r| r.unwrap()))
}),
}
}

/// Apply a sub circuit for specific qubits under some new indices combine
pub fn apply_subcircuit<MAP, const L: usize>(
pub fn apply<const L: usize>(
self,
indices_map: MAP,
sub_circuit: &Circuit<CB, L>,
) -> Self
where
MAP: Fn([Vec<(usize, usize)>; N]) -> [Vec<(usize, usize)>; L] + 'static,
{
subcircuit: impl AsSubcircuit<CB, L>,
indices: impl IndicesInfo<CB, N, L>,
) -> Self {
let func = self.func.clone();
let sub_func = sub_circuit.func.clone();

let sub_func = subcircuit.circuit_func();
Self {
func: Rc::new(move |cb, rs| {
let out = (*func)(cb, rs)?;
let (itm, f_input) = indices.get_new_registers(cb, out);
let f_out = sub_func(cb, f_input)?;

//split
let mut out = out.map(|r| {
let qubits = cb.split_all_register(r);
qubits.into_iter().map(Some).collect::<Vec<_>>()
});

// combine to new registers
let init_indices: [Vec<(usize, usize)>; N] = out
.iter()
.enumerate()
.map(|(reg_idx, qubits)| {
(0..qubits.len())
.map(|idx| (reg_idx, idx))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let new_indices = indices_map(init_indices);

let f_input = new_indices.clone().map(|qubit_positions| {
cb.merge_registers(
qubit_positions
.iter()
.map(|&(reg_idx, idx)| out[reg_idx][idx].take().unwrap()),
)
.unwrap()
});

let f_output = (*sub_func)(cb, f_input)?;
let f_output_qubits = f_output.map(|r| cb.split_all_register(r));

// restore
iter::zip(new_indices, f_output_qubits).for_each(
|(qubit_positions, out_qubits)| {
iter::zip(qubit_positions, out_qubits).for_each(
|((reg_idx, idx), qubit)| {
out[reg_idx][idx] = Some(qubit);
},
);
},
);

Ok(out.map(|qubits| {
cb.merge_registers(qubits.into_iter().map(|qubit| qubit.unwrap()))
.unwrap()
}))
Ok(indices.restore_original_registers(cb, itm, f_out))
}),
}
}
Expand Down Expand Up @@ -154,6 +235,8 @@ mod tests {
use super::*;
use crate::prelude::*;

type CurrentBuilderType = LocalBuilder<f64>;

fn gamma<B>(b: &mut B, rs: [B::Register; 2]) -> CircuitResult<[B::Register; 2]>
where
B: AdvancedCircuitBuilder<f64>,
Expand All @@ -166,7 +249,7 @@ mod tests {

#[test]
fn test_chain_circuit() -> CircuitResult<()> {
let mut b = LocalBuilder::default();
let mut b = CurrentBuilderType::default();
let ra = b.try_register(3).unwrap();
let rb = b.try_register(3).unwrap();

Expand All @@ -176,11 +259,32 @@ mod tests {
// Applies gamma to |ra>|rb>
.apply(gamma, [0, 1])
// Applies gamma to |ra[0] ra[1]>|ra[2]>
.apply_subcircuit(|[ra, _]| [ra[0..=1].to_vec(), vec![ra[2]]], &gamma_circuit)
.apply(&gamma_circuit, |[ra, _]: Idx<2>| {
[ra[0..=1].to_vec(), vec![ra[2]]]
})
// Applies gamma to |ra[0] rb[0]>|ra[2]>
.apply_subcircuit(|[ra, rb]| [vec![ra[0], rb[0]], vec![ra[2]]], &gamma_circuit)
.apply(&gamma_circuit, |[ra, rb]: Idx<2>| {
[vec![ra[0], rb[0]], vec![ra[2]]]
})
// Applies gamma to |ra[0]>|rb[0] ra[2]>
.apply_subcircuit(|[ra, rb]| [vec![ra[0]], vec![rb[0], ra[2]]], &gamma_circuit)
.apply(&gamma_circuit, |[ra, rb]: Idx<2>| {
[vec![ra[0]], vec![rb[0], ra[2]]]
})
// Applies a more complex subcircuit to |ra[1]>|ra[2]>|rb>
.apply(
Circuit::default()
.apply(gamma, [0, 1])
.apply(gamma, [1, 2])
.apply(
|b: &mut CurrentBuilderType, rs| {
let [ra, rb] = rs;
let (ra, rb) = b.cnot(ra, rb)?;
Ok([ra, rb])
},
|[_, r2, r3]: Idx<3>| [r2, vec![r3[0]]],
),
|[ra, rb]: Idx<2>| [vec![ra[1]], vec![ra[2]], rb],
)
.input([ra, rb])
.run(&mut b)?;

Expand Down
3 changes: 3 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[toolchain]
channel = "beta"
components = [ "rustfmt", "rustc-dev" ]

0 comments on commit d251009

Please sign in to comment.