Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adding digital noise #34

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
87561f1
[Feature] Use single precision by default
dominikandreasseitz Apr 19, 2024
b743496
refac expectation like pyq
dominikandreasseitz Apr 19, 2024
7878a90
Add sample, increase atol
dominikandreasseitz Jun 21, 2024
882a5d9
rework circ
dominikandreasseitz Jun 21, 2024
2f74706
Merge branch 'main' into ds/single_prec
RolandMacDoland Aug 7, 2024
3888c6e
Remove spurious import.
RolandMacDoland Aug 7, 2024
d6d50f2
Lint.
RolandMacDoland Aug 7, 2024
0af6134
adding noisy operators
Dec 2, 2024
b22d3a7
change interface
Dec 3, 2024
a49e75a
test tuple length noise in gate
Dec 3, 2024
7603eb3
fix noise
Dec 3, 2024
965099f
separate tests between noisy and non noisy
Dec 4, 2024
9df4446
fix docs and tests with target control index 0
Dec 4, 2024
06c6062
lint
Dec 4, 2024
62a217b
Merge remote-tracking branch 'origin/ds/single_prec' into cm/krauss_ops
Dec 4, 2024
53c520a
Merge remote-tracking branch 'origin/main' into cm/krauss_ops
Dec 5, 2024
9aca1c8
fix order noise and param by inheritance
Dec 6, 2024
6f49beb
fix union
Dec 6, 2024
b3a9a28
rm densitymatrix object - to replace by more functional way
Dec 9, 2024
f3f1aa3
add raises notimplementederrors
Dec 9, 2024
78e7519
add boolean in fwd fcts
Dec 9, 2024
d310f78
add current implementation of channel apply
Dec 10, 2024
4340185
change dagger
Dec 10, 2024
3f1645f
apply_gate correct on density matrices before permutation
Dec 10, 2024
58d2f0b
fix permutation with density matrices
Dec 10, 2024
26ecf21
fix permute_basis
Dec 10, 2024
3c83b17
adding tests with dm
Dec 10, 2024
b50ab1f
adding test shots
Dec 10, 2024
b4e03d9
add permute_basis at beginning of apply
Dec 10, 2024
3cdc5f1
using jax lax transpose
Dec 10, 2024
304cdf4
adding test dm parameteric
Dec 10, 2024
aea2d33
fix controlled ops with dm
Dec 10, 2024
819833f
checking tests expectation work for dm
Dec 10, 2024
bc479ba
separate apply_operator with a density matrix version
Dec 10, 2024
5666c69
add sampling from density mat
Dec 12, 2024
47328f2
fix expectation with density matrices
Dec 12, 2024
40d7cf8
test also shots with noise
Dec 12, 2024
54da26f
adding docs
Dec 12, 2024
dd2625d
fix lint
Dec 13, 2024
3f90c12
rm circuit methods in favor of api
Dec 13, 2024
be6cc2f
rm comment
Dec 16, 2024
e62389a
Merge remote-tracking branch 'origin/main' into cm/krauss_ops
Dec 16, 2024
e32e288
change for isdensity
Dec 16, 2024
37b2a9a
Tuple to tuple
Dec 16, 2024
0936e16
change from default noise tuple to None
Dec 16, 2024
f2606d4
more docstr in apply
Dec 16, 2024
e3b1f2d
change tuple call new state dims
Dec 16, 2024
4fef6bf
fix union
Dec 16, 2024
d519a99
mention density matrix simulator
Dec 16, 2024
8488bd8
using single dispatch - breaking
Dec 17, 2024
4050e27
shots not working
Dec 17, 2024
32ee665
add fixmes
Dec 17, 2024
8c6c881
fix single dispatch methods
Dec 18, 2024
3a403a9
adding values to observable to matrix
Dec 18, 2024
49172fc
fix shots dm
Dec 18, 2024
d5a3d80
fix doc strings
Dec 18, 2024
c855a97
update State typing
Dec 18, 2024
4f45580
update typing primitive noiseprotocol
Dec 18, 2024
45bce84
fix docs is_density
Dec 18, 2024
ff52fac
Union instead of pipe
Dec 18, 2024
d8d7c0b
change union pipe is_controlled
Dec 18, 2024
015697c
union of primitive
Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/)

`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector and density matrix simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.

## Installation

Expand Down
7 changes: 4 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Welcome to horqrux

**horqrux** is a state vector simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/).
**horqrux** is a state vector and density matrix simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/).

## Setup

Expand Down Expand Up @@ -110,10 +110,11 @@ from operator import add
from typing import Any, Callable
from uuid import uuid4

from horqrux.circuit import QuantumCircuit, hea, expectation
from horqrux import expectation
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
from horqrux.circuit import QuantumCircuit, hea
from horqrux.primitive import Primitive
from horqrux.parametric import Parametric
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
from horqrux.utils import DiffMode


Expand Down
101 changes: 101 additions & 0 deletions docs/noise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
## Digital Noise

In the description of closed quantum systems, a pure state vector is used to represent the complete quantum state. Thus, pure quantum states are represented by state vectors $|\psi \rangle $.

However, this description is not sufficient to study open quantum systems. When the system interacts with its environment, quantum systems can be in a mixed state, where quantum information is no longer entirely contained in a single state vector but is distributed probabilistically.

To address these more general cases, we consider a probabilistic combination $p_i$ of possible pure states $|\psi_i \rangle$. Thus, the system is described by a density matrix $\rho$ defined as follows:

$$
\rho = \sum_i p_i |\psi_i\rangle \langle \psi_i|
$$

The transformations of the density operator of an open quantum system interacting with its environment (noise) are represented by the super-operator $S: \rho \rightarrow S(\rho)$, often referred to as a quantum channel.
Quantum channels, due to the conservation of the probability distribution, must be CPTP (Completely Positive and Trace Preserving). Any CPTP super-operator can be written in the following form:

$$
S(\rho) = \sum_i K_i \rho K^{\dagger}_i
$$

Where $K_i$ are the Kraus operators, and satisfy the property $\sum_i K_i K^{\dagger}_i = \mathbb{I}$. As noise is the result of system interactions with its environment, it is therefore possible to simulate noisy quantum circuit with noise type gates.

Thus, `horqrux` implements a large selection of single qubit noise gates such as:

- The bit flip channel defined as: $\textbf{BitFlip}(\rho) =(1-p) \rho + p X \rho X^{\dagger}$
- The phase flip channel defined as: $\textbf{PhaseFlip}(\rho) = (1-p) \rho + p Z \rho Z^{\dagger}$
- The depolarizing channel defined as: $\textbf{Depolarizing}(\rho) = (1-p) \rho + \frac{p}{3} (X \rho X^{\dagger} + Y \rho Y^{\dagger} + Z \rho Z^{\dagger})$
- The pauli channel defined as: $\textbf{PauliChannel}(\rho) = (1-p_x-p_y-p_z) \rho
+ p_x X \rho X^{\dagger}
+ p_y Y \rho Y^{\dagger}
+ p_z Z \rho Z^{\dagger}$
- The amplitude damping channel defined as: $\textbf{AmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$
with:
$\begin{equation*}
K_{0} \ =\begin{pmatrix}
1 & 0\\
0 & \sqrt{1-\ \gamma }
\end{pmatrix} ,\ K_{1} \ =\begin{pmatrix}
0 & \sqrt{\ \gamma }\\
0 & 0
\end{pmatrix}
\end{equation*}$
- The phase damping channel defined as: $\textbf{PhaseDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$
with:
$\begin{equation*}
K_{0} \ =\begin{pmatrix}
1 & 0\\
0 & \sqrt{1-\ \gamma }
\end{pmatrix}, \ K_{1} \ =\begin{pmatrix}
0 & 0\\
0 & \sqrt{\ \gamma }
\end{pmatrix}
\end{equation*}$
* The generalize amplitude damping channel is defined as: $\textbf{GeneralizedAmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger} + K_2 \rho K_2^{\dagger} + K_3 \rho K_3^{\dagger}$
with:
$\begin{cases}
K_{0} \ =\sqrt{p} \ \begin{pmatrix}
1 & 0\\
0 & \sqrt{1-\ \gamma }
\end{pmatrix} ,\ K_{1} \ =\sqrt{p} \ \begin{pmatrix}
0 & 0\\
0 & \sqrt{\ \gamma }
\end{pmatrix} \\
K_{2} \ =\sqrt{1\ -p} \ \begin{pmatrix}
\sqrt{1-\ \gamma } & 0\\
0 & 1
\end{pmatrix} ,\ K_{3} \ =\sqrt{1-p} \ \begin{pmatrix}
0 & 0\\
\sqrt{\ \gamma } & 0
\end{pmatrix}
\end{cases}$

Noise protocols can be added to gates by instantiating `NoiseInstance` providing the `NoiseType` and the `error_probability` (either float or tuple of float):

```python exec="on" source="material-block" html="1"
from horqrux.noise import NoiseInstance, NoiseType

noise_prob = 0.3
AmpD = NoiseInstance(NoiseType.AMPLITUDE_DAMPING, error_probability=noise_prob)

```

Then a gate can be instantiated by providing a tuple of `NoiseInstance` instances. Let’s show this through the simulation of a realistic $X$ gate.

We know that an $X$ gate flips the state of the qubit, for instance $X|0\rangle = |1\rangle$. In practice, it's common for the target qubit to stay in its original state after applying $X$ due to the interactions between it and its environment. The possibility of failure can be represented by a `BitFlip` `NoiseInstance`, which flips the state again after the application of the $X$ gate, returning it to its original state with a probability `1 - gate_fidelity`.

```python exec="on" source="material-block"
from horqrux.api import sample
from horqrux.noise import NoiseInstance, NoiseType
from horqrux.utils import density_mat, product_state
from horqrux.primitive import X

noise = (NoiseInstance(NoiseType.BITFLIP, 0.1),)
ops = [X(0)]
noisy_ops = [X(0, noise=noise)]
state = product_state("0")

noiseless_samples = sample(state, ops)
noisy_samples = sample(density_mat(state), noisy_ops)
print("Noiseless samples", noiseless_samples) # markdown-exec: hide
print("Noiseless samples", noisy_samples) # markdown-exec: hide
```
4 changes: 2 additions & 2 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from .api import expectation
from .api import expectation, run, sample
from .apply import apply_gate, apply_operator
from .circuit import QuantumCircuit, sample
from .circuit import QuantumCircuit
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from .utils import (
Expand Down
6 changes: 2 additions & 4 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Tuple

from jax import Array, custom_vjp

from horqrux.apply import apply_gate
Expand Down Expand Up @@ -37,14 +35,14 @@ def adjoint_expectation(

def adjoint_expectation_single_observable_fwd(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
) -> tuple[Array, tuple[Array, Array, list[Primitive], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)


def adjoint_expectation_single_observable_bwd(
res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
res: tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
) -> tuple:
"""Implementation of Algorithm 1 of https://arxiv.org/abs/2009.02823
which computes the vector-jacobian product in O(P) time using O(1) state vectors
Expand Down
154 changes: 115 additions & 39 deletions horqrux/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import Counter
from functools import singledispatch
from typing import Any, Optional

import jax
Expand All @@ -11,74 +12,128 @@
from horqrux.adjoint import adjoint_expectation
from horqrux.apply import apply_gate
from horqrux.primitive import GateSequence, Primitive
from horqrux.shots import finite_shots_fwd
from horqrux.utils import DiffMode, ForwardMode, OperationType, inner
from horqrux.shots import finite_shots_fwd, observable_to_matrix
from horqrux.utils import (
DensityMatrix,
DiffMode,
ForwardMode,
OperationType,
State,
get_probas,
inner,
sample_from_probs,
)


def run(
circuit: GateSequence,
state: Array,
state: State,
values: dict[str, float] = dict(),
) -> Array:
) -> State:
return apply_gate(state, circuit, values)


def sample(
state: Array,
state: State,
gates: GateSequence,
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
"""Sample from a quantum program.

Args:
state (State): Input state vector or density matrix.
gates (GateSequence): Sequence of gates.
values (dict[str, float], optional): _description_. Defaults to dict().
n_shots (int, optional): Parameter values.. Defaults to 1000.

Raises:
ValueError: If n_shots < 1.

Returns:
Counter: Bitstrings and frequencies.
"""
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")
output_circuit = apply_gate(state, gates, values)

wf = apply_gate(state, gates, values)
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
key = jax.random.PRNGKey(0)
n_qubits = len(state.shape)
# JAX handles pseudo random number generation by tracking an explicit state via a random key
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))

return Counter(
{
format(k, "0{}b".format(n_qubits)): count.item()
for k, count in enumerate(jnp.bincount(samples))
if count > 0
}
)
if isinstance(output_circuit, DensityMatrix):
n_qubits = len(output_circuit.array.shape) // 2
d = 2**n_qubits
output_circuit.array = output_circuit.array.reshape((d, d))
else:
n_qubits = len(output_circuit.shape)

probs = get_probas(output_circuit)
return sample_from_probs(probs, n_qubits, n_shots)


@singledispatch
def __ad_expectation_single_observable(
state: Array, gates: GateSequence, observable: Primitive, values: dict[str, float]
state: Any,
observable: Primitive,
values: dict[str, float],
) -> Any:
raise NotImplementedError("__ad_expectation_single_observable is not implemented")


@__ad_expectation_single_observable.register
def _(
state: Array,
observable: Primitive,
values: dict[str, float],
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real
projected_state = apply_gate(
state,
observable,
values,
OperationType.UNITARY,
)
return inner(state, projected_state).real


@__ad_expectation_single_observable.register
def _(
state: DensityMatrix,
observable: Primitive,
values: dict[str, float],
) -> Array:
n_qubits = len(state.array.shape) // 2
mat_obs = observable_to_matrix(observable, n_qubits, values)
d = 2**n_qubits
prod = jnp.matmul(mat_obs, state.array.reshape((d, d)))
return jnp.trace(prod, axis1=-2, axis2=-1).real


def ad_expectation(
state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float]
state: State,
gates: GateSequence,
observables: list[Primitive],
values: dict[str, float],
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.

Args:
state (State): Input state vector or density matrix.
gates (GateSequence): Sequence of gates.
observables (list[Primitive]): List of observables.
values (dict[str, float]): Parameter values.

Returns:
Array: Expectation values.
"""
outputs = [
__ad_expectation_single_observable(state, gates, observable, values)
__ad_expectation_single_observable(
apply_gate(state, gates, values, OperationType.UNITARY), observable, values
)
for observable in observables
]
return jnp.stack(outputs)


def expectation(
state: Array,
state: State,
gates: GateSequence,
observables: list[Primitive],
values: dict[str, float],
Expand All @@ -87,13 +142,27 @@ def expectation(
n_shots: Optional[int] = None,
key: Any = jax.random.PRNGKey(0),
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
"""Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.

Args:
state (State): Input state vector or density matrix.
gates (GateSequence): Sequence of gates.
observables (list[Primitive]): List of observables.
values (dict[str, float]): Parameter values.
diff_mode (DiffMode, optional): Differentiation mode. Defaults to DiffMode.AD.
forward_mode (ForwardMode, optional): Type of forward method. Defaults to ForwardMode.EXACT.
n_shots (Optional[int], optional): Number of shots. Defaults to None.
key (Any, optional): Random key. Defaults to jax.random.PRNGKey(0).

Returns:
Array: Expectation values.
"""
if diff_mode == DiffMode.AD:
return ad_expectation(state, gates, observables, values)
elif diff_mode == DiffMode.ADJOINT:
if isinstance(state, DensityMatrix):
raise ValueError("Adjoint does not support density matrices.")
return adjoint_expectation(state, gates, observables, values)
elif diff_mode == DiffMode.GPSR:
checkify.check(
Expand All @@ -105,4 +174,11 @@ def expectation(
)
# Type checking is disabled because mypy doesn't parse checkify.check.
# type: ignore
return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key)
return finite_shots_fwd(
state,
gates,
observables,
values,
n_shots=n_shots,
key=key,
)
Loading
Loading