Skip to content

Commit

Permalink
Rewrite cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Renmusxd committed Feb 23, 2022
1 parent abfd9eb commit 91e98a9
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 95 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "qip"
version = "0.16.0"
version = "1.0.0"
authors = ["Sumner Hearth <sumnernh@gmail.com>"]
description = "A library for efficient quantum computing simulations."
repository = "https://github.com/Renmusxd/RustQIP"
Expand All @@ -19,6 +19,6 @@ optimization = []
num-rational = "^0.4"
num-traits = "^0.2"
num-complex = "^0.4"
rayon = {version = "^1.5", optional = true }
rayon = { version = "^1.5", optional = true }
rand = "^0.8"
smallvec = "^1.7"
smallvec = "^1.8"
3 changes: 1 addition & 2 deletions examples/optimizer_example.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use qip::builder::{BuilderCircuitObject, Qudit};
use qip::macros::program_ops::*;
use qip::prelude::*;
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -88,7 +87,7 @@ fn main() -> Result<(), CircuitError> {

let mut b = LocalBuilder::<f64>::default();
let r = b.register(NonZeroUsize::new(3).unwrap());
b.apply_optimizer_circuit(r, opt.get_ops());
b.apply_optimizer_circuit(r, opt.get_ops())?;
let (state, _) = b.calculate_state();
println!("{:?}", state);

Expand Down
181 changes: 111 additions & 70 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ pub struct BuilderCircuitObject<P: Precision> {
object: BuilderCircuitObjectType<P>,
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub enum BuilderCircuitObjectType<P: Precision> {
Unitary(UnitaryMatrixObject<P>),
Measurement(MeasurementObject),
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub enum UnitaryMatrixObject<P: Precision> {
X,
Y,
Expand All @@ -114,7 +114,7 @@ pub enum UnitaryMatrixObject<P: Precision> {
GlobalPhase(RotationObject<P>),
}

#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Clone)]
pub enum RotationObject<P: Precision> {
Floating(P),
PiRational(Ratio<i64>),
Expand All @@ -129,6 +129,110 @@ impl<P: Precision> RotationObject<P> {
}
}

impl<P: Precision> PartialEq for UnitaryMatrixObject<P> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::X, Self::X)
| (Self::Y, Self::Y)
| (Self::Z, Self::Z)
| (Self::H, Self::H)
| (Self::S, Self::S)
| (Self::T, Self::T)
| (Self::CNOT, Self::CNOT)
| (Self::SWAP, Self::SWAP) => true,
(Self::Rz(ra), Self::Rz(rb)) => ra.eq(rb),
(Self::MAT(ma), Self::MAT(mb)) => ma.eq(mb),
(Self::GlobalPhase(ra), Self::GlobalPhase(rb)) => ra.eq(rb),
(_, _) => false,
}
}
}

impl<P: Precision> PartialEq for BuilderCircuitObjectType<P> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Unitary(ua), Self::Unitary(ub)) => ua.eq(ub),
(Self::Measurement(ma), Self::Measurement(mb)) => ma.eq(mb),
(_, _) => false,
}
}
}

impl<P: Precision> PartialEq for RotationObject<P> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Floating(pa), Self::Floating(pb)) => pa.eq(pb),
(Self::PiRational(ra), Self::PiRational(rb)) => ra.eq(rb),
(_, _) => false,
}
}
}

impl<P: Precision> Eq for UnitaryMatrixObject<P> {}

impl<P: Precision> Eq for BuilderCircuitObjectType<P> {}

impl<P: Precision> Eq for RotationObject<P> {}

fn hash_p<P: Precision, H: Hasher>(f: P, state: &mut H) {
format!("{}", f).hash(state)
}

impl<P: Precision> Hash for BuilderCircuitObjectType<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
BuilderCircuitObjectType::Measurement(m) => {
state.write_i8(0);
m.hash(state)
}
BuilderCircuitObjectType::Unitary(u) => {
state.write_i8(1);
u.hash(state)
}
}
}
}

impl<P: Precision> Hash for UnitaryMatrixObject<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
UnitaryMatrixObject::X => state.write_i8(0),
UnitaryMatrixObject::Y => state.write_i8(1),
UnitaryMatrixObject::Z => state.write_i8(2),
UnitaryMatrixObject::H => state.write_i8(3),
UnitaryMatrixObject::S => state.write_i8(4),
UnitaryMatrixObject::T => state.write_i8(5),
UnitaryMatrixObject::CNOT => state.write_i8(6),
UnitaryMatrixObject::SWAP => state.write_i8(7),
UnitaryMatrixObject::Rz(rot) => {
state.write_i8(8);
rot.hash(state);
}
UnitaryMatrixObject::GlobalPhase(rot) => {
state.write_i8(9);
rot.hash(state);
}
UnitaryMatrixObject::MAT(data) => {
state.write_i8(10);
data.iter().for_each(|c| {
hash_p(c.re, state);
hash_p(c.im, state);
})
}
}
}
}

impl<P: Precision> Hash for RotationObject<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
// Grossly inefficient but also don't hash floats.
RotationObject::Floating(f) => hash_p(*f, state),
RotationObject::PiRational(r) => r.hash(state),
}
}
}

#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum MeasurementObject {
Measurement,
Expand Down Expand Up @@ -247,15 +351,15 @@ impl<P: Precision> CircuitBuilder for LocalBuilder<P> {

let mut initial_index = 0;
it.into_iter()
.map(|(r, x)| {
.flat_map(|(r, x)| {
let rn = r.n();
r.indices
.iter()
.rev()
.cloned()
.enumerate()
.map(move |(ri, i)| (r.n() - 1 - i, (x >> ri) & 1))
.map(move |(ri, i)| (n - 1 - i, (x >> (rn - 1 - ri)) & 1))
})
.flatten()
.for_each(|(index, bit)| initial_index |= bit << index);
state[initial_index] = Complex::one();

Expand Down Expand Up @@ -707,8 +811,7 @@ where
let mut rs = cb.split_all_register(r);
let max_r_index = sc
.iter()
.map(|(indices, _)| indices.iter().cloned().max())
.flatten()
.flat_map(|(indices, _)| indices.iter().cloned().max())
.max()
.unwrap();
// Need temp qubits for the max_r_index - rn
Expand Down Expand Up @@ -797,68 +900,6 @@ pub mod optimizers {
use crate::optimizer::mc_optimizer::MonteCarloOptimizer;
use std::path::Path;

impl<P: Precision> Eq for UnitaryMatrixObject<P> {}
impl<P: Precision> Eq for BuilderCircuitObjectType<P> {}

impl<P: Precision> Hash for BuilderCircuitObjectType<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
BuilderCircuitObjectType::Measurement(m) => {
state.write_i8(0);
m.hash(state)
}
BuilderCircuitObjectType::Unitary(u) => {
state.write_i8(1);
u.hash(state)
}
}
}
}

impl<P: Precision> Hash for UnitaryMatrixObject<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
UnitaryMatrixObject::X => state.write_i8(0),
UnitaryMatrixObject::Y => state.write_i8(1),
UnitaryMatrixObject::Z => state.write_i8(2),
UnitaryMatrixObject::H => state.write_i8(3),
UnitaryMatrixObject::S => state.write_i8(4),
UnitaryMatrixObject::T => state.write_i8(5),
UnitaryMatrixObject::CNOT => state.write_i8(6),
UnitaryMatrixObject::SWAP => state.write_i8(7),
UnitaryMatrixObject::Rz(rot) => {
state.write_i8(8);
rot.hash(state);
}
UnitaryMatrixObject::GlobalPhase(rot) => {
state.write_i8(9);
rot.hash(state);
}
UnitaryMatrixObject::MAT(data) => {
state.write_i8(10);
data.iter().for_each(|c| {
hash_p(c.re, state);
hash_p(c.im, state);
})
}
}
}
}

impl<P: Precision> Hash for RotationObject<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
// Grossly inefficient but also don't hash floats.
RotationObject::Floating(f) => hash_p(*f, state),
RotationObject::PiRational(r) => r.hash(state),
}
}
}

fn hash_p<P: Precision, H: Hasher>(f: P, state: &mut H) {
format!("{}", f).hash(state)
}

pub type OptimizerTrie<P> = IndexTrie<
(Vec<usize>, BuilderCircuitObjectType<P>),
Vec<(Vec<usize>, BuilderCircuitObjectType<P>)>,
Expand Down
7 changes: 4 additions & 3 deletions src/builder_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ pub trait CircuitBuilder {
fn qubit(&mut self) -> Self::Register {
self.register(NonZeroUsize::new(1).unwrap())
}

fn qudit(&mut self, n: usize) -> Option<Self::Register> {
NonZeroUsize::new(n).map(|n| self.register(n))
}
fn register(&mut self, n: NonZeroUsize) -> Self::Register;

fn try_register(&mut self, n: usize) -> Option<Self::Register> {
Expand Down Expand Up @@ -115,11 +117,10 @@ pub trait CircuitBuilder {
.collect::<Vec<_>>();
let selected_rs = indices
.into_iter()
.map(|is| {
.flat_map(|is| {
let subrs = is.into_iter().map(|i| rs[i].take().unwrap());
self.merge_registers(subrs)
})
.flatten()
.collect();
let remaining_rs = self.merge_registers(rs.into_iter().flatten());
match remaining_rs {
Expand Down
Loading

0 comments on commit 91e98a9

Please sign in to comment.