From 87561f14be7c2b426a6718c83ac4fd8e9f7aa47a Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 19 Apr 2024 11:48:46 +0200 Subject: [PATCH 01/57] [Feature] Use single precision by default --- horqrux/matrices.py | 25 +++++++++++-------------- horqrux/parametric.py | 4 ++-- horqrux/utils.py | 12 ++++++------ 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/horqrux/matrices.py b/horqrux/matrices.py index 7079e7b..223cf84 100644 --- a/horqrux/matrices.py +++ b/horqrux/matrices.py @@ -1,19 +1,16 @@ from __future__ import annotations import jax.numpy as jnp -from jax import config -config.update("jax_enable_x64", True) # Quantum ML requires higher precision +_X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64) +_Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64) +_Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64) +_H = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) * 1 / jnp.sqrt(2) +_S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex64) +_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex64) +_I = jnp.asarray([[1, 0], [0, 1]], dtype=jnp.complex64) -_X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex128) -_Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex128) -_Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex128) -_H = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) * 1 / jnp.sqrt(2) -_S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex128) -_T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex128) -_I = jnp.asarray([[1, 0], [0, 1]], dtype=jnp.complex128) - -_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128) +_SWAP = jnp.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex64) _SQSWAP = jnp.asarray( @@ -23,11 +20,11 @@ [0, 0.5 * (1 - 1j), 0.5 * (1 + 1j), 0], [0, 0, 0, 1], ], - dtype=jnp.complex128, + dtype=jnp.complex64, ) _ISWAP = jnp.asarray( - [[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex128 + [[1, 0, 0, 0], [0, 0, 1j, 0], [0, 1j, 0, 0], [0, 0, 0, 1]], dtype=jnp.complex64 ) _ISQSWAP = jnp.asarray( @@ -37,7 +34,7 @@ [0, 1j / jnp.sqrt(2), 1 / jnp.sqrt(2), 0], [0, 0, 0, 1], ], - dtype=jnp.complex128, + dtype=jnp.complex64, ) OPERATIONS_DICT = { diff --git a/horqrux/parametric.py b/horqrux/parametric.py index bd5d488..94c1ca6 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -118,12 +118,12 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None, class _PHASE(Parametric): def unitary(self, values: dict[str, float] = dict()) -> Array: - u = jnp.eye(2, 2, dtype=jnp.complex128) + u = jnp.eye(2, 2, dtype=jnp.complex64) u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values))) return u def jacobian(self, values: dict[str, float] = dict()) -> Array: - jac = jnp.zeros((2, 2), dtype=jnp.complex128) + jac = jnp.zeros((2, 2), dtype=jnp.complex64) jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values))) return jac diff --git a/horqrux/utils.py b/horqrux/utils.py index 4b1074e..6a1c03f 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -40,7 +40,7 @@ def _dagger(operator: Array) -> Array: def _unitary(generator: Array, theta: float) -> Array: return ( - jnp.cos(theta / 2) * jnp.eye(2, dtype=jnp.complex128) - 1j * jnp.sin(theta / 2) * generator + jnp.cos(theta / 2) * jnp.eye(2, dtype=jnp.complex64) - 1j * jnp.sin(theta / 2) * generator ) @@ -48,14 +48,14 @@ def _jacobian(generator: Array, theta: float) -> Array: return ( -1 / 2 - * (jnp.sin(theta / 2) * jnp.eye(2, dtype=jnp.complex128) + 1j * jnp.cos(theta / 2)) + * (jnp.sin(theta / 2) * jnp.eye(2, dtype=jnp.complex64) + 1j * jnp.cos(theta / 2)) * generator ) def _controlled(operator: Array, n_control: int) -> Array: n_qubits = int(log2(operator.shape[0])) - control = jnp.eye(2 ** (n_control + n_qubits), dtype=jnp.complex128) + control = jnp.eye(2 ** (n_control + n_qubits), dtype=jnp.complex64) control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator) return control @@ -70,7 +70,7 @@ def product_state(bitstring: str) -> Array: A state corresponding to 'bitstring'. """ n_qubits = len(bitstring) - space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=jnp.complex128) + space = jnp.zeros(tuple(2 for _ in range(n_qubits)), dtype=jnp.complex64) space = space.at[tuple(map(int, bitstring))].set(1.0) return space @@ -129,8 +129,8 @@ def overlap(state: Array, projection: Array) -> Array: def uniform_state( n_qubits: int, ) -> Array: - state = jnp.ones(2**n_qubits, dtype=jnp.complex128) - state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=jnp.complex128)) + state = jnp.ones(2**n_qubits, dtype=jnp.complex64) + state = state / jnp.sqrt(jnp.array(2**n_qubits, dtype=jnp.complex64)) return state.reshape([2] * n_qubits) From b743496c0a34cb97db6e676e632a578dbe907c68 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 19 Apr 2024 13:24:50 +0200 Subject: [PATCH 02/57] refac expectation like pyq --- docs/index.md | 6 +++--- horqrux/__init__.py | 2 ++ horqrux/adjoint.py | 4 ++-- horqrux/circuit.py | 20 +++++++++++++++++++- horqrux/utils.py | 14 ++++++++++++++ tests/test_adjoint.py | 14 ++++++-------- 6 files changed, 46 insertions(+), 14 deletions(-) diff --git a/docs/index.md b/docs/index.md index 99d842c..1081a62 100644 --- a/docs/index.md +++ b/docs/index.md @@ -110,11 +110,11 @@ from operator import add from typing import Any, Callable from uuid import uuid4 -from horqrux.adjoint import adjoint_expectation -from horqrux.circuit import Circuit, hea +from horqrux.circuit import Circuit, hea, expectation from horqrux.primitive import Primitive from horqrux.parametric import Parametric from horqrux import Z, RX, RY, NOT, zero_state, apply_gate +from horqrux.utils import DiffMode n_qubits = 5 @@ -137,7 +137,7 @@ class DQC(Circuit): @partial(vmap, in_axes=(None, None, 0)) def __call__(self, param_values: Array, x: Array) -> Array: param_dict = {name: val for name, val in zip(self.param_names, param_values)} - return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}) + return expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}, DiffMode.ADJOINT) circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers)) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 1b46fcf..be34ae8 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations from .apply import apply_gate, apply_operator +from .circuit import Circuit, expectation from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( + DiffMode, equivalent_state, hilbert_reshape, overlap, diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index d7fa777..fa35b78 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -10,7 +10,7 @@ from horqrux.utils import OperationType, inner -def expectation( +def ad_expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: """ @@ -26,7 +26,7 @@ def expectation( def adjoint_expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: - return expectation(state, gates, observable, values) + return ad_expectation(state, gates, observable, values) def adjoint_expectation_fwd( diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ad47d2a..388b8ba 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -7,10 +7,11 @@ from jax import Array from jax.tree_util import register_pytree_node_class +from horqrux.adjoint import ad_expectation, adjoint_expectation from horqrux.apply import apply_gate from horqrux.parametric import RX, RY, Parametric from horqrux.primitive import NOT, Primitive -from horqrux.utils import zero_state +from horqrux.utils import DiffMode, zero_state @register_pytree_node_class @@ -66,3 +67,20 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> gates += ops return gates + + +def expectation( + state: Array, + gates: list[Primitive], + observable: list[Primitive], + values: dict[str, float], + diff_mode: DiffMode | str = DiffMode.AD, +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + if diff_mode == DiffMode.AD: + return ad_expectation(state, gates, observable, values) + else: + return adjoint_expectation(state, gates, observable, values) diff --git a/horqrux/utils.py b/horqrux/utils.py index 6a1c03f..eda4246 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -34,6 +34,20 @@ class OperationType(StrEnum): JACOBIAN = "jacobian" +class DiffMode(StrEnum): + """ + Which Differentiation method to use. + + Options: Automatic Differentiation - Using the autograd engine of JAX. + Adjoint Differentiation - An implementation of "Efficient calculation of gradients + in classical simulations of variational quantum algorithms", + Jones & Gacon, 2020 + """ + + AD = "ad" + ADJOINT = "adjoint" + + def _dagger(operator: Array) -> Array: return jnp.conjugate(operator.T) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 4647e86..ce5c704 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -5,9 +5,10 @@ from jax import Array, grad from horqrux import random_state -from horqrux.adjoint import adjoint_expectation, expectation +from horqrux.circuit import expectation from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z +from horqrux.utils import DiffMode MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -25,13 +26,10 @@ def test_gradcheck() -> None: } state = random_state(MAX_QUBITS) - def adjoint_expfn(values) -> Array: - return adjoint_expectation(state, ops, observable, values) + def exp_fn(values: dict, diff_mode: DiffMode) -> Array: + return expectation(state, ops, observable, values, diff_mode) - def ad_expfn(values) -> Array: - return expectation(state, ops, observable, values) - - grads_adjoint = grad(adjoint_expfn)(values) - grad_ad = grad(ad_expfn)(values) + grads_adjoint = grad(exp_fn)(values, "adjoint") + grad_ad = grad(exp_fn)(values, "ad") for param, ad_grad in grad_ad.items(): assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09) From 7878a90e5fd67eeae3a86394e85939101c2125b5 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 21 Jun 2024 17:48:32 +0200 Subject: [PATCH 03/57] Add sample, increase atol --- docs/index.md | 6 +++--- horqrux/__init__.py | 2 +- horqrux/circuit.py | 31 +++++++++++++++++++++++++++++++ tests/test_adjoint.py | 2 +- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/docs/index.md b/docs/index.md index 1081a62..0c0d5ec 100644 --- a/docs/index.md +++ b/docs/index.md @@ -218,14 +218,14 @@ from horqrux.primitive import Primitive from horqrux.parametric import Parametric from horqrux.utils import inner -LEARNING_RATE = 0.01 +LEARNING_RATE = 0.15 N_QUBITS = 4 DEPTH = 3 VARIABLES = ("x", "y") NUM_VARIABLES = len(VARIABLES) X_POS, Y_POS = [i for i in range(NUM_VARIABLES)] -BATCH_SIZE = 150 -N_EPOCHS = 1000 +BATCH_SIZE = 500 +N_EPOCHS = 500 def total_magnetization(n_qubits:int) -> Callable: paulis = [Z(i) for i in range(n_qubits)] diff --git a/horqrux/__init__.py b/horqrux/__init__.py index be34ae8..920fe68 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from .apply import apply_gate, apply_operator -from .circuit import Circuit, expectation +from .circuit import Circuit, expectation, sample from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( diff --git a/horqrux/circuit.py b/horqrux/circuit.py index 388b8ba..9e8b419 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -1,9 +1,12 @@ from __future__ import annotations +from collections import Counter from dataclasses import dataclass from typing import Any, Callable from uuid import uuid4 +import jax +import jax.numpy as jnp from jax import Array from jax.tree_util import register_pytree_node_class @@ -84,3 +87,31 @@ def expectation( return ad_expectation(state, gates, observable, values) else: return adjoint_expectation(state, gates, observable, values) + + +def sample( + state: Array, + gates: list[Primitive], + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + wf = apply_gate(state, gates, values) + probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() + key = jax.random.PRNGKey(0) + n_qubits = len(state.shape) + # JAX handles pseudo random number generation by tracking an explicit state via a random key + # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html + samples = jax.vmap( + lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) + )(jax.random.split(key, n_shots)) + + return Counter( + { + format(k, "0{}b".format(n_qubits)): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index ce5c704..4bbd9e2 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -32,4 +32,4 @@ def exp_fn(values: dict, diff_mode: DiffMode) -> Array: grads_adjoint = grad(exp_fn)(values, "adjoint") grad_ad = grad(exp_fn)(values, "ad") for param, ad_grad in grad_ad.items(): - assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09) + assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.25) From 882a5d90a5df3805150aea8214ae242013d9cd8f Mon Sep 17 00:00:00 2001 From: seitzdom Date: Fri, 21 Jun 2024 18:08:46 +0200 Subject: [PATCH 04/57] rework circ --- docs/index.md | 20 +++++++++----------- horqrux/__init__.py | 2 +- horqrux/circuit.py | 30 +++++++++++++++--------------- horqrux/parametric.py | 8 +++----- horqrux/primitive.py | 4 ++-- 5 files changed, 30 insertions(+), 34 deletions(-) diff --git a/docs/index.md b/docs/index.md index 0c0d5ec..13e64c1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -212,7 +212,7 @@ from jax import Array, jit, value_and_grad, vmap from numpy.random import uniform from horqrux.apply import group_by_index -from horqrux.circuit import Circuit, hea +from horqrux.circuit import QuantumCircuit, hea from horqrux import NOT, RX, RY, Z, apply_gate, zero_state from horqrux.primitive import Primitive from horqrux.parametric import Parametric @@ -238,22 +238,20 @@ def total_magnetization(n_qubits:int) -> Callable: return _total_magnetization -class DQC(Circuit): - def __post_init__(self) -> None: - self.ansatz = group_by_index(self.ansatz) - self.observable = total_magnetization(self.n_qubits) - self.state = zero_state(self.n_qubits) +class DQC(QuantumCircuit): + def __post_init__(self, n_qubits: int, operations: list[Primitive]) -> None: + self.operations = group_by_index(operations) + self.observable = total_magnetization(n_qubits) - def __call__(self, param_vals: Array, x: Array, y: Array) -> Array: - param_dict = {name: val for name, val in zip(self.param_names, param_vals)} + def __call__(self, state, values: dict[str, Array], x: Array, y: Array) -> Array: out_state = apply_gate( - self.state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}} + state, self.operations, {**param_dict, **{"f_x": x, "f_y": y}} ) return self.observable(out_state, {}) -fm = [RX("x", i) for i in range(N_QUBITS // 2)] + [ - RX("y", i) for i in range(N_QUBITS // 2, N_QUBITS) +fm = [RX("f_x", i) for i in range(N_QUBITS // 2)] + [ + RX("f_y", i) for i in range(N_QUBITS // 2, N_QUBITS) ] ansatz = hea(N_QUBITS, DEPTH) circ = DQC(N_QUBITS, fm, ansatz) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 920fe68..9d82b66 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from .apply import apply_gate, apply_operator -from .circuit import Circuit, expectation, sample +from .circuit import QuantumCircuit, expectation, sample from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( diff --git a/horqrux/circuit.py b/horqrux/circuit.py index 9e8b419..e9f90d8 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -19,33 +19,33 @@ @register_pytree_node_class @dataclass -class Circuit: +class QuantumCircuit: """A minimalistic circuit class to store a sequence of gates.""" n_qubits: int - feature_map: list[Primitive] - ansatz: list[Primitive] + operations: list[Primitive] - def __post_init__(self) -> None: - self.state = zero_state(self.n_qubits) - - def __call__(self, param_values: Array) -> Array: - return apply_gate( - self.state, - self.feature_map + self.ansatz, - {name: val for name, val in zip(self.param_names, param_values)}, - ) + def __call__(self, state: Array, values: dict[str, Array]) -> Array: + if state is None: + state = zero_state(self.n_qubits) + return apply_gate(state, self.operations, values) @property def param_names(self) -> list[str]: - return [str(op.param) for op in self.ansatz if isinstance(op, Parametric)] + return [ + str(op.param) + for op in self.operations + if isinstance(op, Parametric) + and isinstance(op.param, str) + and op.param.startswith("v_") + ] @property def n_vparams(self) -> int: return len(self.param_names) def tree_flatten(self) -> tuple: - children = (self.feature_map, self.ansatz) + children = (self.operations,) aux_data = (self.n_qubits,) return (aux_data, children) @@ -62,7 +62,7 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> for _ in range(n_layers): for i in range(n_qubits): ops = [ - fn(str(uuid4()), qubit) + fn("v_" + str(uuid4()), qubit) for fn, qubit in zip(rot_fns, [i for _ in range(len(rot_fns))]) ] param_names += [op.param for op in ops] diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 94c1ca6..aba3963 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -44,8 +44,8 @@ def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: children = () aux_data = ( self.generator_name, - self.target[0], - self.control[0], + self.target, + self.control, self.param, ) return (children, aux_data) @@ -69,9 +69,7 @@ def name(self) -> str: return "C" + base_name if is_controlled(self.control) else base_name def __repr__(self) -> str: - return ( - self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})" - ) + return self.name + f"(target={self.target}, control={self.control}, param={self.param})" def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 9790603..f440168 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -51,7 +51,7 @@ def __iter__(self) -> Iterable: def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: children = () - aux_data = (self.generator_name, self.target[0], self.control[0]) + aux_data = (self.generator_name, self.target, self.control) return (children, aux_data) @classmethod @@ -69,7 +69,7 @@ def name(self) -> str: return "C" + self.generator_name if is_controlled(self.control) else self.generator_name def __repr__(self) -> str: - return self.name + f"(target={self.target[0]}, control={self.control[0]})" + return self.name + f"(target={self.target}, control={self.control})" def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: From 3888c6eb303a3c85ca014b843fc85e0623e140ce Mon Sep 17 00:00:00 2001 From: Roland Guichard Date: Wed, 7 Aug 2024 11:25:39 +0200 Subject: [PATCH 05/57] Remove spurious import. --- horqrux/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 01398e2..513c031 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -2,7 +2,7 @@ from .api import expectation from .apply import apply_gate, apply_operator -from .circuit import QuantumCircuit, expectation, sample +from .circuit import QuantumCircuit, sample from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( From d6d50f2e13fa43c31d652a0ddf611bb1effe63b5 Mon Sep 17 00:00:00 2001 From: Roland Guichard Date: Wed, 7 Aug 2024 11:25:54 +0200 Subject: [PATCH 06/57] Lint. --- horqrux/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/horqrux/utils.py b/horqrux/utils.py index d5855b2..38db948 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -36,6 +36,7 @@ class OperationType(StrEnum): class DiffMode(StrEnum): """Differentiation mode.""" + AD = "ad" """Automatic Differentiation - Using the autograd engine of JAX.""" ADJOINT = "adjoint" From 0af6134004db21fb5399319364ef04e33eed43fd Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 2 Dec 2024 14:20:57 +0100 Subject: [PATCH 07/57] adding noisy operators --- horqrux/noise.py | 299 ++++++++++++++++++++++++++++++++++++++++++ horqrux/parametric.py | 46 +++++-- horqrux/primitive.py | 98 ++++++++++---- horqrux/utils.py | 2 + 4 files changed, 410 insertions(+), 35 deletions(-) create mode 100644 horqrux/noise.py diff --git a/horqrux/noise.py b/horqrux/noise.py new file mode 100644 index 0000000..31aefc5 --- /dev/null +++ b/horqrux/noise.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Tuple, Union + +import jax.numpy as jnp +import numpy as np +from jax import Array +from jax.tree_util import register_pytree_node_class +from jax.typing import ArrayLike + +from .utils import ( + QubitSupport, + TargetQubits, + ErrorProbabilities, +) +from .matrices import OPERATIONS_DICT + + + +@register_pytree_node_class +@dataclass +class Noise: + """Noise class which stores information on .""" + + kraus: list[Array] + target: QubitSupport + error_probability: ErrorProbabilities + + @staticmethod + def parse_idx( + idx: Tuple, + ) -> Tuple: + if isinstance(idx, (int, np.int64)): + return ((idx,),) + elif isinstance(idx, tuple): + return (idx,) + else: + return (idx.astype(int),) + + def __post_init__(self) -> None: + self.target = Noise.parse_idx(self.target) + + def __iter__(self) -> Iterable: + return iter((self.kraus, self.target, self.error_probability)) + + def tree_flatten(self) -> Tuple[Tuple, Tuple[list[Array], TargetQubits, ErrorProbabilities]]: + children = () + aux_data = (self.kraus, self.target[0], self.error_probability) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + +def BitFlip(target: TargetQubits, error_probability: float) -> Noise: + """ + Initialize the BitFlip gate. + + The bit flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (float): The probability of a bit flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + Noise: Noise instance for this protocol. + """ + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["X"] + kraus_bitflip: list[Array] = [K0, K1] + return Noise(kraus_bitflip, target, error_probability) + +def PhaseFlip(target: TargetQubits, error_probability: float) -> Noise: + """ + Initialize the PhaseFlip gate + + The phase flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p Z \\rho Z^{\\dagger} + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + Noise: Noise instance for this protocol. + """ + + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["Z"] + kraus: list[Array] = [K0, K1] + return Noise(kraus, target, error_probability) + +def Depolarizing(target: TargetQubits, error_probability: float) -> Noise: + """ + Initialize the Depolarizing gate. + + The depolarizing channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + + p/3 X \\rho X^{\\dagger} + + p/3 Y \\rho Y^{\\dagger} + + p/3 Z \\rho Z^{\\dagger} + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + Noise: Noise instance for this protocol. + """ + + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Z"] + kraus: list[Array] = [K0, K1, K2, K3] + return Noise(kraus, target, error_probability) + +def PauliChannel(target: TargetQubits, error_probability: tuple[float, ...]) -> Noise: + """ + Initialize the PauliChannel gate. + + The pauli channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-px-py-pz) \\rho + + px X \\rho X^{\\dagger} + + py Y \\rho Y^{\\dagger} + + pz Z \\rho Z^{\\dagger} + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (ErrorProbabilities): Tuple containing probabilities + of X, Y, and Z errors. + + Raises: + ValueError: If the probabilities values do not sum up to 1. + + Returns: + Noise: Noise instance for this protocol. + """ + + sum_prob = sum(error_probability) + if sum_prob > 1.0: + raise ValueError("The sum of probabilities can't be greater than 1.0") + for probability in error_probability: + if probability > 1.0 or probability < 0.0: + raise ValueError("The probability values are not correct probabilities") + px, py, pz = ( + error_probability[0], + error_probability[1], + error_probability[2], + ) + + + K0: Array = jnp.sqrt(1.0 - (px + py + pz)) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(px) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(py) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(pz) * OPERATIONS_DICT["Z"] + kraus: list[Array] = [K0, K1, K2, K3] + return Noise(kraus, target, error_probability) + +def AmplitudeDamping(target: TargetQubits, error_probability: float) -> Noise: + """ + Initialize the AmplitudeDamping gate. + + The amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - rate)]] + K1 = [[0, sqrt(rate)], [0, 0]] + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (float): The damping rate, indicating the probability of amplitude loss. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + Noise: Noise instance for this protocol. + """ + + rate = error_probability + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128) + kraus: list[Array] = [K0, K1] + return Noise(kraus, target, error_probability) + +def PhaseDamping(target: TargetQubits, error_probability: float) -> Noise: + """ + Initialize the PhaseDamping gate. + + The phase damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - rate)]] + K1 = [[0, 0], [0, sqrt(rate)]] + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (float): The damping rate, indicating the probability of phase damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + Noise: Noise instance for this protocol. + """ + + rate = error_probability + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, 0], [0, jnp.sqrt(rate)]], dtype=jnp.complex128) + kraus: list[Array] = [K0, K1] + return Noise(kraus, target, error_probability) + +def GeneralizedAmplitudeDamping(target: TargetQubits, error_probability: tuple[float, ...]) -> Noise: + """ + Initialize the GeneralizeAmplitudeDamping gate. + + The generalize amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + K_2 \\rho K_2^{\\dagger} + K_3 \\rho K_3^{\\dagger} + + with: + + .. code-block:: python + + K0 = sqrt(p) * [[1, 0], [0, sqrt(1 - rate)]] + K1 = sqrt(p) * [[0, sqrt(rate)], [0, 0]] + K2 = sqrt(1-p) * [[sqrt(1 - rate), 0], [0, 1]] + K3 = sqrt(1-p) * [[0, 0], [sqrt(rate), 0]] + + Args: + target (int): The index of the qubit being affected by the noise. + error_probability (ErrorProbabilities): The first float must be the probability + of amplitude damping error, and the second float is the damping rate, indicating + the probability of generalized amplitude damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + Noise: Noise instance for this protocol. + """ + + probability = error_probability[0] + rate = error_probability[1] + if probability > 1.0 or probability < 0.0: + raise ValueError("The probability value is not a correct probability") + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + + K0: Array = jnp.sqrt(probability) * jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) + K1: Array = jnp.sqrt(probability) * jnp.array([[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128) + K2: Array = jnp.sqrt(1.0-probability) * jnp.array([[jnp.sqrt(1.0-rate), 0], [0, 1]], dtype=jnp.complex128) + K3: Array = jnp.sqrt(1.0-probability) * jnp.array([[0, 0], [jnp.sqrt(rate), 0]], dtype=jnp.complex128) + kraus: list[Array] = [K0, K1, K2, K3] + return Noise(kraus, target, error_probability) \ No newline at end of file diff --git a/horqrux/parametric.py b/horqrux/parametric.py index bd5d488..faa3d4c 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -8,6 +8,7 @@ from jax.tree_util import register_pytree_node_class from .matrices import OPERATIONS_DICT +from .noise import Noise from .primitive import Primitive from .utils import ( ControlQubits, @@ -27,6 +28,7 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport + noise: Noise | None = None param: str | float = "" def __post_init__(self) -> None: @@ -40,18 +42,21 @@ def parse_val(values: dict[str, float] = dict()) -> float: self.parse_values = parse_dict if isinstance(self.param, str) else parse_val - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override] + def tree_flatten( # type: ignore[override] + self, + ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, Noise | None, str | float]]: children = () aux_data = ( self.generator_name, self.target[0], self.control[0], + self.noise, self.param, ) return (children, aux_data) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.param)) + return iter((self.generator_name, self.target, self.control, self.noise, self.param)) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: @@ -74,46 +79,64 @@ def __repr__(self) -> str: ) -def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RX( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: Noise | None = None, +) -> Parametric: """RX gate. Arguments: param: Parameter denoting the Rotational angle. target: Tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("X", target, control, param) + return Parametric("X", target, control, param=param, noise=noise) -def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RY( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: Noise | None = None, +) -> Parametric: """RY gate. Arguments: param: Parameter denoting the Rotational angle. target: Tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("Y", target, control, param) + return Parametric("Y", target, control, param=param, noise=noise) -def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RZ( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: Noise | None = None, +) -> Parametric: """RZ gate. Arguments: param: Parameter denoting the Rotational angle. target: Tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("Z", target, control, param) + return Parametric("Z", target, control, param=param, noise=noise) class _PHASE(Parametric): @@ -133,16 +156,19 @@ def name(self) -> str: return "C" + base_name if is_controlled(self.control) else base_name -def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def PHASE( + param: float, target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Parametric: """Phase gate. Arguments: param: Parameter denoting the Rotational angle. target: Tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return _PHASE("I", target, control, param) + return _PHASE("I", target, control, param=param, noise=noise) diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 2aac3a4..264eef0 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -8,6 +8,7 @@ from jax.tree_util import register_pytree_node_class from .matrices import OPERATIONS_DICT +from .noise import Noise from .utils import ( ControlQubits, QubitSupport, @@ -27,6 +28,7 @@ class Primitive: generator_name: str target: QubitSupport control: QubitSupport + noise: Noise | None = None @staticmethod def parse_idx( @@ -45,13 +47,22 @@ def __post_init__(self) -> None: self.control = none_like(self.target) else: self.control = Primitive.parse_idx(self.control) + self._validate_noise() + + def _validate_noise(self) -> None: + if self.noise is not None: + dim = 2**self.n_qubits + noise_dim = (2**dim, 2**dim) + for kraus in self.noise.kraus: + if kraus.shape != noise_dim: + raise ValueError(f"Specify all noise tensors with shape {noise_dim}.") def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control)) + return iter((self.generator_name, self.target, self.control, self.noise)) - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: + def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits, Noise | None]]: children = () - aux_data = (self.generator_name, self.target[0], self.control[0]) + aux_data = (self.generator_name, self.target[0], self.control[0], self.noise) return (children, aux_data) @classmethod @@ -68,6 +79,13 @@ def dagger(self, values: dict[str, float] = dict()) -> Array: def name(self) -> str: return "C" + self.generator_name if is_controlled(self.control) else self.generator_name + @property + def n_qubits(self) -> int: + n_qubits = len(self.target) + if self.control[0] is not None: + n_qubits += len(self.control) + return n_qubits + def __repr__(self) -> str: return self.name + f"(target={self.target[0]}, control={self.control[0]})" @@ -75,7 +93,9 @@ def __repr__(self) -> str: GateSequence = Union[Primitive, Iterable[Primitive]] -def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def I( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -84,14 +104,17 @@ def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("I", target, control) + return Primitive("I", target, control, noise) -def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def X( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """X gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -101,17 +124,20 @@ def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("X", target, control) + return Primitive("X", target, control, noise) NOT = X -def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def Y( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -121,14 +147,17 @@ def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("Y", target, control) + return Primitive("Y", target, control, noise) -def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def Z( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -138,14 +167,17 @@ def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("Z", target, control) + return Primitive("Z", target, control, noise) -def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def H( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -154,14 +186,17 @@ def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("H", target, control) + return Primitive("H", target, control, noise) -def S(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def S( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -170,14 +205,17 @@ def S(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("S", target, control) + return Primitive("S", target, control, noise) -def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def T( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """T gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -186,17 +224,20 @@ def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Args: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("T", target, control) + return Primitive("T", target, control, noise) # Multi (target) qubit gates -def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def SWAP( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: """SWAP gate. By providing a control, it turns into a controlled gate (Fredkin gate), use None for no control qubits. @@ -207,20 +248,27 @@ def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: target: Tuple of ints describing the qubits to apply to. control: Optional tuple of ints or None describing the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("SWAP", target, control) + return Primitive("SWAP", target, control, noise) -def SQSWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: - return Primitive("SQSWAP", target, control) +def SQSWAP( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: + return Primitive("SQSWAP", target, control, noise) -def ISWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: - return Primitive("ISWAP", target, control) +def ISWAP( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: + return Primitive("ISWAP", target, control, noise) -def ISQSWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: - return Primitive("ISQSWAP", target, control) +def ISQSWAP( + target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None +) -> Primitive: + return Primitive("ISQSWAP", target, control, noise) diff --git a/horqrux/utils.py b/horqrux/utils.py index 2641530..1c924d9 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -14,6 +14,8 @@ QubitSupport = Tuple[Any, ...] ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...] TargetQubits = Tuple[Tuple[int, ...], ...] +ErrorProbabilities = Tuple[float, ...] | float + ATOL = 1e-014 From b22d3a76912f0cf9314e8142840b07965b1639bd Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 3 Dec 2024 08:28:07 +0100 Subject: [PATCH 08/57] change interface --- horqrux/noise.py | 313 ++++++----------------------------------- horqrux/parametric.py | 23 +-- horqrux/primitive.py | 51 +++---- horqrux/utils.py | 6 + horqrux/utils_noise.py | 267 +++++++++++++++++++++++++++++++++++ tests/test_gates.py | 29 ++++ 6 files changed, 375 insertions(+), 314 deletions(-) create mode 100644 horqrux/utils_noise.py diff --git a/horqrux/noise.py b/horqrux/noise.py index 31aefc5..1b1826d 100644 --- a/horqrux/noise.py +++ b/horqrux/noise.py @@ -1,299 +1,70 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, Tuple, Union +from typing import Any, Callable, Iterable, Tuple -import jax.numpy as jnp -import numpy as np from jax import Array from jax.tree_util import register_pytree_node_class -from jax.typing import ArrayLike from .utils import ( - QubitSupport, - TargetQubits, ErrorProbabilities, + StrEnum, +) +from .utils_noise import ( + AmplitudeDamping, + BitFlip, + Depolarizing, + GeneralizedAmplitudeDamping, + PauliChannel, + PhaseDamping, + PhaseFlip, ) -from .matrices import OPERATIONS_DICT +class NoiseType(StrEnum): + BITFLIP = "BitFlip" + PHASEFLIP = "PhaseFlip" + DEPOLARIZING = "Depolarizing" + PAULI_CHANNEL = "PauliChannel" + AMPLITUDE_DAMPING = "AmplitudeDamping" + PHASE_DAMPING = "PhaseDamping" + GENERALIZED_AMPLITUDE_DAMPING = "GeneralizedAmplitudeDamping" -@register_pytree_node_class -@dataclass -class Noise: - """Noise class which stores information on .""" - kraus: list[Array] - target: QubitSupport - error_probability: ErrorProbabilities +PROTOCOL_TO_KRAUS_FN: dict[str, Callable] = { + "BitFlip": BitFlip, + "PhaseFlip": PhaseFlip, + "Depolarizing": Depolarizing, + "PauliChannel": PauliChannel, + "AmplitudeDamping": AmplitudeDamping, + "PhaseDamping": PhaseDamping, + "GeneralizedAmplitudeDamping": GeneralizedAmplitudeDamping, +} - @staticmethod - def parse_idx( - idx: Tuple, - ) -> Tuple: - if isinstance(idx, (int, np.int64)): - return ((idx,),) - elif isinstance(idx, tuple): - return (idx,) - else: - return (idx.astype(int),) - def __post_init__(self) -> None: - self.target = Noise.parse_idx(self.target) +@register_pytree_node_class +@dataclass +class NoiseInstance: + type: NoiseType + error_probability: ErrorProbabilities def __iter__(self) -> Iterable: - return iter((self.kraus, self.target, self.error_probability)) + return iter((self.kraus, self.error_probability)) - def tree_flatten(self) -> Tuple[Tuple, Tuple[list[Array], TargetQubits, ErrorProbabilities]]: + def tree_flatten( + self, + ) -> Tuple[Tuple, Tuple[NoiseType, ErrorProbabilities]]: children = () - aux_data = (self.kraus, self.target[0], self.error_probability) + aux_data = (self.type, self.error_probability) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) -def BitFlip(target: TargetQubits, error_probability: float) -> Noise: - """ - Initialize the BitFlip gate. - - The bit flip channel is defined as: - - .. math:: - \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (float): The probability of a bit flip error. - - Raises: - ValueError: If the error_probability value is not a float. - - Returns: - Noise: Noise instance for this protocol. - """ - if error_probability > 1.0 or error_probability < 0.0: - raise ValueError("The error_probability value is not a correct probability") - K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] - K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["X"] - kraus_bitflip: list[Array] = [K0, K1] - return Noise(kraus_bitflip, target, error_probability) - -def PhaseFlip(target: TargetQubits, error_probability: float) -> Noise: - """ - Initialize the PhaseFlip gate - - The phase flip channel is defined as: - - .. math:: - \\rho \\Rightarrow (1-p) \\rho + p Z \\rho Z^{\\dagger} - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (float): The probability of phase flip error. - - Raises: - ValueError: If the error_probability value is not a float. - - Returns: - Noise: Noise instance for this protocol. - """ - - if error_probability > 1.0 or error_probability < 0.0: - raise ValueError("The error_probability value is not a correct probability") - K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] - K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["Z"] - kraus: list[Array] = [K0, K1] - return Noise(kraus, target, error_probability) - -def Depolarizing(target: TargetQubits, error_probability: float) -> Noise: - """ - Initialize the Depolarizing gate. - - The depolarizing channel is defined as: - - .. math:: - \\rho \\Rightarrow (1-p) \\rho - + p/3 X \\rho X^{\\dagger} - + p/3 Y \\rho Y^{\\dagger} - + p/3 Z \\rho Z^{\\dagger} - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (float): The probability of phase flip error. - - Raises: - ValueError: If the error_probability value is not a float. - - Returns: - Noise: Noise instance for this protocol. - """ - - if error_probability > 1.0 or error_probability < 0.0: - raise ValueError("The error_probability value is not a correct probability") - K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] - K1: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["X"] - K2: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Y"] - K3: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Z"] - kraus: list[Array] = [K0, K1, K2, K3] - return Noise(kraus, target, error_probability) - -def PauliChannel(target: TargetQubits, error_probability: tuple[float, ...]) -> Noise: - """ - Initialize the PauliChannel gate. - - The pauli channel is defined as: - - .. math:: - \\rho \\Rightarrow (1-px-py-pz) \\rho - + px X \\rho X^{\\dagger} - + py Y \\rho Y^{\\dagger} - + pz Z \\rho Z^{\\dagger} - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (ErrorProbabilities): Tuple containing probabilities - of X, Y, and Z errors. - - Raises: - ValueError: If the probabilities values do not sum up to 1. - - Returns: - Noise: Noise instance for this protocol. - """ - - sum_prob = sum(error_probability) - if sum_prob > 1.0: - raise ValueError("The sum of probabilities can't be greater than 1.0") - for probability in error_probability: - if probability > 1.0 or probability < 0.0: - raise ValueError("The probability values are not correct probabilities") - px, py, pz = ( - error_probability[0], - error_probability[1], - error_probability[2], - ) - - - K0: Array = jnp.sqrt(1.0 - (px + py + pz)) * OPERATIONS_DICT["I"] - K1: Array = jnp.sqrt(px) * OPERATIONS_DICT["X"] - K2: Array = jnp.sqrt(py) * OPERATIONS_DICT["Y"] - K3: Array = jnp.sqrt(pz) * OPERATIONS_DICT["Z"] - kraus: list[Array] = [K0, K1, K2, K3] - return Noise(kraus, target, error_probability) - -def AmplitudeDamping(target: TargetQubits, error_probability: float) -> Noise: - """ - Initialize the AmplitudeDamping gate. - - The amplitude damping channel is defined as: - - .. math:: - \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} - - with: - - .. code-block:: python - - K0 = [[1, 0], [0, sqrt(1 - rate)]] - K1 = [[0, sqrt(rate)], [0, 0]] - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (float): The damping rate, indicating the probability of amplitude loss. - - Raises: - ValueError: If the damping rate is not a correct probability. - - Returns: - Noise: Noise instance for this protocol. - """ - - rate = error_probability - if rate > 1.0 or rate < 0.0: - raise ValueError("The damping rate is not a correct probability") - K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) - K1: Array = jnp.array([[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128) - kraus: list[Array] = [K0, K1] - return Noise(kraus, target, error_probability) - -def PhaseDamping(target: TargetQubits, error_probability: float) -> Noise: - """ - Initialize the PhaseDamping gate. - - The phase damping channel is defined as: - - .. math:: - \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} - - with: - - .. code-block:: python - - K0 = [[1, 0], [0, sqrt(1 - rate)]] - K1 = [[0, 0], [0, sqrt(rate)]] - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (float): The damping rate, indicating the probability of phase damping. - - Raises: - ValueError: If the damping rate is not a correct probability. - - Returns: - Noise: Noise instance for this protocol. - """ - - rate = error_probability - if rate > 1.0 or rate < 0.0: - raise ValueError("The damping rate is not a correct probability") - K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) - K1: Array = jnp.array([[0, 0], [0, jnp.sqrt(rate)]], dtype=jnp.complex128) - kraus: list[Array] = [K0, K1] - return Noise(kraus, target, error_probability) - -def GeneralizedAmplitudeDamping(target: TargetQubits, error_probability: tuple[float, ...]) -> Noise: - """ - Initialize the GeneralizeAmplitudeDamping gate. - - The generalize amplitude damping channel is defined as: - - .. math:: - \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} - + K_2 \\rho K_2^{\\dagger} + K_3 \\rho K_3^{\\dagger} - - with: - - .. code-block:: python - - K0 = sqrt(p) * [[1, 0], [0, sqrt(1 - rate)]] - K1 = sqrt(p) * [[0, sqrt(rate)], [0, 0]] - K2 = sqrt(1-p) * [[sqrt(1 - rate), 0], [0, 1]] - K3 = sqrt(1-p) * [[0, 0], [sqrt(rate), 0]] - - Args: - target (int): The index of the qubit being affected by the noise. - error_probability (ErrorProbabilities): The first float must be the probability - of amplitude damping error, and the second float is the damping rate, indicating - the probability of generalized amplitude damping. - - Raises: - ValueError: If the damping rate is not a correct probability. + def kraus(self) -> tuple[Array, ...]: + kraus_fn: Callable[..., tuple[Array, ...]] = PROTOCOL_TO_KRAUS_FN[self.type] + return kraus_fn(error_probability=self.error_probability) - Returns: - Noise: Noise instance for this protocol. - """ - probability = error_probability[0] - rate = error_probability[1] - if probability > 1.0 or probability < 0.0: - raise ValueError("The probability value is not a correct probability") - if rate > 1.0 or rate < 0.0: - raise ValueError("The damping rate is not a correct probability") - - K0: Array = jnp.sqrt(probability) * jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) - K1: Array = jnp.sqrt(probability) * jnp.array([[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128) - K2: Array = jnp.sqrt(1.0-probability) * jnp.array([[jnp.sqrt(1.0-rate), 0], [0, 1]], dtype=jnp.complex128) - K3: Array = jnp.sqrt(1.0-probability) * jnp.array([[0, 0], [jnp.sqrt(rate), 0]], dtype=jnp.complex128) - kraus: list[Array] = [K0, K1, K2, K3] - return Noise(kraus, target, error_probability) \ No newline at end of file +NoiseProtocol = Tuple[NoiseInstance, ...] diff --git a/horqrux/parametric.py b/horqrux/parametric.py index faa3d4c..ff36b0f 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Iterable, Tuple import jax.numpy as jnp @@ -8,7 +8,7 @@ from jax.tree_util import register_pytree_node_class from .matrices import OPERATIONS_DICT -from .noise import Noise +from .noise import NoiseProtocol from .primitive import Primitive from .utils import ( ControlQubits, @@ -28,8 +28,8 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport - noise: Noise | None = None param: str | float = "" + noise: NoiseProtocol = field(default_factory=tuple) def __post_init__(self) -> None: super().__post_init__() @@ -44,19 +44,19 @@ def parse_val(values: dict[str, float] = dict()) -> float: def tree_flatten( # type: ignore[override] self, - ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, Noise | None, str | float]]: + ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float, NoiseProtocol]]: children = () aux_data = ( self.generator_name, self.target[0], self.control[0], - self.noise, self.param, + self.noise, ) return (children, aux_data) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.noise, self.param)) + return iter((self.generator_name, self.target, self.control, self.param, self.noise)) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: @@ -83,7 +83,7 @@ def RX( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: Noise | None = None, + noise: NoiseProtocol = tuple(), ) -> Parametric: """RX gate. @@ -103,7 +103,7 @@ def RY( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: Noise | None = None, + noise: NoiseProtocol = tuple(), ) -> Parametric: """RY gate. @@ -123,7 +123,7 @@ def RZ( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: Noise | None = None, + noise: NoiseProtocol = tuple(), ) -> Parametric: """RZ gate. @@ -157,7 +157,10 @@ def name(self) -> str: def PHASE( - param: float, target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + param: float, + target: TargetQubits, + control: ControlQubits = (None,), + noise: NoiseProtocol = tuple(), ) -> Parametric: """Phase gate. diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 264eef0..9267d41 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Iterable, Tuple, Union import numpy as np @@ -8,7 +8,7 @@ from jax.tree_util import register_pytree_node_class from .matrices import OPERATIONS_DICT -from .noise import Noise +from .noise import NoiseProtocol from .utils import ( ControlQubits, QubitSupport, @@ -28,7 +28,7 @@ class Primitive: generator_name: str target: QubitSupport control: QubitSupport - noise: Noise | None = None + noise: NoiseProtocol = field(default_factory=tuple) @staticmethod def parse_idx( @@ -47,20 +47,11 @@ def __post_init__(self) -> None: self.control = none_like(self.target) else: self.control = Primitive.parse_idx(self.control) - self._validate_noise() - - def _validate_noise(self) -> None: - if self.noise is not None: - dim = 2**self.n_qubits - noise_dim = (2**dim, 2**dim) - for kraus in self.noise.kraus: - if kraus.shape != noise_dim: - raise ValueError(f"Specify all noise tensors with shape {noise_dim}.") def __iter__(self) -> Iterable: return iter((self.generator_name, self.target, self.control, self.noise)) - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits, Noise | None]]: + def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits, NoiseProtocol]]: children = () aux_data = (self.generator_name, self.target[0], self.control[0], self.noise) return (children, aux_data) @@ -94,7 +85,7 @@ def __repr__(self) -> str: def I( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -113,7 +104,7 @@ def I( def X( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """X gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -136,7 +127,7 @@ def X( def Y( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -156,7 +147,7 @@ def Y( def Z( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -176,7 +167,7 @@ def Z( def H( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -195,7 +186,7 @@ def H( def S( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -214,7 +205,7 @@ def S( def T( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """T gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -236,7 +227,7 @@ def T( def SWAP( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() ) -> Primitive: """SWAP gate. By providing a control, it turns into a controlled gate (Fredkin gate), use None for no control qubits. @@ -256,19 +247,13 @@ def SWAP( return Primitive("SWAP", target, control, noise) -def SQSWAP( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None -) -> Primitive: - return Primitive("SQSWAP", target, control, noise) +def SQSWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: + return Primitive("SQSWAP", target, control) -def ISWAP( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None -) -> Primitive: - return Primitive("ISWAP", target, control, noise) +def ISWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: + return Primitive("ISWAP", target, control) -def ISQSWAP( - target: TargetQubits, control: ControlQubits = (None,), noise: Noise | None = None -) -> Primitive: - return Primitive("ISQSWAP", target, control, noise) +def ISQSWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: + return Primitive("ISQSWAP", target, control) diff --git a/horqrux/utils.py b/horqrux/utils.py index 1c924d9..91d1ea0 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -171,3 +171,9 @@ def _normalize(wf: Array) -> Array: def is_normalized(state: Array) -> bool: return equivalent_state(state, state) + + +def density_mat(state: Array) -> Array: + n_qubits = len(state.shape) + state = state.reshape(2**n_qubits) + return jnp.einsum("i,j->ij", state, state.conj()).reshape(tuple(2 for _ in range(2 * n_qubits))) diff --git a/horqrux/utils_noise.py b/horqrux/utils_noise.py new file mode 100644 index 0000000..19149f1 --- /dev/null +++ b/horqrux/utils_noise.py @@ -0,0 +1,267 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Tuple + +import jax.numpy as jnp +from jax import Array + +from .matrices import OPERATIONS_DICT +from .utils import ( + TargetQubits, +) + + +def BitFlip(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the BitFlip gate. + + The bit flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} + + Args: + error_probability (float): The probability of a bit flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["X"] + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def PhaseFlip(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the PhaseFlip gate + + The phase flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p Z \\rho Z^{\\dagger} + + Args: + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def Depolarizing(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the Depolarizing gate. + + The depolarizing channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + + p/3 X \\rho X^{\\dagger} + + p/3 Y \\rho Y^{\\dagger} + + p/3 Z \\rho Z^{\\dagger} + + Args: + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if error_probability > 1.0 or error_probability < 0.0: + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus + + +def PauliChannel(error_probability: tuple[float, ...]) -> tuple[Array, ...]: + """ + Initialize the PauliChannel gate. + + The pauli channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-px-py-pz) \\rho + + px X \\rho X^{\\dagger} + + py Y \\rho Y^{\\dagger} + + pz Z \\rho Z^{\\dagger} + + Args: + error_probability (ErrorProbabilities): Tuple containing probabilities + of X, Y, and Z errors. + + Raises: + ValueError: If the probabilities values do not sum up to 1. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + sum_prob = sum(error_probability) + if sum_prob > 1.0: + raise ValueError("The sum of probabilities can't be greater than 1.0") + for probability in error_probability: + if probability > 1.0 or probability < 0.0: + raise ValueError("The probability values are not correct probabilities") + px, py, pz = ( + error_probability[0], + error_probability[1], + error_probability[2], + ) + + K0: Array = jnp.sqrt(1.0 - (px + py + pz)) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(px) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(py) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(pz) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus + + +def AmplitudeDamping(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the AmplitudeDamping gate. + + The amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - rate)]] + K1 = [[0, sqrt(rate)], [0, 0]] + + Args: + error_probability (float): The damping rate, indicating the probability of amplitude loss. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + rate = error_probability + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128) + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def PhaseDamping(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the PhaseDamping gate. + + The phase damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - rate)]] + K1 = [[0, 0], [0, sqrt(rate)]] + + Args: + error_probability (float): The damping rate, indicating the probability of phase damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + rate = error_probability + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, 0], [0, jnp.sqrt(rate)]], dtype=jnp.complex128) + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def GeneralizedAmplitudeDamping( + error_probability: tuple[float, ...] +) -> tuple[Array, ...]: + """ + Initialize the GeneralizeAmplitudeDamping gate. + + The generalize amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + K_2 \\rho K_2^{\\dagger} + K_3 \\rho K_3^{\\dagger} + + with: + + .. code-block:: python + + K0 = sqrt(p) * [[1, 0], [0, sqrt(1 - rate)]] + K1 = sqrt(p) * [[0, sqrt(rate)], [0, 0]] + K2 = sqrt(1-p) * [[sqrt(1 - rate), 0], [0, 1]] + K3 = sqrt(1-p) * [[0, 0], [sqrt(rate), 0]] + + Args: + error_probability (ErrorProbabilities): The first float must be the probability + of amplitude damping error, and the second float is the damping rate, indicating + the probability of generalized amplitude damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + probability = error_probability[0] + rate = error_probability[1] + if probability > 1.0 or probability < 0.0: + raise ValueError("The probability value is not a correct probability") + if rate > 1.0 or rate < 0.0: + raise ValueError("The damping rate is not a correct probability") + + K0: Array = jnp.sqrt(probability) * jnp.array( + [[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128 + ) + K1: Array = jnp.sqrt(probability) * jnp.array( + [[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128 + ) + K2: Array = jnp.sqrt(1.0 - probability) * jnp.array( + [[jnp.sqrt(1.0 - rate), 0], [0, 1]], dtype=jnp.complex128 + ) + K3: Array = jnp.sqrt(1.0 - probability) * jnp.array( + [[0, 0], [jnp.sqrt(rate), 0]], dtype=jnp.complex128 + ) + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus diff --git a/tests/test_gates.py b/tests/test_gates.py index 4c44cd3..ac0d6f0 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -8,6 +8,7 @@ from jax import Array from horqrux.apply import apply_gate, apply_operator +from horqrux.noise import NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import equivalent_state, product_state, random_state @@ -15,6 +16,7 @@ MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) +Noise_ops = (NoiseType.BitFlip, NoiseType.PhaseFlip, NoiseType.Depolarizing) @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) @@ -28,6 +30,33 @@ def test_primitive(gate_fn: Callable) -> None: ) +# @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) +# @pytest.mark.parametrize("noise_fn", Noise_ops) +# def test_noisy_primitive(gate_fn: Callable, noise_fn: Callable) -> None: +# target = np.random.randint(0, MAX_QUBITS) +# noise = noise_fn(target, error_probability=0.1) + +# noisy_gate = gate_fn(target, noise=noise) +# orig_state = random_state(MAX_QUBITS) +# orig_dm = density_mat(orig_state) +# noisy_output_dm = apply_gate(orig_dm, noisy_gate) +# output_dm = apply_gate(orig_dm, gate_fn(target)) +# assert noisy_output_dm.shape == orig_dm.shape +# assert not jnp.allclose(noisy_output_dm, orig_dm) +# assert not jnp.allclose(output_dm, noisy_output_dm) + +# reverse_op = output_dm +# # for kraus in noise.kraus[::-1]: +# # reverse_op = apply_operator( +# # reverse_op, _dagger(kraus), noisy_gate.target[0], noisy_gate.control[0] +# # ) +# reverse_op = apply_operator( +# reverse_op, noisy_gate.dagger(), noisy_gate.target[0], noisy_gate.control[0] +# ) +# assert reverse_op.shape == orig_dm.shape +# assert jnp.allclose(reverse_op, orig_dm) + + @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) def test_controlled_primitive(gate_fn: Callable) -> None: target = np.random.randint(0, MAX_QUBITS) From a49e75a03ce8268b96379a22dd645f96ec21d419 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 3 Dec 2024 10:35:35 +0100 Subject: [PATCH 09/57] test tuple length noise in gate --- horqrux/utils_noise.py | 11 +---------- tests/test_gates.py | 37 ++++++++++--------------------------- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/horqrux/utils_noise.py b/horqrux/utils_noise.py index 19149f1..030f467 100644 --- a/horqrux/utils_noise.py +++ b/horqrux/utils_noise.py @@ -1,16 +1,9 @@ - from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Iterable, Tuple - import jax.numpy as jnp from jax import Array from .matrices import OPERATIONS_DICT -from .utils import ( - TargetQubits, -) def BitFlip(error_probability: float) -> tuple[Array, ...]: @@ -211,9 +204,7 @@ def PhaseDamping(error_probability: float) -> tuple[Array, ...]: return kraus -def GeneralizedAmplitudeDamping( - error_probability: tuple[float, ...] -) -> tuple[Array, ...]: +def GeneralizedAmplitudeDamping(error_probability: tuple[float, ...]) -> tuple[Array, ...]: """ Initialize the GeneralizeAmplitudeDamping gate. diff --git a/tests/test_gates.py b/tests/test_gates.py index ac0d6f0..80530ad 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -8,7 +8,7 @@ from jax import Array from horqrux.apply import apply_gate, apply_operator -from horqrux.noise import NoiseType +from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import equivalent_state, product_state, random_state @@ -16,7 +16,7 @@ MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) -Noise_ops = (NoiseType.BitFlip, NoiseType.PhaseFlip, NoiseType.Depolarizing) +NOISE_TYPES = (NoiseType.BITFLIP, NoiseType.PHASEFLIP, NoiseType.DEPOLARIZING) @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) @@ -30,31 +30,14 @@ def test_primitive(gate_fn: Callable) -> None: ) -# @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) -# @pytest.mark.parametrize("noise_fn", Noise_ops) -# def test_noisy_primitive(gate_fn: Callable, noise_fn: Callable) -> None: -# target = np.random.randint(0, MAX_QUBITS) -# noise = noise_fn(target, error_probability=0.1) - -# noisy_gate = gate_fn(target, noise=noise) -# orig_state = random_state(MAX_QUBITS) -# orig_dm = density_mat(orig_state) -# noisy_output_dm = apply_gate(orig_dm, noisy_gate) -# output_dm = apply_gate(orig_dm, gate_fn(target)) -# assert noisy_output_dm.shape == orig_dm.shape -# assert not jnp.allclose(noisy_output_dm, orig_dm) -# assert not jnp.allclose(output_dm, noisy_output_dm) - -# reverse_op = output_dm -# # for kraus in noise.kraus[::-1]: -# # reverse_op = apply_operator( -# # reverse_op, _dagger(kraus), noisy_gate.target[0], noisy_gate.control[0] -# # ) -# reverse_op = apply_operator( -# reverse_op, noisy_gate.dagger(), noisy_gate.target[0], noisy_gate.control[0] -# ) -# assert reverse_op.shape == orig_dm.shape -# assert jnp.allclose(reverse_op, orig_dm) +@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) +@pytest.mark.parametrize("noise_type", NOISE_TYPES) +def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: + target = np.random.randint(0, MAX_QUBITS) + noise = NoiseInstance(noise_type, error_probability=0.1) + + noisy_gate = gate_fn(target, noise=(noise,)) + assert len(noisy_gate.noise) == 1 @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) From 7603eb3ce05a8883826d30eae8bf16bae7578a48 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 3 Dec 2024 17:24:31 +0100 Subject: [PATCH 10/57] fix noise --- horqrux/apply.py | 68 ++++++++++++++++++++++++++++++++++++++++----- horqrux/noise.py | 4 +++ horqrux/utils.py | 15 ++++++---- tests/test_gates.py | 5 ++++ 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 5bd054c..3326080 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -1,16 +1,18 @@ from __future__ import annotations -from functools import reduce +from functools import partial, reduce from operator import add from typing import Iterable, Tuple +import jax import jax.numpy as jnp import numpy as np from jax import Array from horqrux.primitive import Primitive -from .utils import OperationType, State, _controlled, is_controlled +from .noise import NoiseProtocol +from .utils import DensityMatrix, OperationType, State, _controlled, _dagger, is_controlled def apply_operator( @@ -51,6 +53,46 @@ def apply_operator( return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) +def apply_kraus_operator( + kraus: Array, + state: State, + target: Tuple[int, ...], +) -> State: + state_dims: Tuple[int, ...] = target + n_qubits = int(np.log2(kraus.size)) + kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits))) + op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int)) + + # Ki rho + state = jnp.tensordot(a=kraus, b=state, axes=(op_dims, state_dims)) + new_state_dims = tuple(i for i in range(len(state_dims))) + state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + + # dagger ops + state = jnp.tensordot(a=kraus, b=_dagger(state), axes=(op_dims, state_dims)) + state = _dagger(state) + + return state + + +def apply_operator_with_noise( + state: State, + operator: Array, + target: Tuple[int, ...], + control: Tuple[int | None, ...], + noise: NoiseProtocol, +) -> State: + state_gate = apply_operator(state, operator, target, control) + if len(noise) == 0: + return state_gate + else: + kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise)))) + apply_one_kraus = jax.vmap(partial(apply_kraus_operator, state=state_gate, target=target)) + kraus_evol = apply_one_kraus(kraus_ops) + output_dm = jnp.sum(kraus_evol, 0) + return output_dm + + def group_by_index(gates: Iterable[Primitive]) -> Iterable[Primitive]: """Group gates together which are acting on the same qubit.""" sorted_gates = [] @@ -106,7 +148,7 @@ def merge_operators( def apply_gate( - state: State, + state: State | DensityMatrix, gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, @@ -115,7 +157,7 @@ def apply_gate( ) -> State: """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. Arguments: - state: State to operate on. + state: State or DensityMatrix to operate on. gate: Gate(s) to apply. values: A dictionary with parameter values. op_type: The type of operation to perform: Unitary, Dagger or Jacobian. @@ -126,9 +168,11 @@ def apply_gate( State after applying 'gate'. """ operator: Tuple[Array, ...] + noise = list() if isinstance(gate, Primitive): operator_fn = getattr(gate, op_type) operator, target, control = (operator_fn(values),), gate.target, gate.control + noise += [gate.noise] else: if group_gates: gate = group_by_index(gate) @@ -137,8 +181,18 @@ def apply_gate( control = reduce(add, [g.control for g in gate]) if merge_ops: operator, target, control = merge_operators(operator, target, control) - return reduce( - lambda state, gate: apply_operator(state, *gate), - zip(operator, target, control), + noise = [g.noise for g in gate] + + has_noise = len(reduce(add, noise)) > 0 + if has_noise and not isinstance(state, DensityMatrix): + print(state.shape) + state = DensityMatrix(state).dm + print(state.shape) + + output_state = reduce( + lambda state, gate: apply_operator_with_noise(state, *gate), + zip(operator, target, control, noise), state, ) + + return output_state diff --git a/horqrux/noise.py b/horqrux/noise.py index 1b1826d..e26f2dc 100644 --- a/horqrux/noise.py +++ b/horqrux/noise.py @@ -62,9 +62,13 @@ def tree_flatten( def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) + @property def kraus(self) -> tuple[Array, ...]: kraus_fn: Callable[..., tuple[Array, ...]] = PROTOCOL_TO_KRAUS_FN[self.type] return kraus_fn(error_probability=self.error_probability) + def __repr__(self) -> str: + return self.type + f"(p={self.error_probability})" + NoiseProtocol = Tuple[NoiseInstance, ...] diff --git a/horqrux/utils.py b/horqrux/utils.py index 91d1ea0..89008c9 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -19,6 +19,15 @@ ATOL = 1e-014 +class DensityMatrix: + def __init__(self, state: ArrayLike) -> None: + n_qubits = len(state.shape) + state = state.reshape(2**n_qubits) + self.dm = jnp.einsum("i,j->ij", state, state.conj()).reshape( + tuple(2 for _ in range(2 * n_qubits)) + ) + + class StrEnum(str, Enum): def __str__(self) -> str: """Used when dumping enum fields in a schema.""" @@ -171,9 +180,3 @@ def _normalize(wf: Array) -> Array: def is_normalized(state: Array) -> bool: return equivalent_state(state, state) - - -def density_mat(state: Array) -> Array: - n_qubits = len(state.shape) - state = state.reshape(2**n_qubits) - return jnp.einsum("i,j->ij", state, state.conj()).reshape(tuple(2 for _ in range(2 * n_qubits))) diff --git a/tests/test_gates.py b/tests/test_gates.py index 80530ad..9f64066 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -39,6 +39,11 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: noisy_gate = gate_fn(target, noise=(noise,)) assert len(noisy_gate.noise) == 1 + orig_state = random_state(MAX_QUBITS) + output_dm = apply_gate(orig_state, noisy_gate) + # check output is a density matrix + assert len(output_dm.shape) == 2 * MAX_QUBITS + @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) def test_controlled_primitive(gate_fn: Callable) -> None: From 965099f65ccebdab2e0842d5c2b86210b6b0e28f Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 4 Dec 2024 09:07:03 +0100 Subject: [PATCH 11/57] separate tests between noisy and non noisy --- tests/test_gates.py | 17 ------------- tests/test_noise.py | 60 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 17 deletions(-) create mode 100644 tests/test_noise.py diff --git a/tests/test_gates.py b/tests/test_gates.py index 9f64066..4c44cd3 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -8,7 +8,6 @@ from jax import Array from horqrux.apply import apply_gate, apply_operator -from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import equivalent_state, product_state, random_state @@ -16,7 +15,6 @@ MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) -NOISE_TYPES = (NoiseType.BITFLIP, NoiseType.PHASEFLIP, NoiseType.DEPOLARIZING) @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) @@ -30,21 +28,6 @@ def test_primitive(gate_fn: Callable) -> None: ) -@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) -@pytest.mark.parametrize("noise_type", NOISE_TYPES) -def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: - target = np.random.randint(0, MAX_QUBITS) - noise = NoiseInstance(noise_type, error_probability=0.1) - - noisy_gate = gate_fn(target, noise=(noise,)) - assert len(noisy_gate.noise) == 1 - - orig_state = random_state(MAX_QUBITS) - output_dm = apply_gate(orig_state, noisy_gate) - # check output is a density matrix - assert len(output_dm.shape) == 2 * MAX_QUBITS - - @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) def test_controlled_primitive(gate_fn: Callable) -> None: target = np.random.randint(0, MAX_QUBITS) diff --git a/tests/test_noise.py b/tests/test_noise.py new file mode 100644 index 0000000..b6462f8 --- /dev/null +++ b/tests/test_noise.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Callable + +import jax.numpy as jnp +import numpy as np +import pytest +from jax import Array + +from horqrux.apply import apply_gate, apply_operator +from horqrux.noise import NoiseInstance, NoiseType +from horqrux.parametric import PHASE, RX, RY, RZ +from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z +from horqrux.utils import equivalent_state, product_state, random_state + +MAX_QUBITS = 7 +PARAMETRIC_GATES = (RX, RY, RZ, PHASE) +PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) + +NOISE_oneproba = (NoiseType.BITFLIP, NoiseType.PHASEFLIP, NoiseType.DEPOLARIZING, NoiseType.AMPLITUDE_DAMPING, NoiseType.PHASE_DAMPING) +ALL_NOISES = list(NoiseType) + +def noise_instance(noise_type: NoiseType) -> NoiseInstance: + if noise_type in NOISE_oneproba: + errors = 0.1 + elif noise_type == NoiseType.PAULI_CHANNEL: + errors=(0.4, 0.5, 0.1) + else: + errors = (0.2, 0.8) + + return NoiseInstance(noise_type, error_probability=errors) + +@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) +@pytest.mark.parametrize("noise_type", ALL_NOISES) +def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: + target = np.random.randint(0, MAX_QUBITS) + noise = noise_instance(noise_type) + + noisy_gate = gate_fn(target, noise=(noise,)) + assert len(noisy_gate.noise) == 1 + + orig_state = random_state(MAX_QUBITS) + output_dm = apply_gate(orig_state, noisy_gate) + # check output is a density matrix + assert len(output_dm.shape) == 2 * MAX_QUBITS + + +@pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) +@pytest.mark.parametrize("noise_type", ALL_NOISES) +def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: + target = np.random.randint(0, MAX_QUBITS) + noise = noise_instance(noise_type) + noisy_gate = gate_fn("theta", target, noise=(noise,)) + values = {"theta": np.random.uniform(0.1, 2 * np.pi)} + orig_state = random_state(MAX_QUBITS) + + output_dm = apply_gate(orig_state, noisy_gate, values) + # check output is a density matrix + assert len(output_dm.shape) == 2 * MAX_QUBITS + \ No newline at end of file From 06c6062eabe1235c319724c4b050137ac27aa9dc Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 4 Dec 2024 14:06:29 +0100 Subject: [PATCH 12/57] lint --- tests/test_noise.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_noise.py b/tests/test_noise.py index b6462f8..91ef6d4 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -2,34 +2,40 @@ from typing import Callable -import jax.numpy as jnp import numpy as np import pytest -from jax import Array -from horqrux.apply import apply_gate, apply_operator +from horqrux.apply import apply_gate from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ -from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z -from horqrux.utils import equivalent_state, product_state, random_state +from horqrux.primitive import NOT, H, I, S, T, X, Y, Z +from horqrux.utils import random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) -NOISE_oneproba = (NoiseType.BITFLIP, NoiseType.PHASEFLIP, NoiseType.DEPOLARIZING, NoiseType.AMPLITUDE_DAMPING, NoiseType.PHASE_DAMPING) +NOISE_oneproba = ( + NoiseType.BITFLIP, + NoiseType.PHASEFLIP, + NoiseType.DEPOLARIZING, + NoiseType.AMPLITUDE_DAMPING, + NoiseType.PHASE_DAMPING, +) ALL_NOISES = list(NoiseType) + def noise_instance(noise_type: NoiseType) -> NoiseInstance: if noise_type in NOISE_oneproba: errors = 0.1 elif noise_type == NoiseType.PAULI_CHANNEL: - errors=(0.4, 0.5, 0.1) + errors = (0.4, 0.5, 0.1) else: errors = (0.2, 0.8) - + return NoiseInstance(noise_type, error_probability=errors) + @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) @pytest.mark.parametrize("noise_type", ALL_NOISES) def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: @@ -57,4 +63,3 @@ def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: output_dm = apply_gate(orig_state, noisy_gate, values) # check output is a density matrix assert len(output_dm.shape) == 2 * MAX_QUBITS - \ No newline at end of file From 9aca1c8ad1ce68d75428ff1a001ec76c25a1488e Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Fri, 6 Dec 2024 08:46:14 +0100 Subject: [PATCH 13/57] fix order noise and param by inheritance --- horqrux/apply.py | 2 -- horqrux/parametric.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index dc20390..4625d76 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -185,9 +185,7 @@ def apply_gate( has_noise = len(reduce(add, noise)) > 0 if has_noise and not isinstance(state, DensityMatrix): - print(state.shape) state = DensityMatrix(state).dm - print(state.shape) output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 293c7f9..184b9f2 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -31,8 +31,8 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport - param: str | float = "" noise: NoiseProtocol = field(default_factory=tuple) + param: str | float = "" def __post_init__(self) -> None: super().__post_init__() @@ -47,19 +47,19 @@ def parse_val(values: dict[str, float] = dict()) -> float: def tree_flatten( # type: ignore[override] self, - ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float, NoiseProtocol]]: + ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, NoiseProtocol, str | float]]: children = () aux_data = ( self.generator_name, self.target[0], self.control[0], - self.param, self.noise, + self.param, ) return (children, aux_data) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.param, self.noise)) + return iter((self.generator_name, self.target, self.control, self.noise, self.param)) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: @@ -97,7 +97,7 @@ def RX( Returns: Parametric: A Parametric gate object. """ - return Parametric("X", target, control, param=param, noise=noise) + return Parametric("X", target, control, noise, param) def RY( @@ -117,7 +117,7 @@ def RY( Returns: Parametric: A Parametric gate object. """ - return Parametric("Y", target, control, param=param, noise=noise) + return Parametric("Y", target, control, noise, param) def RZ( @@ -137,7 +137,7 @@ def RZ( Returns: Parametric: A Parametric gate object. """ - return Parametric("Z", target, control, param=param, noise=noise) + return Parametric("Z", target, control, noise, param) class _PHASE(Parametric): @@ -175,4 +175,4 @@ def PHASE( Parametric: A Parametric gate object. """ - return _PHASE("I", target, control, param=param, noise=noise) + return _PHASE("I", target, control, noise, param) From 6f49beb7135ccf90b593c093abed3ce0fed5f3f4 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Fri, 6 Dec 2024 08:50:00 +0100 Subject: [PATCH 14/57] fix union --- horqrux/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horqrux/utils.py b/horqrux/utils.py index 2eb87bf..4ed681f 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -18,7 +18,7 @@ QubitSupport = Tuple[Any, ...] ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...] TargetQubits = Tuple[Tuple[int, ...], ...] -ErrorProbabilities = Tuple[float, ...] | float +ErrorProbabilities = Union[Tuple[float, ...], float] ATOL = 1e-014 From b3a9a2802c99464459eaf7ae803311e4efeee4dd Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 9 Dec 2024 09:05:03 +0100 Subject: [PATCH 15/57] rm densitymatrix object - to replace by more functional way --- horqrux/api.py | 18 +++++++++++++++--- horqrux/apply.py | 12 +++++++----- horqrux/utils.py | 19 ++++++++++++------- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 2eeb464..f178535 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -52,7 +52,11 @@ def sample( def __ad_expectation_single_observable( - state: Array, gates: GateSequence, observable: Primitive, values: dict[str, float] + state: Array, + gates: GateSequence, + observable: Primitive, + values: dict[str, float], + is_state_densitymat: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' @@ -60,11 +64,18 @@ def __ad_expectation_single_observable( """ out_state = apply_gate(state, gates, values, OperationType.UNITARY) projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) - return inner(out_state, projected_state).real + if not is_state_densitymat: + return inner(out_state, projected_state).real + + raise NotImplementedError("Expectation from density matrices is not yet supported!") def ad_expectation( - state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float] + state: Array, + gates: GateSequence, + observables: list[Primitive], + values: dict[str, float], + is_state_densitymat: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' @@ -85,6 +96,7 @@ def expectation( diff_mode: DiffMode = DiffMode.AD, forward_mode: ForwardMode = ForwardMode.EXACT, n_shots: Optional[int] = None, + is_state_densitymat: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ diff --git a/horqrux/apply.py b/horqrux/apply.py index 4625d76..dd380f5 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -12,7 +12,7 @@ from horqrux.primitive import Primitive from .noise import NoiseProtocol -from .utils import DensityMatrix, OperationType, State, _controlled, _dagger, is_controlled +from .utils import OperationType, State, _controlled, _dagger, density_mat, is_controlled def apply_operator( @@ -148,12 +148,13 @@ def merge_operators( def apply_gate( - state: State | DensityMatrix, + state: State, gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, + is_state_densitymat: bool = False, ) -> State: """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. Arguments: @@ -163,9 +164,10 @@ def apply_gate( op_type: The type of operation to perform: Unitary, Dagger or Jacobian. group_gates: Group gates together which are acting on the same qubit. merge_ops: Attempt to merge operators acting on the same qubit. + is_state_densitymat: If True, state is provided as a density matrix. Returns: - State after applying 'gate'. + State or density matrix after applying 'gate'. """ operator: Tuple[Array, ...] noise = list() @@ -184,8 +186,8 @@ def apply_gate( noise = [g.noise for g in gate] has_noise = len(reduce(add, noise)) > 0 - if has_noise and not isinstance(state, DensityMatrix): - state = DensityMatrix(state).dm + if has_noise and not is_state_densitymat: + state = density_mat(state) output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), diff --git a/horqrux/utils.py b/horqrux/utils.py index 4ed681f..c119956 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -23,13 +23,18 @@ ATOL = 1e-014 -class DensityMatrix: - def __init__(self, state: ArrayLike) -> None: - n_qubits = len(state.shape) - state = state.reshape(2**n_qubits) - self.dm = jnp.einsum("i,j->ij", state, state.conj()).reshape( - tuple(2 for _ in range(2 * n_qubits)) - ) +def density_mat(state: Array) -> Array: + """Convert state to density matrix + + Args: + state (State): Input state. + + Returns: + State: Density matrix representation. + """ + n_qubits = len(state.shape) + state = state.reshape(2**n_qubits) + return jnp.einsum("i,j->ij", state, state.conj()).reshape(tuple(2 for _ in range(2 * n_qubits))) class StrEnum(str, Enum): From f3f1aa3a4d120c06b866c1e0feca0c9790bd6d35 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 9 Dec 2024 09:07:31 +0100 Subject: [PATCH 16/57] add raises notimplementederrors --- horqrux/api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/horqrux/api.py b/horqrux/api.py index f178535..f17302e 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -82,7 +82,7 @@ def ad_expectation( and compute the expectation given an observable. """ outputs = [ - __ad_expectation_single_observable(state, gates, observable, values) + __ad_expectation_single_observable(state, gates, observable, values, is_state_densitymat) for observable in observables ] return jnp.stack(outputs) @@ -117,4 +117,6 @@ def expectation( ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore + if is_state_densitymat: + raise NotImplementedError("Expectation from density matrices is not yet supported!") return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key) From 78e7519fe75e997453a8fe2e6277d6877c4d0215 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 9 Dec 2024 11:20:16 +0100 Subject: [PATCH 17/57] add boolean in fwd fcts --- horqrux/api.py | 14 +++++++++++++- horqrux/shots.py | 16 +++++++++++----- tests/test_noise.py | 22 +++++++++++++++++++--- tests/test_shots.py | 6 ++++-- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index f17302e..f4f7382 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -28,10 +28,14 @@ def sample( gates: GateSequence, values: dict[str, float] = dict(), n_shots: int = 1000, + is_state_densitymat: bool = False, ) -> Counter: if n_shots < 1: raise ValueError("You can only call sample with n_shots>0.") + if is_state_densitymat: + raise NotImplementedError("Sampling with density matrices is not yet supported!") + wf = apply_gate(state, gates, values) probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() key = jax.random.PRNGKey(0) @@ -119,4 +123,12 @@ def expectation( # type: ignore if is_state_densitymat: raise NotImplementedError("Expectation from density matrices is not yet supported!") - return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key) + return finite_shots_fwd( + state, + gates, + observables, + values, + n_shots=n_shots, + is_state_densitymat=is_state_densitymat, + key=key, + ) diff --git a/horqrux/shots.py b/horqrux/shots.py index 4383100..9069425 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -33,20 +33,21 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) +@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5, 6)) def finite_shots_fwd( state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float], n_shots: int = 100, + is_state_densitymat: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - state = apply_gate(state, gates, values) + state = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) n_qubits = len(state.shape) mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] @@ -100,6 +101,7 @@ def finite_shots_jvp( gates: GateSequence, observable: Primitive, n_shots: int, + is_state_densitymat: bool, key: Array, primals: tuple[dict[str, float]], tangents: tuple[dict[str, float]], @@ -116,14 +118,18 @@ def jvp_component(param_name: str, key: Array) -> Array: up_key, down_key = random.split(key) up_val = values.copy() up_val[param_name] = up_val[param_name] + shift - f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key) + f_up = finite_shots_fwd( + state, gates, observable, up_val, n_shots, is_state_densitymat, up_key + ) down_val = values.copy() down_val[param_name] = down_val[param_name] - shift - f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key) + f_down = finite_shots_fwd( + state, gates, observable, down_val, n_shots, is_state_densitymat, down_key + ) grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) return grad * tangent_dict[param_name] params_with_keys = zip(values.keys(), random.split(key, len(values))) - fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) + fwd = finite_shots_fwd(state, gates, observable, values, n_shots, is_state_densitymat, key) jvp = sum(jvp_component(param, key) for param, key in params_with_keys) return fwd, jvp.reshape(fwd.shape) diff --git a/tests/test_noise.py b/tests/test_noise.py index 91ef6d4..686d71b 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -2,6 +2,7 @@ from typing import Callable +import jax.numpy as jnp import numpy as np import pytest @@ -9,7 +10,7 @@ from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z -from horqrux.utils import random_state +from horqrux.utils import density_mat, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -45,10 +46,17 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: noisy_gate = gate_fn(target, noise=(noise,)) assert len(noisy_gate.noise) == 1 + dm_shape_len = 2 * MAX_QUBITS + orig_state = random_state(MAX_QUBITS) output_dm = apply_gate(orig_state, noisy_gate) # check output is a density matrix - assert len(output_dm.shape) == 2 * MAX_QUBITS + assert len(output_dm.shape) == dm_shape_len + + orig_dm = density_mat(orig_state) + assert len(orig_dm.shape) == dm_shape_len + output_dm2 = apply_gate(orig_dm, noisy_gate, is_state_densitymat=True) + assert jnp.allclose(output_dm2, output_dm) @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) @@ -60,6 +68,14 @@ def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: values = {"theta": np.random.uniform(0.1, 2 * np.pi)} orig_state = random_state(MAX_QUBITS) + dm_shape_len = 2 * MAX_QUBITS + output_dm = apply_gate(orig_state, noisy_gate, values) # check output is a density matrix - assert len(output_dm.shape) == 2 * MAX_QUBITS + assert len(output_dm.shape) == dm_shape_len + + orig_dm = density_mat(orig_state) + assert len(orig_dm.shape) == dm_shape_len + + output_dm2 = apply_gate(orig_dm, noisy_gate, values, is_state_densitymat=True) + assert jnp.allclose(output_dm2, output_dm) diff --git a/tests/test_shots.py b/tests/test_shots.py index c98062d..85e98db 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -20,11 +20,13 @@ def test_shots() -> None: def exact(x): values = {"theta": x} - return expectation(state, ops, observables, values, "ad") + return expectation(state, ops, observables, values, diff_mode="ad") def shots(x): values = {"theta": x} - return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) + return expectation( + state, ops, observables, values, diff_mode="gpsr", forward_mode="shots", n_shots=N_SHOTS + ) exp_exact = exact(x) exp_shots = shots(x) From d310f7811dc17de38f5df9b48dfa8252a0f9241d Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 09:01:12 +0100 Subject: [PATCH 18/57] add current implementation of channel apply --- horqrux/__init__.py | 2 +- horqrux/api.py | 7 ++++--- horqrux/apply.py | 27 ++++++++++++++++++++------- horqrux/shots.py | 21 +++++++++++++++------ 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 513c031..2569a54 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .api import expectation +from .api import expectation, run from .apply import apply_gate, apply_operator from .circuit import QuantumCircuit, sample from .parametric import PHASE, RX, RY, RZ diff --git a/horqrux/api.py b/horqrux/api.py index f4f7382..76990b6 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -19,8 +19,9 @@ def run( circuit: GateSequence, state: Array, values: dict[str, float] = dict(), + is_state_densitymat: bool = False, ) -> Array: - return apply_gate(state, circuit, values) + return apply_gate(state, circuit, values, is_state_densitymat=is_state_densitymat) def sample( @@ -121,8 +122,8 @@ def expectation( ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore - if is_state_densitymat: - raise NotImplementedError("Expectation from density matrices is not yet supported!") + # if is_state_densitymat: + # raise NotImplementedError("Expectation with density matrices is not yet supported!") return finite_shots_fwd( state, gates, diff --git a/horqrux/apply.py b/horqrux/apply.py index dd380f5..9e8de9f 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -20,6 +20,7 @@ def apply_operator( operator: Array, target: Tuple[int, ...], control: Tuple[int | None, ...], + is_state_densitymat: bool = False, ) -> State: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. @@ -37,6 +38,7 @@ def apply_operator( operator: Array to contract over 'state'. target: Tuple of target qubits on which to apply the 'operator' to. control: Tuple of control qubits. + is_state_densitymat: Whether the state is provided as a density matrix. Returns: State after applying 'operator'. @@ -45,12 +47,21 @@ def apply_operator( if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] - n_qubits = int(np.log2(operator.shape[1])) - operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits))) - op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) - state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims)) + n_qubits_op = int(np.log2(operator.shape[1])) + operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) + # Apply operator + new_state_dims = tuple(i for i in range(len(state_dims))) - return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_dims, state_dims)) + if not is_state_densitymat: + return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + operator_dagger = _dagger(operator_reshaped) + + # Apply operator to density matrix: ρ' = O ρ O† + + state = jnp.tensordot(a=operator_dagger, b=state, axes=(op_dims, state_dims)) + return state def apply_kraus_operator( @@ -81,8 +92,9 @@ def apply_operator_with_noise( target: Tuple[int, ...], control: Tuple[int | None, ...], noise: NoiseProtocol, + is_state_densitymat: bool = False, ) -> State: - state_gate = apply_operator(state, operator, target, control) + state_gate = apply_operator(state, operator, target, control, is_state_densitymat) if len(noise) == 0: return state_gate else: @@ -188,10 +200,11 @@ def apply_gate( has_noise = len(reduce(add, noise)) > 0 if has_noise and not is_state_densitymat: state = density_mat(state) + is_state_densitymat = True output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), - zip(operator, target, control, noise), + zip(operator, target, control, noise, (is_state_densitymat,) * len(target)), state, ) diff --git a/horqrux/shots.py b/horqrux/shots.py index 9069425..614dfd8 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -49,12 +49,21 @@ def finite_shots_fwd( """ state = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) n_qubits = len(state.shape) - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] - eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] - eigvecs, eigvals = align_eigenvectors(eigs) - inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) - probs = jnp.abs(inner_prod) ** 2 - return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + if not is_state_densitymat: + mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] + eigvecs, eigvals = align_eigenvectors(eigs) + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + probs = jnp.abs(inner_prod) ** 2 + return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + else: + n_qubits = n_qubits // 2 + mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + mat_obs = jnp.stack(mat_obs) + dim = 2**n_qubits + rho = state.reshape((dim, dim)) + prod = jnp.matmul(mat_obs, rho) + return jnp.trace(prod, axis1=-2, axis2=-1).real def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: From 4340185a04babb9e8dbc22a0753d822cb19ba20e Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 12:52:37 +0100 Subject: [PATCH 19/57] change dagger --- horqrux/apply.py | 15 ++++++++------- horqrux/utils.py | 13 ++++++++++++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 9e8de9f..f6dd3eb 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -48,19 +48,20 @@ def apply_operator( operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] n_qubits_op = int(np.log2(operator.shape[1])) - operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) - op_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) + op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int)) # Apply operator - + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_dims, state_dims)) if not is_state_densitymat: return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) - operator_dagger = _dagger(operator_reshaped) - # Apply operator to density matrix: ρ' = O ρ O† + state = _dagger(state) + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, op_in_dims)) + state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + state = _dagger(state) - state = jnp.tensordot(a=operator_dagger, b=state, axes=(op_dims, state_dims)) return state diff --git a/horqrux/utils.py b/horqrux/utils.py index c119956..e047a26 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -73,7 +73,18 @@ class ForwardMode(StrEnum): def _dagger(operator: Array) -> Array: - return jnp.conjugate(operator.T) + # If the operator is a tensor with repeated 2D axes + if operator.ndim > 2: + # Conjugate and swap the last two axes + conjugated = operator.conj() + + # Create the transpose axes: swap pairs of indices + half = operator.ndim // 2 + axes = tuple(range(half, operator.ndim)) + tuple(range(half)) + return jnp.transpose(conjugated, axes) + else: + # For standard matrices, use conjugate transpose + return jnp.conjugate(operator.T) def _unitary(generator: Array, theta: float) -> Array: From 3f1645f7ee4525e1ca8a35653240eee7cac6b14e Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 14:19:56 +0100 Subject: [PATCH 20/57] apply_gate correct on density matrices before permutation --- horqrux/apply.py | 13 ++++++------- horqrux/utils.py | 7 ++++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index f6dd3eb..f3061c9 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -48,20 +48,19 @@ def apply_operator( operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] n_qubits_op = int(np.log2(operator.shape[1])) - operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) - op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) - op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int)) + operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) + op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int)) # Apply operator - state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims)) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) if not is_state_densitymat: return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) # Apply operator to density matrix: ρ' = O ρ O† state = _dagger(state) - state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, op_in_dims)) - state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) state = _dagger(state) - + state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) return state diff --git a/horqrux/utils.py b/horqrux/utils.py index e047a26..17d63de 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -32,9 +32,10 @@ def density_mat(state: Array) -> Array: Returns: State: Density matrix representation. """ - n_qubits = len(state.shape) - state = state.reshape(2**n_qubits) - return jnp.einsum("i,j->ij", state, state.conj()).reshape(tuple(2 for _ in range(2 * n_qubits))) + # Expand dimensions to enable broadcasting + ket = jnp.expand_dims(state, axis=tuple(range(state.ndim, 2 * state.ndim))) + bra = jnp.conj(jnp.expand_dims(state, axis=tuple(range(state.ndim)))) + return ket * bra class StrEnum(str, Enum): From 58d2f0befc8348650e49e3020230448acf00c99a Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 14:57:18 +0100 Subject: [PATCH 21/57] fix permutation with density matrices --- horqrux/apply.py | 14 ++++++++++++-- horqrux/utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index f3061c9..dfb0a37 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -12,7 +12,15 @@ from horqrux.primitive import Primitive from .noise import NoiseProtocol -from .utils import OperationType, State, _controlled, _dagger, density_mat, is_controlled +from .utils import ( + OperationType, + State, + _controlled, + _dagger, + density_mat, + is_controlled, + permute_basis, +) def apply_operator( @@ -55,12 +63,14 @@ def apply_operator( state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) if not is_state_densitymat: + # only return O ρ with correctly swaped axis for tensordot return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) # Apply operator to density matrix: ρ' = O ρ O† state = _dagger(state) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) state = _dagger(state) - state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + support_perm = target + tuple(set(new_state_dims) - set(target)) + state = permute_basis(state, support_perm, True) return state diff --git a/horqrux/utils.py b/horqrux/utils.py index 17d63de..934b16a 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -38,6 +38,30 @@ def density_mat(state: Array) -> Array: return ket * bra +def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> Array: + """Takes an operator tensor and permutes the rows and + columns according to the order of the qubit support. + + Args: + operator (Tensor): Operator to permute over. + qubit_support (tuple): Qubit support. + inv (bool): Applies the inverse permutation instead. + + Returns: + Tensor: Permuted operator. + """ + ordered_support = np.argsort(qubit_support) + ranked_support = np.argsort(ordered_support) + n_qubits = len(qubit_support) + if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))): + return operator + + perm = tuple(ranked_support) + tuple(ranked_support + n_qubits) + if inv: + perm = np.argsort(perm) + return jnp.transpose(operator, perm) + + class StrEnum(str, Enum): def __str__(self) -> str: """Used when dumping enum fields in a schema.""" From 26ecf21ee47bb311e9ce8be8e178820fb34345e1 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 15:23:10 +0100 Subject: [PATCH 22/57] fix permute_basis --- horqrux/apply.py | 5 +++-- horqrux/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index dfb0a37..2d79d01 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -61,15 +61,16 @@ def apply_operator( op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int)) # Apply operator state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - new_state_dims = tuple(i for i in range(len(state_dims))) + if not is_state_densitymat: # only return O ρ with correctly swaped axis for tensordot + new_state_dims = tuple(i for i in range(len(state_dims))) return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) # Apply operator to density matrix: ρ' = O ρ O† state = _dagger(state) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) state = _dagger(state) - support_perm = target + tuple(set(new_state_dims) - set(target)) + support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) state = permute_basis(state, support_perm, True) return state diff --git a/horqrux/utils.py b/horqrux/utils.py index 934b16a..d89d84e 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -53,13 +53,13 @@ def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> A ordered_support = np.argsort(qubit_support) ranked_support = np.argsort(ordered_support) n_qubits = len(qubit_support) - if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))): + if all(a == b for a, b in zip(ranked_support, tuple(range(n_qubits)))): return operator perm = tuple(ranked_support) + tuple(ranked_support + n_qubits) if inv: perm = np.argsort(perm) - return jnp.transpose(operator, perm) + return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm) class StrEnum(str, Enum): From 3c83b17bc2a3dde6782073b20d5922ae279f9256 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 15:58:43 +0100 Subject: [PATCH 23/57] adding tests with dm --- horqrux/apply.py | 15 +++++++++++++-- tests/test_gates.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 2d79d01..d56f511 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -61,17 +61,28 @@ def apply_operator( op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int)) # Apply operator state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - + if not is_state_densitymat: # only return O ρ with correctly swaped axis for tensordot + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) + print("init 1", state.reshape((4, 4)).round(4)) + state = permute_basis(state, support_perm, False) + print("permute 1", state.reshape((4, 4)).round(4)) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) # Apply operator to density matrix: ρ' = O ρ O† + print("einsum 1", state.reshape((4, 4)).round(4)) state = _dagger(state) + print("dag 1", state.reshape((4, 4)).round(4)) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) + print("einsum 2", state.reshape((4, 4)).round(4)) state = _dagger(state) - support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) + print("dag 2", state.reshape((4, 4)).round(4)) + state = permute_basis(state, support_perm, True) + print("permute", state.reshape((4, 4)).round(4)) return state diff --git a/tests/test_gates.py b/tests/test_gates.py index 4c44cd3..90f9732 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -10,7 +10,7 @@ from horqrux.apply import apply_gate, apply_operator from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z -from horqrux.utils import equivalent_state, product_state, random_state +from horqrux.utils import density_mat, equivalent_state, product_state, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -22,11 +22,22 @@ def test_primitive(gate_fn: Callable) -> None: target = np.random.randint(0, MAX_QUBITS) gate = gate_fn(target) orig_state = random_state(MAX_QUBITS) + assert len(orig_state) == 2 state = apply_gate(orig_state, gate) assert jnp.allclose( apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(), + gate.target[0], + gate.control[0], + is_state_densitymat=True, + ) + assert jnp.allclose(dm, density_mat(state)) + @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) def test_controlled_primitive(gate_fn: Callable) -> None: From b50ab1fb87ee68cdeddbafdf2e4af4eb759fe14a Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 15:59:21 +0100 Subject: [PATCH 24/57] adding test shots --- tests/test_shots.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_shots.py b/tests/test_shots.py index 85e98db..6dd37d5 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -6,6 +6,7 @@ from horqrux import expectation, random_state from horqrux.parametric import RX from horqrux.primitive import Z +from horqrux.utils import density_mat N_QUBITS = 2 SHOTS_ATOL = 0.01 @@ -28,10 +29,25 @@ def shots(x): state, ops, observables, values, diff_mode="gpsr", forward_mode="shots", n_shots=N_SHOTS ) + def shots_dm(x): + values = {"theta": x} + return expectation( + density_mat(state), + ops, + observables, + values, + diff_mode="gpsr", + is_state_densitymat=True, + forward_mode="shots", + n_shots=N_SHOTS, + ) + exp_exact = exact(x) exp_shots = shots(x) + exp_shots_dm = shots_dm(x) assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) + assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum()) From b4e03d9b77211205926d8bd4c2fcbce81af58c65 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 16:24:35 +0100 Subject: [PATCH 25/57] add permute_basis at beginning of apply --- horqrux/apply.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index d56f511..387d663 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -60,29 +60,24 @@ def apply_operator( op_out_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int)) # Apply operator - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - + new_state_dims = tuple(i for i in range(len(state_dims))) if not is_state_densitymat: # only return O ρ with correctly swaped axis for tensordot state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - new_state_dims = tuple(i for i in range(len(state_dims))) return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + + # Apply operator to density matrix: ρ' = O ρ O† + support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) - print("init 1", state.reshape((4, 4)).round(4)) + # print("init 1", state.reshape((4, 4)).round(4)) state = permute_basis(state, support_perm, False) - print("permute 1", state.reshape((4, 4)).round(4)) - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - # Apply operator to density matrix: ρ' = O ρ O† - print("einsum 1", state.reshape((4, 4)).round(4)) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims)) + state = _dagger(state) - print("dag 1", state.reshape((4, 4)).round(4)) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) - print("einsum 2", state.reshape((4, 4)).round(4)) state = _dagger(state) - print("dag 2", state.reshape((4, 4)).round(4)) state = permute_basis(state, support_perm, True) - print("permute", state.reshape((4, 4)).round(4)) return state From 3cdc5f187e59375827fb6b99cc0d2d5f2bb19953 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 16:42:49 +0100 Subject: [PATCH 26/57] using jax lax transpose --- horqrux/apply.py | 2 -- horqrux/utils.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 387d663..08d8c2c 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -67,9 +67,7 @@ def apply_operator( return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) # Apply operator to density matrix: ρ' = O ρ O† - support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) - # print("init 1", state.reshape((4, 4)).round(4)) state = permute_basis(state, support_perm, False) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims)) diff --git a/horqrux/utils.py b/horqrux/utils.py index d89d84e..3d5f369 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -59,7 +59,8 @@ def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> A perm = tuple(ranked_support) + tuple(ranked_support + n_qubits) if inv: perm = np.argsort(perm) - return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm) + return jax.lax.transpose(operator, perm) + # return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm) class StrEnum(str, Enum): From 304cdf4d72f8cb8174f9d58e77e154c53ccf3c61 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 16:45:06 +0100 Subject: [PATCH 27/57] adding test dm parameteric --- tests/test_gates.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_gates.py b/tests/test_gates.py index 90f9732..3685889 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -64,6 +64,16 @@ def test_parametric(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(values), + gate.target[0], + gate.control[0], + is_state_densitymat=True, + ) + assert jnp.allclose(dm, density_mat(state)) + @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) def test_controlled_parametric(gate_fn: Callable) -> None: From aea2d339b2b0f5c14a71bd8db363657e56e09f16 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 16:49:11 +0100 Subject: [PATCH 28/57] fix controlled ops with dm --- horqrux/apply.py | 2 +- tests/test_gates.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 08d8c2c..4908b8e 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -67,7 +67,7 @@ def apply_operator( return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) # Apply operator to density matrix: ρ' = O ρ O† - support_perm = target + tuple(set(tuple(range(state.ndim // 2))) - set(target)) + support_perm = state_dims + tuple(set(tuple(range(state.ndim // 2))) - set(state_dims)) state = permute_basis(state, support_perm, False) state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims)) diff --git a/tests/test_gates.py b/tests/test_gates.py index 3685889..d90783b 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -52,6 +52,16 @@ def test_controlled_primitive(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(), + gate.target[0], + gate.control[0], + is_state_densitymat=True, + ) + assert jnp.allclose(dm, density_mat(state)) + @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) def test_parametric(gate_fn: Callable) -> None: @@ -89,6 +99,16 @@ def test_controlled_parametric(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(values), + gate.target[0], + gate.control[0], + is_state_densitymat=True, + ) + assert jnp.allclose(dm, density_mat(state)) + @pytest.mark.parametrize( ["bitstring", "expected_state"], From 819833fff20c4d2301068dd52e151759e29d1e2d Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 16:59:28 +0100 Subject: [PATCH 29/57] checking tests expectation work for dm --- tests/test_shots.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_shots.py b/tests/test_shots.py index 6dd37d5..660f089 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -22,6 +22,10 @@ def test_shots() -> None: def exact(x): values = {"theta": x} return expectation(state, ops, observables, values, diff_mode="ad") + + def exact_dm(x): + values = {"theta": x} + return expectation(density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True) def shots(x): values = {"theta": x} @@ -43,6 +47,9 @@ def shots_dm(x): ) exp_exact = exact(x) + exp_exact_dm = exact_dm(x) + assert jnp.allclose(exp_exact, exp_exact_dm) + exp_shots = shots(x) exp_shots_dm = shots_dm(x) From bc479ba7abdba66af623d265db14e6bfb5f3f07b Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 17:06:20 +0100 Subject: [PATCH 30/57] separate apply_operator with a density matrix version --- horqrux/apply.py | 53 +++++++++++++++++++++++++++++++++++---------- tests/test_gates.py | 14 +++++------- tests/test_shots.py | 6 +++-- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 4908b8e..d320c21 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -28,7 +28,6 @@ def apply_operator( operator: Array, target: Tuple[int, ...], control: Tuple[int | None, ...], - is_state_densitymat: bool = False, ) -> State: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. @@ -56,23 +55,51 @@ def apply_operator( operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] n_qubits_op = int(np.log2(operator.shape[1])) - operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) - op_out_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) - op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int)) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) # Apply operator new_state_dims = tuple(i for i in range(len(state_dims))) - if not is_state_densitymat: - # only return O ρ with correctly swaped axis for tensordot - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims)) - return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims)) + return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + + +def apply_operator_dm( + state: State, + operator: Array, + target: Tuple[int, ...], + control: Tuple[int | None, ...], +) -> State: + """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix + of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. + In case of a controlled operation, the 'operator' array will be embedded into a controlled array. + + Arguments: + state: Density matrix to operate on. + operator: Array to contract over 'state'. + target: Tuple of target qubits on which to apply the 'operator' to. + control: Tuple of control qubits. + is_state_densitymat: Whether the state is provided as a density matrix. + + Returns: + Density matrix after applying 'operator'. + """ + state_dims: Tuple[int, ...] = target + if is_controlled(control): + operator = _controlled(operator, len(control)) + state_dims = (*control, *target) # type: ignore[arg-type] + n_qubits_op = int(np.log2(operator.shape[1])) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) + op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int)) + new_state_dims = tuple(i for i in range(len(state_dims))) # Apply operator to density matrix: ρ' = O ρ O† support_perm = state_dims + tuple(set(tuple(range(state.ndim // 2))) - set(state_dims)) state = permute_basis(state, support_perm, False) - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims)) + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, new_state_dims)) state = _dagger(state) - state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims)) + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, op_in_dims)) state = _dagger(state) state = permute_basis(state, support_perm, True) @@ -109,7 +136,11 @@ def apply_operator_with_noise( noise: NoiseProtocol, is_state_densitymat: bool = False, ) -> State: - state_gate = apply_operator(state, operator, target, control, is_state_densitymat) + state_gate = ( + apply_operator(state, operator, target, control) + if not is_state_densitymat + else apply_operator_dm(state, operator, target, control) + ) if len(noise) == 0: return state_gate else: diff --git a/tests/test_gates.py b/tests/test_gates.py index d90783b..bca0e40 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -7,7 +7,7 @@ import pytest from jax import Array -from horqrux.apply import apply_gate, apply_operator +from horqrux.apply import apply_gate, apply_operator, apply_operator_dm from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import density_mat, equivalent_state, product_state, random_state @@ -29,12 +29,11 @@ def test_primitive(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator( + dm = apply_operator_dm( density_mat(orig_state), gate.unitary(), gate.target[0], gate.control[0], - is_state_densitymat=True, ) assert jnp.allclose(dm, density_mat(state)) @@ -53,12 +52,11 @@ def test_controlled_primitive(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator( + dm = apply_operator_dm( density_mat(orig_state), gate.unitary(), gate.target[0], gate.control[0], - is_state_densitymat=True, ) assert jnp.allclose(dm, density_mat(state)) @@ -75,12 +73,11 @@ def test_parametric(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator( + dm = apply_operator_dm( density_mat(orig_state), gate.unitary(values), gate.target[0], gate.control[0], - is_state_densitymat=True, ) assert jnp.allclose(dm, density_mat(state)) @@ -100,12 +97,11 @@ def test_controlled_parametric(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator( + dm = apply_operator_dm( density_mat(orig_state), gate.unitary(values), gate.target[0], gate.control[0], - is_state_densitymat=True, ) assert jnp.allclose(dm, density_mat(state)) diff --git a/tests/test_shots.py b/tests/test_shots.py index 660f089..1c3edb7 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -22,10 +22,12 @@ def test_shots() -> None: def exact(x): values = {"theta": x} return expectation(state, ops, observables, values, diff_mode="ad") - + def exact_dm(x): values = {"theta": x} - return expectation(density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True) + return expectation( + density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True + ) def shots(x): values = {"theta": x} From 5666c69632010333c15e72e60637ad6b1572e851 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Thu, 12 Dec 2024 15:27:15 +0100 Subject: [PATCH 31/57] add sampling from density mat --- horqrux/api.py | 13 ++++++++----- tests/test_noise.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 76990b6..f2763e8 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -34,13 +34,16 @@ def sample( if n_shots < 1: raise ValueError("You can only call sample with n_shots>0.") + output_circuit = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) if is_state_densitymat: - raise NotImplementedError("Sampling with density matrices is not yet supported!") - - wf = apply_gate(state, gates, values) - probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() + n_qubits = len(state.shape) // 2 + d = 2**n_qubits + probs = jnp.diagonal(output_circuit.reshape((d, d))).real + else: + n_qubits = len(state.shape) + probs = jnp.abs(jnp.float_power(output_circuit, 2.0)).ravel() key = jax.random.PRNGKey(0) - n_qubits = len(state.shape) + # JAX handles pseudo random number generation by tracking an explicit state via a random key # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html samples = jax.vmap( diff --git a/tests/test_noise.py b/tests/test_noise.py index 686d71b..159f20b 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -6,11 +6,12 @@ import numpy as np import pytest +from horqrux.api import run, sample from horqrux.apply import apply_gate from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z -from horqrux.utils import density_mat, random_state +from horqrux.utils import density_mat, product_state, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -58,6 +59,10 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: output_dm2 = apply_gate(orig_dm, noisy_gate, is_state_densitymat=True) assert jnp.allclose(output_dm2, output_dm) + perfect_gate = gate_fn(target) + perfect_output = density_mat(apply_gate(orig_state, perfect_gate)) + assert not jnp.allclose(perfect_output, output_dm) + @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) @pytest.mark.parametrize("noise_type", ALL_NOISES) @@ -79,3 +84,35 @@ def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: output_dm2 = apply_gate(orig_dm, noisy_gate, values, is_state_densitymat=True) assert jnp.allclose(output_dm2, output_dm) + + perfect_gate = gate_fn("theta", target) + perfect_output = density_mat(apply_gate(orig_state, perfect_gate, values)) + assert not jnp.allclose(perfect_output, output_dm) + + +def simple_depolarizing_test() -> None: + noise = (NoiseInstance(NoiseType.DEPOLARIZING, 0.1),) + ops = [X(0, noise=noise), X(1)] + state = product_state("00") + state_output = run(ops, state) + + assert jnp.allclose( + state_output, + jnp.array( + [ + [ + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + [[0.0 - 0.0j, 0.06666667 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + ], + [ + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.93333333 - 0.0j]], + ], + ], + dtype=jnp.complex128, + ), + ) + + sampling_output = sample(density_mat(state), ops, is_state_densitymat=True) + assert "11" in sampling_output.keys() + assert "01" in sampling_output.keys() From 47328f201078a3aad1c58e01e3806cc6051edd80 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Thu, 12 Dec 2024 15:57:40 +0100 Subject: [PATCH 32/57] fix expectation with density matrices --- horqrux/api.py | 28 ++++++++++++++++++++-------- tests/test_noise.py | 11 +++++++++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index f2763e8..7a2d6af 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -11,7 +11,7 @@ from horqrux.adjoint import adjoint_expectation from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive -from horqrux.shots import finite_shots_fwd +from horqrux.shots import finite_shots_fwd, observable_to_matrix from horqrux.utils import DiffMode, ForwardMode, OperationType, inner @@ -70,12 +70,26 @@ def __ad_expectation_single_observable( Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - out_state = apply_gate(state, gates, values, OperationType.UNITARY) - projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) - if not is_state_densitymat: - return inner(out_state, projected_state).real + out_state = apply_gate( + state, gates, values, OperationType.UNITARY, is_state_densitymat=is_state_densitymat + ) + # in case we have noisy simulations + out_state_densitymat = is_state_densitymat or (out_state.shape != state.shape) - raise NotImplementedError("Expectation from density matrices is not yet supported!") + if not out_state_densitymat: + projected_state = apply_gate( + out_state, + observable, + values, + OperationType.UNITARY, + is_state_densitymat=out_state_densitymat, + ) + return inner(out_state, projected_state).real + n_qubits = len(out_state.shape) // 2 + mat_obs = observable_to_matrix(observable, n_qubits) + d = 2**n_qubits + prod = jnp.matmul(mat_obs, out_state.reshape((d, d))) + return jnp.trace(prod, axis1=-2, axis2=-1).real def ad_expectation( @@ -125,8 +139,6 @@ def expectation( ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore - # if is_state_densitymat: - # raise NotImplementedError("Expectation with density matrices is not yet supported!") return finite_shots_fwd( state, gates, diff --git a/tests/test_noise.py b/tests/test_noise.py index 159f20b..d7046d2 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from horqrux.api import run, sample +from horqrux.api import expectation, run, sample from horqrux.apply import apply_gate from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ @@ -96,6 +96,7 @@ def simple_depolarizing_test() -> None: state = product_state("00") state_output = run(ops, state) + # test run assert jnp.allclose( state_output, jnp.array( @@ -113,6 +114,12 @@ def simple_depolarizing_test() -> None: ), ) - sampling_output = sample(density_mat(state), ops, is_state_densitymat=True) + # test sampling + dm_state = density_mat(state) + sampling_output = sample(dm_state, ops, is_state_densitymat=True) assert "11" in sampling_output.keys() assert "01" in sampling_output.keys() + + # test expectation + exp_dm = expectation(dm_state, ops, [Z(0)], {}) + assert jnp.allclose(exp_dm, jnp.array([-0.86666667], dtype=jnp.float64)) From 40d7cf80a796cb2cadbad993bd44a32a28e86290 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Thu, 12 Dec 2024 16:01:11 +0100 Subject: [PATCH 33/57] test also shots with noise --- tests/test_noise.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_noise.py b/tests/test_noise.py index d7046d2..85bf781 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -11,7 +11,7 @@ from horqrux.noise import NoiseInstance, NoiseType from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z -from horqrux.utils import density_mat, product_state, random_state +from horqrux.utils import ForwardMode, density_mat, product_state, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -123,3 +123,9 @@ def simple_depolarizing_test() -> None: # test expectation exp_dm = expectation(dm_state, ops, [Z(0)], {}) assert jnp.allclose(exp_dm, jnp.array([-0.86666667], dtype=jnp.float64)) + + # test shots expectation + exp_dm_shots = expectation( + dm_state, ops, [Z(0)], {}, forward_mode=ForwardMode.SHOTS, n_shots=1000 + ) + assert jnp.allclose(exp_dm, exp_dm_shots, atol=1e-02) From 54da26fe8582875aa8862690141b92606dc5a6fe Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Thu, 12 Dec 2024 17:59:42 +0100 Subject: [PATCH 34/57] adding docs --- docs/noise.md | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 2 + 2 files changed, 103 insertions(+) create mode 100644 docs/noise.md diff --git a/docs/noise.md b/docs/noise.md new file mode 100644 index 0000000..8526738 --- /dev/null +++ b/docs/noise.md @@ -0,0 +1,101 @@ +## Digital Noise + +In the description of closed quantum systems, a pure state vector is used to represent the complete quantum state. Thus, pure quantum states are represented by state vectors $|\psi \rangle $. + +However, this description is not sufficient to study open quantum systems. When the system interacts with its environment, quantum systems can be in a mixed state, where quantum information is no longer entirely contained in a single state vector but is distributed probabilistically. + +To address these more general cases, we consider a probabilistic combination $p_i$ of possible pure states $|\psi_i \rangle$. Thus, the system is described by a density matrix $\rho$ defined as follows: + +$$ +\rho = \sum_i p_i |\psi_i\rangle \langle \psi_i| +$$ + +The transformations of the density operator of an open quantum system interacting with its environment (noise) are represented by the super-operator $S: \rho \rightarrow S(\rho)$, often referred to as a quantum channel. +Quantum channels, due to the conservation of the probability distribution, must be CPTP (Completely Positive and Trace Preserving). Any CPTP super-operator can be written in the following form: + +$$ +S(\rho) = \sum_i K_i \rho K^{\dagger}_i +$$ + +Where $K_i$ are the Kraus operators, and satisfy the property $\sum_i K_i K^{\dagger}_i = \mathbb{I}$. As noise is the result of system interactions with its environment, it is therefore possible to simulate noisy quantum circuit with noise type gates. + +Thus, `horqrux` implements a large selection of single qubit noise gates such as: + +- The bit flip channel defined as: $\textbf{BitFlip}(\rho) =(1-p) \rho + p X \rho X^{\dagger}$ +- The phase flip channel defined as: $\textbf{PhaseFlip}(\rho) = (1-p) \rho + p Z \rho Z^{\dagger}$ +- The depolarizing channel defined as: $\textbf{Depolarizing}(\rho) = (1-p) \rho + \frac{p}{3} (X \rho X^{\dagger} + Y \rho Y^{\dagger} + Z \rho Z^{\dagger})$ +- The pauli channel defined as: $\textbf{PauliChannel}(\rho) = (1-p_x-p_y-p_z) \rho + + p_x X \rho X^{\dagger} + + p_y Y \rho Y^{\dagger} + + p_z Z \rho Z^{\dagger}$ +- The amplitude damping channel defined as: $\textbf{AmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$ + with: + $\begin{equation*} + K_{0} \ =\begin{pmatrix} + 1 & 0\\ + 0 & \sqrt{1-\ \gamma } + \end{pmatrix} ,\ K_{1} \ =\begin{pmatrix} + 0 & \sqrt{\ \gamma }\\ + 0 & 0 + \end{pmatrix} + \end{equation*}$ +- The phase damping channel defined as: $\textbf{PhaseDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$ + with: + $\begin{equation*} + K_{0} \ =\begin{pmatrix} + 1 & 0\\ + 0 & \sqrt{1-\ \gamma } + \end{pmatrix}, \ K_{1} \ =\begin{pmatrix} + 0 & 0\\ + 0 & \sqrt{\ \gamma } + \end{pmatrix} + \end{equation*}$ +* The generalize amplitude damping channel is defined as: $\textbf{GeneralizedAmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger} + K_2 \rho K_2^{\dagger} + K_3 \rho K_3^{\dagger}$ + with: +$\begin{cases} +K_{0} \ =\sqrt{p} \ \begin{pmatrix} +1 & 0\\ +0 & \sqrt{1-\ \gamma } +\end{pmatrix} ,\ K_{1} \ =\sqrt{p} \ \begin{pmatrix} +0 & 0\\ +0 & \sqrt{\ \gamma } +\end{pmatrix} \\ +K_{2} \ =\sqrt{1\ -p} \ \begin{pmatrix} +\sqrt{1-\ \gamma } & 0\\ +0 & 1 +\end{pmatrix} ,\ K_{3} \ =\sqrt{1-p} \ \begin{pmatrix} +0 & 0\\ +\sqrt{\ \gamma } & 0 +\end{pmatrix} +\end{cases}$ + +Noise protocols can be added to gates by instantiating `NoiseInstance` providing the `NoiseType` and the `error_probability` (either float or tuple of float): + +```python exec="on" source="material-block" html="1" +from horqrux.noise import NoiseInstance, NoiseType + +noise_prob = 0.3 +AmpD = NoiseInstance(NoiseType.AMPLITUDE_DAMPING, error_probability=noise_prob) + +``` + +Then a gate can be instantiated by providing a tuple of `NoiseInstance` instances. Let’s show this through the simulation of a realistic $X$ gate. + +We know that an $X$ gate flips the state of the qubit, for instance $X|0\rangle = |1\rangle$. In practice, it's common for the target qubit to stay in its original state after applying $X$ due to the interactions between it and its environment. The possibility of failure can be represented by a `BitFlip` `NoiseInstance`, which flips the state again after the application of the $X$ gate, returning it to its original state with a probability `1 - gate_fidelity`. + +```python exec="on" source="material-block" +from horqrux.api import sample +from horqrux.noise import NoiseInstance, NoiseType +from horqrux.utils import density_mat, product_state +from horqrux.primitive import X + +noise = (NoiseInstance(NoiseType.BITFLIP, 0.1),) +ops = [X(0)] +noisy_ops = [X(0, noise=noise)] +state = product_state("0") + +noiseless_samples = sample(state, ops) +noisy_samples = sample(density_mat(state), noisy_ops, is_state_densitymat=True) +print("Noiseless samples", noiseless_samples) # markdown-exec: hide +print("Noiseless samples", noisy_samples) # markdown-exec: hide +``` diff --git a/mkdocs.yml b/mkdocs.yml index e3ac86a..0f214f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -6,6 +6,8 @@ nav: - horqrux in a nutshell: index.md - Contribute: CONTRIBUTING.md - Code of Conduct: CODE_OF_CONDUCT.md + - Advanced Features: + - Noisy simulation: noise.md theme: name: material From dd2625da183e40f531603b87ef69371947a8fbe3 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Fri, 13 Dec 2024 08:15:20 +0100 Subject: [PATCH 35/57] fix lint --- docs/noise.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/noise.md b/docs/noise.md index 8526738..3bb2f0b 100644 --- a/docs/noise.md +++ b/docs/noise.md @@ -79,7 +79,7 @@ AmpD = NoiseInstance(NoiseType.AMPLITUDE_DAMPING, error_probability=noise_prob) ``` -Then a gate can be instantiated by providing a tuple of `NoiseInstance` instances. Let’s show this through the simulation of a realistic $X$ gate. +Then a gate can be instantiated by providing a tuple of `NoiseInstance` instances. Let’s show this through the simulation of a realistic $X$ gate. We know that an $X$ gate flips the state of the qubit, for instance $X|0\rangle = |1\rangle$. In practice, it's common for the target qubit to stay in its original state after applying $X$ due to the interactions between it and its environment. The possibility of failure can be represented by a `BitFlip` `NoiseInstance`, which flips the state again after the application of the $X$ gate, returning it to its original state with a probability `1 - gate_fidelity`. From 3f90c128c1b53ab1d6f9e3385a295d385f8862d0 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Fri, 13 Dec 2024 08:28:55 +0100 Subject: [PATCH 36/57] rm circuit methods in favor of api --- docs/index.md | 5 +++-- horqrux/__init__.py | 4 ++-- horqrux/circuit.py | 51 +------------------------------------------ tests/test_adjoint.py | 5 ++--- 4 files changed, 8 insertions(+), 57 deletions(-) diff --git a/docs/index.md b/docs/index.md index 87c805e..1111304 100644 --- a/docs/index.md +++ b/docs/index.md @@ -110,10 +110,11 @@ from operator import add from typing import Any, Callable from uuid import uuid4 -from horqrux.circuit import QuantumCircuit, hea, expectation +from horqrux import expectation +from horqrux import Z, RX, RY, NOT, zero_state, apply_gate +from horqrux.circuit import QuantumCircuit, hea from horqrux.primitive import Primitive from horqrux.parametric import Parametric -from horqrux import Z, RX, RY, NOT, zero_state, apply_gate from horqrux.utils import DiffMode diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 2569a54..3e31d9d 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from .api import expectation, run +from .api import expectation, run, sample from .apply import apply_gate, apply_operator -from .circuit import QuantumCircuit, sample +from .circuit import QuantumCircuit from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ae4bb03..a398106 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -1,20 +1,16 @@ from __future__ import annotations -from collections import Counter from dataclasses import dataclass, field from typing import Any, Callable from uuid import uuid4 -import jax -import jax.numpy as jnp from jax import Array from jax.tree_util import register_pytree_node_class -from horqrux.adjoint import ad_expectation, adjoint_expectation from horqrux.apply import apply_gate from horqrux.parametric import RX, RY, Parametric from horqrux.primitive import NOT, Primitive -from horqrux.utils import DiffMode, zero_state +from horqrux.utils import zero_state @register_pytree_node_class @@ -113,48 +109,3 @@ def hea( gates += ops return gates - - -def expectation( - state: Array, - gates: list[Primitive], - observable: list[Primitive], - values: dict[str, float], - diff_mode: DiffMode | str = DiffMode.AD, -) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. - """ - if diff_mode == DiffMode.AD: - return ad_expectation(state, gates, observable, values) - else: - return adjoint_expectation(state, gates, observable, values) - - -def sample( - state: Array, - gates: list[Primitive], - values: dict[str, float] = dict(), - n_shots: int = 1000, -) -> Counter: - if n_shots < 1: - raise ValueError("You can only call sample with n_shots>0.") - - wf = apply_gate(state, gates, values) - probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() - key = jax.random.PRNGKey(0) - n_qubits = len(state.shape) - # JAX handles pseudo random number generation by tracking an explicit state via a random key - # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html - samples = jax.vmap( - lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) - )(jax.random.split(key, n_shots)) - - return Counter( - { - format(k, "0{}b".format(n_qubits)): count.item() - for k, count in enumerate(jnp.bincount(samples)) - if count > 0 - } - ) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 27f0164..2df8bdd 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -4,8 +4,7 @@ import numpy as np from jax import Array, grad -from horqrux import random_state -from horqrux.circuit import expectation +from horqrux import expectation, random_state from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z from horqrux.utils import DiffMode @@ -27,7 +26,7 @@ def test_gradcheck() -> None: state = random_state(MAX_QUBITS) def exp_fn(values: dict, diff_mode: DiffMode = "ad") -> Array: - return expectation(state, ops, observable, values, diff_mode) + return expectation(state, ops, observable, values, diff_mode).item() grads_adjoint = grad(exp_fn)(values, "adjoint") grad_ad = grad(exp_fn)(values) From be6cc2fd868c4a16d0e2c9d882b150ba12ba517b Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 11:40:41 +0100 Subject: [PATCH 37/57] rm comment --- horqrux/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/horqrux/utils.py b/horqrux/utils.py index 3d5f369..843a4d6 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -60,7 +60,6 @@ def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> A if inv: perm = np.argsort(perm) return jax.lax.transpose(operator, perm) - # return jnp.moveaxis(operator, source=tuple(range(operator.ndim)), destination=perm) class StrEnum(str, Enum): From e32e288a4782cc45913dd25123ac6a0c7d58048c Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:28:22 +0100 Subject: [PATCH 38/57] change for isdensity --- docs/noise.md | 2 +- horqrux/api.py | 28 +++++++++++++--------------- horqrux/apply.py | 18 +++++++++--------- horqrux/shots.py | 18 +++++++----------- tests/test_noise.py | 6 +++--- tests/test_shots.py | 4 ++-- 6 files changed, 35 insertions(+), 41 deletions(-) diff --git a/docs/noise.md b/docs/noise.md index 3bb2f0b..369b8f4 100644 --- a/docs/noise.md +++ b/docs/noise.md @@ -95,7 +95,7 @@ noisy_ops = [X(0, noise=noise)] state = product_state("0") noiseless_samples = sample(state, ops) -noisy_samples = sample(density_mat(state), noisy_ops, is_state_densitymat=True) +noisy_samples = sample(density_mat(state), noisy_ops, is_density=True) print("Noiseless samples", noiseless_samples) # markdown-exec: hide print("Noiseless samples", noisy_samples) # markdown-exec: hide ``` diff --git a/horqrux/api.py b/horqrux/api.py index 7a2d6af..199c317 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -19,9 +19,9 @@ def run( circuit: GateSequence, state: Array, values: dict[str, float] = dict(), - is_state_densitymat: bool = False, + is_density: bool = False, ) -> Array: - return apply_gate(state, circuit, values, is_state_densitymat=is_state_densitymat) + return apply_gate(state, circuit, values, is_density=is_density) def sample( @@ -29,13 +29,13 @@ def sample( gates: GateSequence, values: dict[str, float] = dict(), n_shots: int = 1000, - is_state_densitymat: bool = False, + is_density: bool = False, ) -> Counter: if n_shots < 1: raise ValueError("You can only call sample with n_shots>0.") - output_circuit = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) - if is_state_densitymat: + output_circuit = apply_gate(state, gates, values, is_density=is_density) + if is_density: n_qubits = len(state.shape) // 2 d = 2**n_qubits probs = jnp.diagonal(output_circuit.reshape((d, d))).real @@ -64,17 +64,15 @@ def __ad_expectation_single_observable( gates: GateSequence, observable: Primitive, values: dict[str, float], - is_state_densitymat: bool = False, + is_density: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - out_state = apply_gate( - state, gates, values, OperationType.UNITARY, is_state_densitymat=is_state_densitymat - ) + out_state = apply_gate(state, gates, values, OperationType.UNITARY, is_density=is_density) # in case we have noisy simulations - out_state_densitymat = is_state_densitymat or (out_state.shape != state.shape) + out_state_densitymat = is_density or (out_state.shape != state.shape) if not out_state_densitymat: projected_state = apply_gate( @@ -82,7 +80,7 @@ def __ad_expectation_single_observable( observable, values, OperationType.UNITARY, - is_state_densitymat=out_state_densitymat, + is_density=out_state_densitymat, ) return inner(out_state, projected_state).real n_qubits = len(out_state.shape) // 2 @@ -97,14 +95,14 @@ def ad_expectation( gates: GateSequence, observables: list[Primitive], values: dict[str, float], - is_state_densitymat: bool = False, + is_density: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ outputs = [ - __ad_expectation_single_observable(state, gates, observable, values, is_state_densitymat) + __ad_expectation_single_observable(state, gates, observable, values, is_density) for observable in observables ] return jnp.stack(outputs) @@ -118,7 +116,7 @@ def expectation( diff_mode: DiffMode = DiffMode.AD, forward_mode: ForwardMode = ForwardMode.EXACT, n_shots: Optional[int] = None, - is_state_densitymat: bool = False, + is_density: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ @@ -145,6 +143,6 @@ def expectation( observables, values, n_shots=n_shots, - is_state_densitymat=is_state_densitymat, + is_density=is_density, key=key, ) diff --git a/horqrux/apply.py b/horqrux/apply.py index d320c21..20cb5aa 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -45,7 +45,7 @@ def apply_operator( operator: Array to contract over 'state'. target: Tuple of target qubits on which to apply the 'operator' to. control: Tuple of control qubits. - is_state_densitymat: Whether the state is provided as a density matrix. + is_density: Whether the state is provided as a density matrix. Returns: State after applying 'operator'. @@ -78,7 +78,7 @@ def apply_operator_dm( operator: Array to contract over 'state'. target: Tuple of target qubits on which to apply the 'operator' to. control: Tuple of control qubits. - is_state_densitymat: Whether the state is provided as a density matrix. + is_density: Whether the state is provided as a density matrix. Returns: Density matrix after applying 'operator'. @@ -134,11 +134,11 @@ def apply_operator_with_noise( target: Tuple[int, ...], control: Tuple[int | None, ...], noise: NoiseProtocol, - is_state_densitymat: bool = False, + is_density: bool = False, ) -> State: state_gate = ( apply_operator(state, operator, target, control) - if not is_state_densitymat + if not is_density else apply_operator_dm(state, operator, target, control) ) if len(noise) == 0: @@ -212,7 +212,7 @@ def apply_gate( op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, - is_state_densitymat: bool = False, + is_density: bool = False, ) -> State: """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. Arguments: @@ -222,7 +222,7 @@ def apply_gate( op_type: The type of operation to perform: Unitary, Dagger or Jacobian. group_gates: Group gates together which are acting on the same qubit. merge_ops: Attempt to merge operators acting on the same qubit. - is_state_densitymat: If True, state is provided as a density matrix. + is_density: If True, state is provided as a density matrix. Returns: State or density matrix after applying 'gate'. @@ -244,13 +244,13 @@ def apply_gate( noise = [g.noise for g in gate] has_noise = len(reduce(add, noise)) > 0 - if has_noise and not is_state_densitymat: + if has_noise and not is_density: state = density_mat(state) - is_state_densitymat = True + is_density = True output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), - zip(operator, target, control, noise, (is_state_densitymat,) * len(target)), + zip(operator, target, control, noise, (is_density,) * len(target)), state, ) diff --git a/horqrux/shots.py b/horqrux/shots.py index 614dfd8..a330314 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -40,16 +40,16 @@ def finite_shots_fwd( observables: list[Primitive], values: dict[str, float], n_shots: int = 100, - is_state_densitymat: bool = False, + is_density: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - state = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) + state = apply_gate(state, gates, values, is_density=is_density) n_qubits = len(state.shape) - if not is_state_densitymat: + if not is_density: mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] eigvecs, eigvals = align_eigenvectors(eigs) @@ -110,7 +110,7 @@ def finite_shots_jvp( gates: GateSequence, observable: Primitive, n_shots: int, - is_state_densitymat: bool, + is_density: bool, key: Array, primals: tuple[dict[str, float]], tangents: tuple[dict[str, float]], @@ -127,18 +127,14 @@ def jvp_component(param_name: str, key: Array) -> Array: up_key, down_key = random.split(key) up_val = values.copy() up_val[param_name] = up_val[param_name] + shift - f_up = finite_shots_fwd( - state, gates, observable, up_val, n_shots, is_state_densitymat, up_key - ) + f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, is_density, up_key) down_val = values.copy() down_val[param_name] = down_val[param_name] - shift - f_down = finite_shots_fwd( - state, gates, observable, down_val, n_shots, is_state_densitymat, down_key - ) + f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, is_density, down_key) grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) return grad * tangent_dict[param_name] params_with_keys = zip(values.keys(), random.split(key, len(values))) - fwd = finite_shots_fwd(state, gates, observable, values, n_shots, is_state_densitymat, key) + fwd = finite_shots_fwd(state, gates, observable, values, n_shots, is_density, key) jvp = sum(jvp_component(param, key) for param, key in params_with_keys) return fwd, jvp.reshape(fwd.shape) diff --git a/tests/test_noise.py b/tests/test_noise.py index 85bf781..52acb5e 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -56,7 +56,7 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: orig_dm = density_mat(orig_state) assert len(orig_dm.shape) == dm_shape_len - output_dm2 = apply_gate(orig_dm, noisy_gate, is_state_densitymat=True) + output_dm2 = apply_gate(orig_dm, noisy_gate, is_density=True) assert jnp.allclose(output_dm2, output_dm) perfect_gate = gate_fn(target) @@ -82,7 +82,7 @@ def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: orig_dm = density_mat(orig_state) assert len(orig_dm.shape) == dm_shape_len - output_dm2 = apply_gate(orig_dm, noisy_gate, values, is_state_densitymat=True) + output_dm2 = apply_gate(orig_dm, noisy_gate, values, is_density=True) assert jnp.allclose(output_dm2, output_dm) perfect_gate = gate_fn("theta", target) @@ -116,7 +116,7 @@ def simple_depolarizing_test() -> None: # test sampling dm_state = density_mat(state) - sampling_output = sample(dm_state, ops, is_state_densitymat=True) + sampling_output = sample(dm_state, ops, is_density=True) assert "11" in sampling_output.keys() assert "01" in sampling_output.keys() diff --git a/tests/test_shots.py b/tests/test_shots.py index 1c3edb7..8447f41 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -26,7 +26,7 @@ def exact(x): def exact_dm(x): values = {"theta": x} return expectation( - density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True + density_mat(state), ops, observables, values, diff_mode="ad", is_density=True ) def shots(x): @@ -43,7 +43,7 @@ def shots_dm(x): observables, values, diff_mode="gpsr", - is_state_densitymat=True, + is_density=True, forward_mode="shots", n_shots=N_SHOTS, ) From 37b2a9abc5a5989a9c0bdb4fd0fa0e1e37926c6f Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:31:35 +0100 Subject: [PATCH 39/57] Tuple to tuple --- horqrux/adjoint.py | 6 ++---- horqrux/apply.py | 32 ++++++++++++++++---------------- horqrux/noise.py | 6 +++--- horqrux/parametric.py | 12 ++++++------ horqrux/primitive.py | 24 ++++++++++++------------ horqrux/utils.py | 16 ++++++++-------- horqrux/utils_noise.py | 2 +- 7 files changed, 48 insertions(+), 50 deletions(-) diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index de7f925..d3dc138 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Tuple - from jax import Array, custom_vjp from horqrux.apply import apply_gate @@ -37,14 +35,14 @@ def adjoint_expectation( def adjoint_expectation_single_observable_fwd( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] -) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]: +) -> tuple[Array, tuple[Array, Array, list[Primitive], dict[str, float]]]: out_state = apply_gate(state, gates, values, OperationType.UNITARY) projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) return inner(out_state, projected_state).real, (out_state, projected_state, gates, values) def adjoint_expectation_single_observable_bwd( - res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array + res: tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array ) -> tuple: """Implementation of Algorithm 1 of https://arxiv.org/abs/2009.02823 which computes the vector-jacobian product in O(P) time using O(1) state vectors diff --git a/horqrux/apply.py b/horqrux/apply.py index 20cb5aa..e1f54a0 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -2,7 +2,7 @@ from functools import partial, reduce from operator import add -from typing import Iterable, Tuple +from typing import Iterable import jax import jax.numpy as jnp @@ -26,8 +26,8 @@ def apply_operator( state: State, operator: Array, - target: Tuple[int, ...], - control: Tuple[int | None, ...], + target: tuple[int, ...], + control: tuple[int | None, ...], ) -> State: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. @@ -43,14 +43,14 @@ def apply_operator( Arguments: state: State to operate on. operator: Array to contract over 'state'. - target: Tuple of target qubits on which to apply the 'operator' to. - control: Tuple of control qubits. + target: tuple of target qubits on which to apply the 'operator' to. + control: tuple of control qubits. is_density: Whether the state is provided as a density matrix. Returns: State after applying 'operator'. """ - state_dims: Tuple[int, ...] = target + state_dims: tuple[int, ...] = target if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] @@ -66,8 +66,8 @@ def apply_operator( def apply_operator_dm( state: State, operator: Array, - target: Tuple[int, ...], - control: Tuple[int | None, ...], + target: tuple[int, ...], + control: tuple[int | None, ...], ) -> State: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. @@ -76,14 +76,14 @@ def apply_operator_dm( Arguments: state: Density matrix to operate on. operator: Array to contract over 'state'. - target: Tuple of target qubits on which to apply the 'operator' to. - control: Tuple of control qubits. + target: tuple of target qubits on which to apply the 'operator' to. + control: tuple of control qubits. is_density: Whether the state is provided as a density matrix. Returns: Density matrix after applying 'operator'. """ - state_dims: Tuple[int, ...] = target + state_dims: tuple[int, ...] = target if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] @@ -109,9 +109,9 @@ def apply_operator_dm( def apply_kraus_operator( kraus: Array, state: State, - target: Tuple[int, ...], + target: tuple[int, ...], ) -> State: - state_dims: Tuple[int, ...] = target + state_dims: tuple[int, ...] = target n_qubits = int(np.log2(kraus.size)) kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits))) op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int)) @@ -131,8 +131,8 @@ def apply_kraus_operator( def apply_operator_with_noise( state: State, operator: Array, - target: Tuple[int, ...], - control: Tuple[int | None, ...], + target: tuple[int, ...], + control: tuple[int | None, ...], noise: NoiseProtocol, is_density: bool = False, ) -> State: @@ -227,7 +227,7 @@ def apply_gate( Returns: State or density matrix after applying 'gate'. """ - operator: Tuple[Array, ...] + operator: tuple[Array, ...] noise = list() if isinstance(gate, Primitive): operator_fn = getattr(gate, op_type) diff --git a/horqrux/noise.py b/horqrux/noise.py index e26f2dc..3f2b11d 100644 --- a/horqrux/noise.py +++ b/horqrux/noise.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Iterable, Tuple +from typing import Any, Callable, Iterable from jax import Array from jax.tree_util import register_pytree_node_class @@ -53,7 +53,7 @@ def __iter__(self) -> Iterable: def tree_flatten( self, - ) -> Tuple[Tuple, Tuple[NoiseType, ErrorProbabilities]]: + ) -> tuple[tuple, tuple[NoiseType, ErrorProbabilities]]: children = () aux_data = (self.type, self.error_probability) return (children, aux_data) @@ -71,4 +71,4 @@ def __repr__(self) -> str: return self.type + f"(p={self.error_probability})" -NoiseProtocol = Tuple[NoiseInstance, ...] +NoiseProtocol = tuple[NoiseInstance, ...] diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 184b9f2..6a53cae 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Iterable, Tuple +from typing import Any, Iterable import jax.numpy as jnp from jax import Array @@ -47,7 +47,7 @@ def parse_val(values: dict[str, float] = dict()) -> float: def tree_flatten( # type: ignore[override] self, - ) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, NoiseProtocol, str | float]]: + ) -> tuple[tuple, tuple[str, tuple, tuple, NoiseProtocol, str | float]]: children = () aux_data = ( self.generator_name, @@ -90,7 +90,7 @@ def RX( Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. noise: The noise instance. Defaults to None. @@ -110,7 +110,7 @@ def RY( Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. noise: The noise instance. Defaults to None. @@ -130,7 +130,7 @@ def RZ( Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. noise: The noise instance. Defaults to None. @@ -167,7 +167,7 @@ def PHASE( Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. noise: The noise instance. Defaults to None. diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 830d91b..edb3cd9 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Iterable, Tuple, Union +from typing import Any, Iterable, Union import numpy as np from jax import Array @@ -32,8 +32,8 @@ class Primitive: @staticmethod def parse_idx( - idx: Tuple, - ) -> Tuple: + idx: tuple, + ) -> tuple: if isinstance(idx, (int, np.int64)): return ((idx,),) elif isinstance(idx, tuple): @@ -51,7 +51,7 @@ def __post_init__(self) -> None: def __iter__(self) -> Iterable: return iter((self.generator_name, self.target, self.control, self.noise)) - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits, NoiseProtocol]]: + def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, NoiseProtocol]]: children = () aux_data = (self.generator_name, self.target[0], self.control[0], self.noise) return (children, aux_data) @@ -93,7 +93,7 @@ def I( Example usage: I(1) represents the instruction to apply I to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -113,7 +113,7 @@ def X( Example usage controlled: X(1, 0) represents the instruction to apply CX / CNOT to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -136,7 +136,7 @@ def Y( Example usage controlled: Y(1, 0) represents the instruction to apply CY to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -156,7 +156,7 @@ def Z( Example usage controlled: Z(1, 0) represents the instruction to apply CZ to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -175,7 +175,7 @@ def H( Example usage: H(1) represents the instruction to apply Hadamard to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -194,7 +194,7 @@ def S( Example usage: S(1) represents the instruction to apply S to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -213,7 +213,7 @@ def T( Example usage: T(1) represents the instruction to apply Hadamard to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. @@ -236,7 +236,7 @@ def SWAP( Example usage controlled: SWAP(((0, 1), ), ((2, ))) swaps qubits 0 and 1 with controlled bit 2. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints or None describing the control qubits. Defaults to (None,). noise: The noise instance. Defaults to None. diff --git a/horqrux/utils.py b/horqrux/utils.py index 843a4d6..cb45ba7 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Iterable, Tuple, Union +from typing import Any, Iterable, Union import jax import jax.numpy as jnp @@ -15,10 +15,10 @@ default_dtype = default_complex_dtype() State = ArrayLike -QubitSupport = Tuple[Any, ...] -ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...] -TargetQubits = Tuple[Tuple[int, ...], ...] -ErrorProbabilities = Union[Tuple[float, ...], float] +QubitSupport = tuple[Any, ...] +ControlQubits = tuple[Union[None, tuple[int, ...]], ...] +TargetQubits = tuple[tuple[int, ...], ...] +ErrorProbabilities = Union[tuple[float, ...], float] ATOL = 1e-014 @@ -153,14 +153,14 @@ def zero_state(n_qubits: int) -> Array: return product_state("0" * n_qubits) -def none_like(x: Iterable) -> Tuple[None, ...]: +def none_like(x: Iterable) -> tuple[None, ...]: """Generates a tuple of Nones with equal length to x. Useful for gates with multiple targets but no control. Args: x (Iterable): Iterable to be mimicked. Returns: - Tuple[None, ...]: Tuple of Nones of length x. + tuple[None, ...]: tuple of Nones of length x. """ return tuple(map(lambda _: None, x)) @@ -208,7 +208,7 @@ def uniform_state( return state.reshape([2] * n_qubits) -def is_controlled(qubit_support: Tuple[int | None, ...] | int | None) -> bool: +def is_controlled(qubit_support: tuple[int | None, ...] | int | None) -> bool: if isinstance(qubit_support, int): return True elif isinstance(qubit_support, tuple): diff --git a/horqrux/utils_noise.py b/horqrux/utils_noise.py index 030f467..a0caaa6 100644 --- a/horqrux/utils_noise.py +++ b/horqrux/utils_noise.py @@ -104,7 +104,7 @@ def PauliChannel(error_probability: tuple[float, ...]) -> tuple[Array, ...]: + pz Z \\rho Z^{\\dagger} Args: - error_probability (ErrorProbabilities): Tuple containing probabilities + error_probability (ErrorProbabilities): tuple containing probabilities of X, Y, and Z errors. Raises: From 0936e16583a06592fdad1e8cac53a10633b1cbb7 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:43:14 +0100 Subject: [PATCH 40/57] change from default noise tuple to None --- horqrux/apply.py | 5 +++-- horqrux/noise.py | 7 +++---- horqrux/parametric.py | 12 ++++++------ horqrux/primitive.py | 20 ++++++++++---------- horqrux/utils.py | 2 -- horqrux/utils_noise.py | 4 ++-- 6 files changed, 24 insertions(+), 26 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index e1f54a0..855e348 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -141,7 +141,7 @@ def apply_operator_with_noise( if not is_density else apply_operator_dm(state, operator, target, control) ) - if len(noise) == 0: + if noise is None: return state_gate else: kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise)))) @@ -243,7 +243,8 @@ def apply_gate( operator, target, control = merge_operators(operator, target, control) noise = [g.noise for g in gate] - has_noise = len(reduce(add, noise)) > 0 + # faster way to check has_noise + has_noise = noise != [None] * len(noise) if has_noise and not is_density: state = density_mat(state) is_density = True diff --git a/horqrux/noise.py b/horqrux/noise.py index 3f2b11d..4db2464 100644 --- a/horqrux/noise.py +++ b/horqrux/noise.py @@ -7,7 +7,6 @@ from jax.tree_util import register_pytree_node_class from .utils import ( - ErrorProbabilities, StrEnum, ) from .utils_noise import ( @@ -46,14 +45,14 @@ class NoiseType(StrEnum): @dataclass class NoiseInstance: type: NoiseType - error_probability: ErrorProbabilities + error_probability: tuple[float, ...] | float def __iter__(self) -> Iterable: return iter((self.kraus, self.error_probability)) def tree_flatten( self, - ) -> tuple[tuple, tuple[NoiseType, ErrorProbabilities]]: + ) -> tuple[tuple, tuple[NoiseType, tuple[float, ...] | float]]: children = () aux_data = (self.type, self.error_probability) return (children, aux_data) @@ -71,4 +70,4 @@ def __repr__(self) -> str: return self.type + f"(p={self.error_probability})" -NoiseProtocol = tuple[NoiseInstance, ...] +NoiseProtocol = tuple[NoiseInstance, ...] | None diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 6a53cae..0f37914 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Iterable import jax.numpy as jnp @@ -31,7 +31,7 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport - noise: NoiseProtocol = field(default_factory=tuple) + noise: NoiseProtocol = None param: str | float = "" def __post_init__(self) -> None: @@ -84,7 +84,7 @@ def RX( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: NoiseProtocol = tuple(), + noise: NoiseProtocol = None, ) -> Parametric: """RX gate. @@ -104,7 +104,7 @@ def RY( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: NoiseProtocol = tuple(), + noise: NoiseProtocol = None, ) -> Parametric: """RY gate. @@ -124,7 +124,7 @@ def RZ( param: float | str, target: TargetQubits, control: ControlQubits = (None,), - noise: NoiseProtocol = tuple(), + noise: NoiseProtocol = None, ) -> Parametric: """RZ gate. @@ -161,7 +161,7 @@ def PHASE( param: float, target: TargetQubits, control: ControlQubits = (None,), - noise: NoiseProtocol = tuple(), + noise: NoiseProtocol = None, ) -> Parametric: """Phase gate. diff --git a/horqrux/primitive.py b/horqrux/primitive.py index edb3cd9..2507fc9 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Iterable, Union import numpy as np @@ -28,7 +28,7 @@ class Primitive: generator_name: str target: QubitSupport control: QubitSupport - noise: NoiseProtocol = field(default_factory=tuple) + noise: NoiseProtocol | None = None @staticmethod def parse_idx( @@ -85,7 +85,7 @@ def __repr__(self) -> str: def I( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -104,7 +104,7 @@ def I( def X( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """X gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -127,7 +127,7 @@ def X( def Y( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -147,7 +147,7 @@ def Y( def Z( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -167,7 +167,7 @@ def Z( def H( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -186,7 +186,7 @@ def H( def S( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -205,7 +205,7 @@ def S( def T( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """T gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -227,7 +227,7 @@ def T( def SWAP( - target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = tuple() + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None ) -> Primitive: """SWAP gate. By providing a control, it turns into a controlled gate (Fredkin gate), use None for no control qubits. diff --git a/horqrux/utils.py b/horqrux/utils.py index cb45ba7..8e91ed2 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -18,8 +18,6 @@ QubitSupport = tuple[Any, ...] ControlQubits = tuple[Union[None, tuple[int, ...]], ...] TargetQubits = tuple[tuple[int, ...], ...] -ErrorProbabilities = Union[tuple[float, ...], float] - ATOL = 1e-014 diff --git a/horqrux/utils_noise.py b/horqrux/utils_noise.py index a0caaa6..67e5326 100644 --- a/horqrux/utils_noise.py +++ b/horqrux/utils_noise.py @@ -104,7 +104,7 @@ def PauliChannel(error_probability: tuple[float, ...]) -> tuple[Array, ...]: + pz Z \\rho Z^{\\dagger} Args: - error_probability (ErrorProbabilities): tuple containing probabilities + error_probability (tuple[float, ...] | float): tuple containing probabilities of X, Y, and Z errors. Raises: @@ -224,7 +224,7 @@ def GeneralizedAmplitudeDamping(error_probability: tuple[float, ...]) -> tuple[A K3 = sqrt(1-p) * [[0, 0], [sqrt(rate), 0]] Args: - error_probability (ErrorProbabilities): The first float must be the probability + error_probability (tuple[float, ...] | float): The first float must be the probability of amplitude damping error, and the second float is the damping rate, indicating the probability of generalized amplitude damping. From f2606d4d29a8066ee958ab56d7eac890cd869a72 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:50:59 +0100 Subject: [PATCH 41/57] more docstr in apply --- horqrux/apply.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 855e348..3549d9e 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -111,17 +111,25 @@ def apply_kraus_operator( state: State, target: tuple[int, ...], ) -> State: + """Apply K \\rho K^\\dagger. + + Args: + kraus (Array): Kraus operator K. + state (State): Input density matrix. + target (tuple[int, ...]): Target qubits. + + Returns: + State: Output density matrix. + """ state_dims: tuple[int, ...] = target n_qubits = int(np.log2(kraus.size)) kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits))) op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int)) - # Ki rho state = jnp.tensordot(a=kraus, b=state, axes=(op_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) - # dagger ops state = jnp.tensordot(a=kraus, b=_dagger(state), axes=(op_dims, state_dims)) state = _dagger(state) @@ -136,6 +144,25 @@ def apply_operator_with_noise( noise: NoiseProtocol, is_density: bool = False, ) -> State: + """Evolves the input state and applies a noisy quantum channel + on the evolved state :math:`\rho`. + + The evolution is represented as a sum of Kraus operators: + .. math:: + S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger, + + Args: + state (State): Input state + operator (Array): Operator to apply. + target (tuple[int, ...]): Target qubits. + control (tuple[int | None, ...]): Control qubits. + noise (NoiseProtocol): The noise protocol. + is_density (bool, optional): If true, state is provided as a density matrix. + Defaults to False. + + Returns: + State: Output state or density matrix. + """ state_gate = ( apply_operator(state, operator, target, control) if not is_density From e3b1f2d3a6028ffc561af2110851117cbd91f20a Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:52:40 +0100 Subject: [PATCH 42/57] change tuple call new state dims --- horqrux/apply.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index 3549d9e..aa807ad 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -58,7 +58,7 @@ def apply_operator( operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) # Apply operator - new_state_dims = tuple(i for i in range(len(state_dims))) + new_state_dims = tuple(range(len(state_dims))) state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims)) return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) @@ -91,7 +91,7 @@ def apply_operator_dm( operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int)) - new_state_dims = tuple(i for i in range(len(state_dims))) + new_state_dims = tuple(range(len(state_dims))) # Apply operator to density matrix: ρ' = O ρ O† support_perm = state_dims + tuple(set(tuple(range(state.ndim // 2))) - set(state_dims)) From 4fef6bf090a16398d45e94ed056f10b65e21a51a Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 16:58:51 +0100 Subject: [PATCH 43/57] fix union --- horqrux/noise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/horqrux/noise.py b/horqrux/noise.py index 4db2464..84460b3 100644 --- a/horqrux/noise.py +++ b/horqrux/noise.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, Union from jax import Array from jax.tree_util import register_pytree_node_class @@ -70,4 +70,4 @@ def __repr__(self) -> str: return self.type + f"(p={self.error_probability})" -NoiseProtocol = tuple[NoiseInstance, ...] | None +NoiseProtocol = Union[tuple[NoiseInstance, ...], None] From d519a99ec914abf7bc4f1f96d84f4e1585f9befd Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 16 Dec 2024 17:49:02 +0100 Subject: [PATCH 44/57] mention density matrix simulator --- README.md | 2 +- docs/index.md | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2d3f427..e995b78 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/) -`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. +`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector and density matrix simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. ## Installation diff --git a/docs/index.md b/docs/index.md index 1111304..1ddbe5c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ # Welcome to horqrux -**horqrux** is a state vector simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/). +**horqrux** is a state vector and density matrix simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/). ## Setup diff --git a/pyproject.toml b/pyproject.toml index 13dcd50..ab1411a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "horqrux" -description = "Jax-based quantum state vector simulator." +description = "Jax-based quantum state vector and density matrix simulator." authors = [ { name = "Gert-Jan Both" , email = "gert-jan.both@pasqal.com" }, { name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" }, From 8488bd8ae22e7e20a7e8100a521ab9b541201f37 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 17 Dec 2024 17:38:44 +0100 Subject: [PATCH 45/57] using single dispatch - breaking --- horqrux/api.py | 95 +++++++++++++++---------- horqrux/apply.py | 164 ++++++++++++++++++++++++++++++++------------ horqrux/shots.py | 51 +++++++------- horqrux/utils.py | 25 ++++++- tests/test_gates.py | 18 ++--- tests/test_noise.py | 32 ++++++--- tests/test_shots.py | 5 +- 7 files changed, 258 insertions(+), 132 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 199c317..cd549f6 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import Counter +from functools import singledispatch from typing import Any, Optional import jax @@ -12,36 +13,18 @@ from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive from horqrux.shots import finite_shots_fwd, observable_to_matrix -from horqrux.utils import DiffMode, ForwardMode, OperationType, inner +from horqrux.utils import DensityMatrix, DiffMode, ForwardMode, OperationType, inner def run( circuit: GateSequence, - state: Array, + state: Array | DensityMatrix, values: dict[str, float] = dict(), - is_density: bool = False, -) -> Array: - return apply_gate(state, circuit, values, is_density=is_density) - +) -> Array | DensityMatrix: + return apply_gate(state, circuit, values) -def sample( - state: Array, - gates: GateSequence, - values: dict[str, float] = dict(), - n_shots: int = 1000, - is_density: bool = False, -) -> Counter: - if n_shots < 1: - raise ValueError("You can only call sample with n_shots>0.") - output_circuit = apply_gate(state, gates, values, is_density=is_density) - if is_density: - n_qubits = len(state.shape) // 2 - d = 2**n_qubits - probs = jnp.diagonal(output_circuit.reshape((d, d))).real - else: - n_qubits = len(state.shape) - probs = jnp.abs(jnp.float_power(output_circuit, 2.0)).ravel() +def sample_from_probs(probs: Array, n_qubits: int, n_shots: int) -> Counter: key = jax.random.PRNGKey(0) # JAX handles pseudo random number generation by tracking an explicit state via a random key @@ -59,64 +42,101 @@ def sample( ) -def __ad_expectation_single_observable( +@singledispatch +def sample( + state: Array, + gates: GateSequence, + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + raise NotImplementedError("sample method is not implemented") + + +@sample.register +def _( state: Array, gates: GateSequence, + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + output_circuit = apply_gate(state, gates, values) + n_qubits = len(state.shape) + probs = jnp.abs(jnp.float_power(output_circuit, 2.0)).ravel() + return sample_from_probs(probs, n_qubits, n_shots) + + +@sample.register +def _( + state: DensityMatrix, + gates: GateSequence, + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + output_circuit = apply_gate(state, gates, values) + n_qubits = len(state.array.shape) // 2 + d = 2**n_qubits + probs = jnp.diagonal(output_circuit.array.reshape((d, d))).real + return sample_from_probs(probs, n_qubits, n_shots) + + +def __ad_expectation_single_observable( + state: Array | DensityMatrix, + gates: GateSequence, observable: Primitive, values: dict[str, float], - is_density: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - out_state = apply_gate(state, gates, values, OperationType.UNITARY, is_density=is_density) - # in case we have noisy simulations - out_state_densitymat = is_density or (out_state.shape != state.shape) + out_state = apply_gate(state, gates, values, OperationType.UNITARY) - if not out_state_densitymat: + if not isinstance(out_state, DensityMatrix): projected_state = apply_gate( out_state, observable, values, OperationType.UNITARY, - is_density=out_state_densitymat, ) return inner(out_state, projected_state).real - n_qubits = len(out_state.shape) // 2 + n_qubits = len(out_state.array.shape) // 2 mat_obs = observable_to_matrix(observable, n_qubits) d = 2**n_qubits - prod = jnp.matmul(mat_obs, out_state.reshape((d, d))) + prod = jnp.matmul(mat_obs, out_state.array.reshape((d, d))) return jnp.trace(prod, axis1=-2, axis2=-1).real def ad_expectation( - state: Array, + state: Array | DensityMatrix, gates: GateSequence, observables: list[Primitive], values: dict[str, float], - is_density: bool = False, ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ outputs = [ - __ad_expectation_single_observable(state, gates, observable, values, is_density) + __ad_expectation_single_observable(state, gates, observable, values) for observable in observables ] return jnp.stack(outputs) def expectation( - state: Array, + state: Array | DensityMatrix, gates: GateSequence, observables: list[Primitive], values: dict[str, float], diff_mode: DiffMode = DiffMode.AD, forward_mode: ForwardMode = ForwardMode.EXACT, n_shots: Optional[int] = None, - is_density: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ @@ -143,6 +163,5 @@ def expectation( observables, values, n_shots=n_shots, - is_density=is_density, key=key, ) diff --git a/horqrux/apply.py b/horqrux/apply.py index aa807ad..977aa71 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial, reduce +from functools import partial, reduce, singledispatch from operator import add from typing import Iterable @@ -13,8 +13,8 @@ from .noise import NoiseProtocol from .utils import ( + DensityMatrix, OperationType, - State, _controlled, _dagger, density_mat, @@ -23,12 +23,23 @@ ) +@singledispatch def apply_operator( - state: State, + state: Array, operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], -) -> State: +) -> Array: + raise NotImplementedError("apply_operator is not implemented") + + +@apply_operator.register +def _( + state: Array, + operator: Array, + target: tuple[int, ...], + control: tuple[int | None, ...], +) -> Array: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. In case of a controlled operation, the 'operator' array will be embedded into a controlled array. @@ -41,14 +52,14 @@ def apply_operator( are moved to their original positions and the state is returned. Arguments: - state: State to operate on. + state: Array to operate on. operator: Array to contract over 'state'. target: tuple of target qubits on which to apply the 'operator' to. control: tuple of control qubits. is_density: Whether the state is provided as a density matrix. Returns: - State after applying 'operator'. + Array after applying 'operator'. """ state_dims: tuple[int, ...] = target if is_controlled(control): @@ -63,12 +74,13 @@ def apply_operator( return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) -def apply_operator_dm( - state: State, +@apply_operator.register +def _( + state: DensityMatrix, operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], -) -> State: +) -> DensityMatrix: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. In case of a controlled operation, the 'operator' array will be embedded into a controlled array. @@ -94,32 +106,32 @@ def apply_operator_dm( new_state_dims = tuple(range(len(state_dims))) # Apply operator to density matrix: ρ' = O ρ O† - support_perm = state_dims + tuple(set(tuple(range(state.ndim // 2))) - set(state_dims)) - state = permute_basis(state, support_perm, False) - state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, new_state_dims)) + support_perm = state_dims + tuple(set(tuple(range(state.array.ndim // 2))) - set(state_dims)) + out_state = permute_basis(state.array, support_perm, False) + out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, new_state_dims)) - state = _dagger(state) - state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, op_in_dims)) - state = _dagger(state) + out_state = _dagger(out_state) + out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, op_in_dims)) + out_state = _dagger(out_state) - state = permute_basis(state, support_perm, True) - return state + out_state = permute_basis(out_state, support_perm, True) + return DensityMatrix(out_state) def apply_kraus_operator( kraus: Array, - state: State, + state: Array, target: tuple[int, ...], -) -> State: +) -> Array: """Apply K \\rho K^\\dagger. Args: kraus (Array): Kraus operator K. - state (State): Input density matrix. + state (Array): Input density matrix. target (tuple[int, ...]): Target qubits. Returns: - State: Output density matrix. + Array: Output density matrix. """ state_dims: tuple[int, ...] = target n_qubits = int(np.log2(kraus.size)) @@ -137,13 +149,12 @@ def apply_kraus_operator( def apply_operator_with_noise( - state: State, + state: Array | DensityMatrix, operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], noise: NoiseProtocol, - is_density: bool = False, -) -> State: +) -> Array: """Evolves the input state and applies a noisy quantum channel on the evolved state :math:`\rho`. @@ -152,27 +163,27 @@ def apply_operator_with_noise( S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger, Args: - state (State): Input state + state (Array | DensityMatrix): Input state or density matrix. operator (Array): Operator to apply. target (tuple[int, ...]): Target qubits. control (tuple[int | None, ...]): Control qubits. noise (NoiseProtocol): The noise protocol. - is_density (bool, optional): If true, state is provided as a density matrix. - Defaults to False. Returns: - State: Output state or density matrix. + Array: Output state or density matrix. """ - state_gate = ( - apply_operator(state, operator, target, control) - if not is_density - else apply_operator_dm(state, operator, target, control) - ) + state_gate = apply_operator(state, operator, target, control) if noise is None: return state_gate else: kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise)))) - apply_one_kraus = jax.vmap(partial(apply_kraus_operator, state=state_gate, target=target)) + apply_one_kraus = jax.vmap( + partial( + apply_kraus_operator, + state=state_gate.array if isinstance(state_gate, DensityMatrix) else state_gate, + target=target, + ) + ) kraus_evol = apply_one_kraus(kraus_ops) output_dm = jnp.sum(kraus_evol, 0) return output_dm @@ -232,18 +243,30 @@ def merge_operators( return merged_operators[::-1], merged_targets[::-1], merged_controls[::-1] +@singledispatch def apply_gate( - state: State, + state: Array, gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, - is_density: bool = False, -) -> State: +) -> Array | DensityMatrix: + raise NotImplementedError("apply_gate is not implemented") + + +@apply_gate.register +def _( + state: Array, + gate: Primitive | Iterable[Primitive], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> Array | DensityMatrix: """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. Arguments: - state: State or DensityMatrix to operate on. + state: Array or DensityMatrix to operate on. gate: Gate(s) to apply. values: A dictionary with parameter values. op_type: The type of operation to perform: Unitary, Dagger or Jacobian. @@ -252,7 +275,7 @@ def apply_gate( is_density: If True, state is provided as a density matrix. Returns: - State or density matrix after applying 'gate'. + Array or density matrix after applying 'gate'. """ operator: tuple[Array, ...] noise = list() @@ -272,14 +295,67 @@ def apply_gate( # faster way to check has_noise has_noise = noise != [None] * len(noise) - if has_noise and not is_density: + if has_noise: state = density_mat(state) - is_density = True + + output_state = reduce( + lambda state, gate: apply_operator_with_noise(state, *gate), + zip(operator, target, control, noise), + state.array, + ) + output_state = DensityMatrix(output_state) + else: + output_state = reduce( + lambda state, gate: apply_operator_with_noise(state, *gate), + zip(operator, target, control, noise), + state, + ) + + return output_state + + +@apply_gate.register +def _( + state: DensityMatrix, + gate: Primitive | Iterable[Primitive], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> DensityMatrix: + """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. + Arguments: + state: Array or DensityMatrix to operate on. + gate: Gate(s) to apply. + values: A dictionary with parameter values. + op_type: The type of operation to perform: Unitary, Dagger or Jacobian. + group_gates: Group gates together which are acting on the same qubit. + merge_ops: Attempt to merge operators acting on the same qubit. + is_density: If True, state is provided as a density matrix. + + Returns: + Array or density matrix after applying 'gate'. + """ + operator: tuple[Array, ...] + noise = list() + if isinstance(gate, Primitive): + operator_fn = getattr(gate, op_type) + operator, target, control = (operator_fn(values),), gate.target, gate.control + noise += [gate.noise] + else: + if group_gates: + gate = group_by_index(gate) + operator = tuple(getattr(g, op_type)(values) for g in gate) + target = reduce(add, [g.target for g in gate]) + control = reduce(add, [g.control for g in gate]) + if merge_ops: + operator, target, control = merge_operators(operator, target, control) + noise = [g.noise for g in gate] output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), - zip(operator, target, control, noise, (is_density,) * len(target)), - state, + zip(operator, target, control, noise), + state.array, ) - return output_state + return DensityMatrix(output_state) diff --git a/horqrux/shots.py b/horqrux/shots.py index a330314..4e515aa 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -10,7 +10,7 @@ from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive -from horqrux.utils import none_like +from horqrux.utils import DensityMatrix, none_like def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: @@ -33,37 +33,41 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5, 6)) -def finite_shots_fwd( +def eigenval_decomposition_sampling( state: Array, + observables: list[Primitive], + n_qubits: int, + n_shots: int, + key: Any = jax.random.PRNGKey(0), +) -> Array: + mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] + eigvecs, eigvals = align_eigenvectors(eigs) + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + probs = jnp.abs(inner_prod) ** 2 + return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + + +@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) +def finite_shots_fwd( + state: Array | DensityMatrix, gates: GateSequence, observables: list[Primitive], values: dict[str, float], n_shots: int = 100, - is_density: bool = False, key: Any = jax.random.PRNGKey(0), ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - state = apply_gate(state, gates, values, is_density=is_density) - n_qubits = len(state.shape) - if not is_density: - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] - eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] - eigvecs, eigvals = align_eigenvectors(eigs) - inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) - probs = jnp.abs(inner_prod) ** 2 - return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + if isinstance(state, DensityMatrix): + output_gates = apply_gate(state, gates, values).array + n_qubits = len(output_gates.array.shape) // 2 else: - n_qubits = n_qubits // 2 - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] - mat_obs = jnp.stack(mat_obs) - dim = 2**n_qubits - rho = state.reshape((dim, dim)) - prod = jnp.matmul(mat_obs, rho) - return jnp.trace(prod, axis1=-2, axis2=-1).real + output_gates = apply_gate(state, gates, values) + n_qubits = len(state.shape) + return eigenval_decomposition_sampling(output_gates, observables, n_qubits, n_shots, key) def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: @@ -110,7 +114,6 @@ def finite_shots_jvp( gates: GateSequence, observable: Primitive, n_shots: int, - is_density: bool, key: Array, primals: tuple[dict[str, float]], tangents: tuple[dict[str, float]], @@ -127,14 +130,14 @@ def jvp_component(param_name: str, key: Array) -> Array: up_key, down_key = random.split(key) up_val = values.copy() up_val[param_name] = up_val[param_name] + shift - f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, is_density, up_key) + f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key) down_val = values.copy() down_val[param_name] = down_val[param_name] - shift - f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, is_density, down_key) + f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key) grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) return grad * tangent_dict[param_name] params_with_keys = zip(values.keys(), random.split(key, len(values))) - fwd = finite_shots_fwd(state, gates, observable, values, n_shots, is_density, key) + fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) jvp = sum(jvp_component(param, key) for param, key in params_with_keys) return fwd, jvp.reshape(fwd.shape) diff --git a/horqrux/utils.py b/horqrux/utils.py index 8e91ed2..dc0bfdb 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from enum import Enum from typing import Any, Iterable, Union @@ -7,6 +8,7 @@ import jax.numpy as jnp import numpy as np from jax import Array +from jax.tree_util import register_pytree_node_class from jax.typing import ArrayLike from numpy import log2 @@ -21,7 +23,24 @@ ATOL = 1e-014 -def density_mat(state: Array) -> Array: +@register_pytree_node_class +@dataclass +class DensityMatrix: + """Dataclass to identify density matrices from states.""" + + array: Array + + def tree_flatten(self) -> tuple[tuple, tuple[Array]]: + children = () + aux_data = self.array + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + +def density_mat(state: Array | DensityMatrix) -> DensityMatrix: """Convert state to density matrix Args: @@ -31,9 +50,11 @@ def density_mat(state: Array) -> Array: State: Density matrix representation. """ # Expand dimensions to enable broadcasting + if isinstance(state, DensityMatrix): + return state ket = jnp.expand_dims(state, axis=tuple(range(state.ndim, 2 * state.ndim))) bra = jnp.conj(jnp.expand_dims(state, axis=tuple(range(state.ndim)))) - return ket * bra + return DensityMatrix(ket * bra) def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> Array: diff --git a/tests/test_gates.py b/tests/test_gates.py index bca0e40..afb737d 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -7,7 +7,7 @@ import pytest from jax import Array -from horqrux.apply import apply_gate, apply_operator, apply_operator_dm +from horqrux.apply import apply_gate, apply_operator from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import density_mat, equivalent_state, product_state, random_state @@ -29,13 +29,13 @@ def test_primitive(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator_dm( + dm = apply_operator( density_mat(orig_state), gate.unitary(), gate.target[0], gate.control[0], ) - assert jnp.allclose(dm, density_mat(state)) + assert jnp.allclose(dm.array, density_mat(state).array) @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) @@ -52,13 +52,13 @@ def test_controlled_primitive(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator_dm( + dm = apply_operator( density_mat(orig_state), gate.unitary(), gate.target[0], gate.control[0], ) - assert jnp.allclose(dm, density_mat(state)) + assert jnp.allclose(dm.array, density_mat(state).array) @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) @@ -73,13 +73,13 @@ def test_parametric(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator_dm( + dm = apply_operator( density_mat(orig_state), gate.unitary(values), gate.target[0], gate.control[0], ) - assert jnp.allclose(dm, density_mat(state)) + assert jnp.allclose(dm.array, density_mat(state).array) @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) @@ -97,13 +97,13 @@ def test_controlled_parametric(gate_fn: Callable) -> None: ) # test density matrix is similar to pure state - dm = apply_operator_dm( + dm = apply_operator( density_mat(orig_state), gate.unitary(values), gate.target[0], gate.control[0], ) - assert jnp.allclose(dm, density_mat(state)) + assert jnp.allclose(dm.array, density_mat(state).array) @pytest.mark.parametrize( diff --git a/tests/test_noise.py b/tests/test_noise.py index 52acb5e..7add33c 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -52,16 +52,19 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: orig_state = random_state(MAX_QUBITS) output_dm = apply_gate(orig_state, noisy_gate) # check output is a density matrix - assert len(output_dm.shape) == dm_shape_len + assert len(output_dm.array.shape) == dm_shape_len orig_dm = density_mat(orig_state) - assert len(orig_dm.shape) == dm_shape_len - output_dm2 = apply_gate(orig_dm, noisy_gate, is_density=True) - assert jnp.allclose(output_dm2, output_dm) + assert len(orig_dm.array.shape) == dm_shape_len + output_dm2 = apply_gate( + orig_dm, + noisy_gate, + ) + assert jnp.allclose(output_dm2.array, output_dm.array) perfect_gate = gate_fn(target) perfect_output = density_mat(apply_gate(orig_state, perfect_gate)) - assert not jnp.allclose(perfect_output, output_dm) + assert not jnp.allclose(perfect_output.array, output_dm.array) @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) @@ -77,17 +80,21 @@ def test_noisy_parametric(gate_fn: Callable, noise_type: NoiseType) -> None: output_dm = apply_gate(orig_state, noisy_gate, values) # check output is a density matrix - assert len(output_dm.shape) == dm_shape_len + assert len(output_dm.array.shape) == dm_shape_len orig_dm = density_mat(orig_state) - assert len(orig_dm.shape) == dm_shape_len + assert len(orig_dm.array.shape) == dm_shape_len - output_dm2 = apply_gate(orig_dm, noisy_gate, values, is_density=True) - assert jnp.allclose(output_dm2, output_dm) + output_dm2 = apply_gate( + orig_dm, + noisy_gate, + values, + ) + assert jnp.allclose(output_dm2.array, output_dm.array) perfect_gate = gate_fn("theta", target) perfect_output = density_mat(apply_gate(orig_state, perfect_gate, values)) - assert not jnp.allclose(perfect_output, output_dm) + assert not jnp.allclose(perfect_output.array, output_dm.array) def simple_depolarizing_test() -> None: @@ -116,7 +123,10 @@ def simple_depolarizing_test() -> None: # test sampling dm_state = density_mat(state) - sampling_output = sample(dm_state, ops, is_density=True) + sampling_output = sample( + dm_state, + ops, + ) assert "11" in sampling_output.keys() assert "01" in sampling_output.keys() diff --git a/tests/test_shots.py b/tests/test_shots.py index 8447f41..07e4934 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -25,9 +25,7 @@ def exact(x): def exact_dm(x): values = {"theta": x} - return expectation( - density_mat(state), ops, observables, values, diff_mode="ad", is_density=True - ) + return expectation(density_mat(state), ops, observables, values, diff_mode="ad") def shots(x): values = {"theta": x} @@ -43,7 +41,6 @@ def shots_dm(x): observables, values, diff_mode="gpsr", - is_density=True, forward_mode="shots", n_shots=N_SHOTS, ) From 4050e27964776e248d053d0462e5025ddae5fc8a Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 17 Dec 2024 18:26:43 +0100 Subject: [PATCH 46/57] shots not working --- horqrux/shots.py | 4 +++- horqrux/utils.py | 2 +- tests/test_shots.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/horqrux/shots.py b/horqrux/shots.py index 4e515aa..5827ffe 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -63,7 +63,9 @@ def finite_shots_fwd( """ if isinstance(state, DensityMatrix): output_gates = apply_gate(state, gates, values).array - n_qubits = len(output_gates.array.shape) // 2 + n_qubits = len(output_gates.shape) // 2 + d = 2**n_qubits + output_gates = output_gates.reshape((d, d)) else: output_gates = apply_gate(state, gates, values) n_qubits = len(state.shape) diff --git a/horqrux/utils.py b/horqrux/utils.py index dc0bfdb..7d29306 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -32,7 +32,7 @@ class DensityMatrix: def tree_flatten(self) -> tuple[tuple, tuple[Array]]: children = () - aux_data = self.array + aux_data = (self.array,) return (children, aux_data) @classmethod diff --git a/tests/test_shots.py b/tests/test_shots.py index 07e4934..7e68e6b 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -47,13 +47,13 @@ def shots_dm(x): exp_exact = exact(x) exp_exact_dm = exact_dm(x) - assert jnp.allclose(exp_exact, exp_exact_dm) + # assert jnp.allclose(exp_exact, exp_exact_dm) exp_shots = shots(x) exp_shots_dm = shots_dm(x) assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) + # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum()) From 32ee665c8d9d5bce03de95466f71ad462fdcf3f0 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 17 Dec 2024 18:28:08 +0100 Subject: [PATCH 47/57] add fixmes --- tests/test_shots.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_shots.py b/tests/test_shots.py index 7e68e6b..5f012b0 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -46,11 +46,13 @@ def shots_dm(x): ) exp_exact = exact(x) - exp_exact_dm = exact_dm(x) + # FIXME: DM expectation not working + # exp_exact_dm = exact_dm(x) # assert jnp.allclose(exp_exact, exp_exact_dm) exp_shots = shots(x) - exp_shots_dm = shots_dm(x) + # FIXME: DM expectation not working + # exp_shots_dm = shots_dm(x) assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) From 8c6c8815ef0978af8c9e53828e8539efa39c7b55 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 13:44:06 +0100 Subject: [PATCH 48/57] fix single dispatch methods --- horqrux/api.py | 114 ++++++++++++++++++-------------------------- horqrux/apply.py | 114 ++++++++++++++++++++++++++++---------------- horqrux/utils.py | 46 ++++++++++++++++++ tests/test_noise.py | 1 + tests/test_shots.py | 19 ++++---- 5 files changed, 177 insertions(+), 117 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index cd549f6..7bf36c2 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -13,7 +13,15 @@ from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive from horqrux.shots import finite_shots_fwd, observable_to_matrix -from horqrux.utils import DensityMatrix, DiffMode, ForwardMode, OperationType, inner +from horqrux.utils import ( + DensityMatrix, + DiffMode, + ForwardMode, + OperationType, + get_probas, + inner, + sample_from_probs, +) def run( @@ -24,91 +32,61 @@ def run( return apply_gate(state, circuit, values) -def sample_from_probs(probs: Array, n_qubits: int, n_shots: int) -> Counter: - key = jax.random.PRNGKey(0) - - # JAX handles pseudo random number generation by tracking an explicit state via a random key - # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html - samples = jax.vmap( - lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) - )(jax.random.split(key, n_shots)) - - return Counter( - { - format(k, "0{}b".format(n_qubits)): count.item() - for k, count in enumerate(jnp.bincount(samples)) - if count > 0 - } - ) - - -@singledispatch def sample( - state: Array, - gates: GateSequence, - values: dict[str, float] = dict(), - n_shots: int = 1000, -) -> Counter: - raise NotImplementedError("sample method is not implemented") - - -@sample.register -def _( - state: Array, + state: Array | DensityMatrix, gates: GateSequence, values: dict[str, float] = dict(), n_shots: int = 1000, ) -> Counter: if n_shots < 1: raise ValueError("You can only call sample with n_shots>0.") - output_circuit = apply_gate(state, gates, values) - n_qubits = len(state.shape) - probs = jnp.abs(jnp.float_power(output_circuit, 2.0)).ravel() - return sample_from_probs(probs, n_qubits, n_shots) + if isinstance(output_circuit, DensityMatrix): + n_qubits = len(output_circuit.array.shape) // 2 + d = 2**n_qubits + output_circuit.array = output_circuit.array.reshape((d, d)) + else: + n_qubits = len(output_circuit.array.shape) -@sample.register -def _( - state: DensityMatrix, - gates: GateSequence, - values: dict[str, float] = dict(), - n_shots: int = 1000, -) -> Counter: - if n_shots < 1: - raise ValueError("You can only call sample with n_shots>0.") - - output_circuit = apply_gate(state, gates, values) - n_qubits = len(state.array.shape) // 2 - d = 2**n_qubits - probs = jnp.diagonal(output_circuit.array.reshape((d, d))).real + probs = get_probas(output_circuit) return sample_from_probs(probs, n_qubits, n_shots) +@singledispatch def __ad_expectation_single_observable( - state: Array | DensityMatrix, - gates: GateSequence, + output_state: Array, observable: Primitive, values: dict[str, float], ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. - """ - out_state = apply_gate(state, gates, values, OperationType.UNITARY) + raise NotImplementedError("__ad_expectation_single_observable is not implemented") - if not isinstance(out_state, DensityMatrix): - projected_state = apply_gate( - out_state, - observable, - values, - OperationType.UNITARY, - ) - return inner(out_state, projected_state).real - n_qubits = len(out_state.array.shape) // 2 + +@__ad_expectation_single_observable.register +def _( + state: Array, + observable: Primitive, + values: dict[str, float], +) -> Array: + projected_state = apply_gate( + state, + observable, + values, + OperationType.UNITARY, + ) + return inner(state, projected_state).real + + +@__ad_expectation_single_observable.register +def _( + state: DensityMatrix, + observable: Primitive, + values: dict[str, float], +) -> Array: + n_qubits = len(state.array.shape) // 2 mat_obs = observable_to_matrix(observable, n_qubits) d = 2**n_qubits - prod = jnp.matmul(mat_obs, out_state.array.reshape((d, d))) + prod = jnp.matmul(mat_obs, state.array.reshape((d, d))) return jnp.trace(prod, axis1=-2, axis2=-1).real @@ -123,7 +101,9 @@ def ad_expectation( and compute the expectation given an observable. """ outputs = [ - __ad_expectation_single_observable(state, gates, observable, values) + __ad_expectation_single_observable( + apply_gate(state, gates, values, OperationType.UNITARY), observable, values + ) for observable in observables ] return jnp.stack(outputs) diff --git a/horqrux/apply.py b/horqrux/apply.py index 977aa71..f112b56 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -30,6 +30,20 @@ def apply_operator( target: tuple[int, ...], control: tuple[int | None, ...], ) -> Array: + """Apply an operator on a state or density matrix. + + Args: + state (Array): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. + + Raises: + NotImplementedError: If not implemented for given types. + + Returns: + Array: The output of the application of the operator. + """ raise NotImplementedError("apply_operator is not implemented") @@ -51,12 +65,11 @@ def _( dimension 'i' of 'state'. To restore the former order of dimensions, the affected dimensions are moved to their original positions and the state is returned. - Arguments: - state: Array to operate on. - operator: Array to contract over 'state'. - target: tuple of target qubits on which to apply the 'operator' to. - control: tuple of control qubits. - is_density: Whether the state is provided as a density matrix. + Args: + state (Array): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. Returns: Array after applying 'operator'. @@ -85,12 +98,11 @@ def _( of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. In case of a controlled operation, the 'operator' array will be embedded into a controlled array. - Arguments: - state: Density matrix to operate on. - operator: Array to contract over 'state'. - target: tuple of target qubits on which to apply the 'operator' to. - control: tuple of control qubits. - is_density: Whether the state is provided as a density matrix. + Args: + state (Array): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. Returns: Density matrix after applying 'operator'. @@ -106,8 +118,10 @@ def _( new_state_dims = tuple(range(len(state_dims))) # Apply operator to density matrix: ρ' = O ρ O† - support_perm = state_dims + tuple(set(tuple(range(state.array.ndim // 2))) - set(state_dims)) - out_state = permute_basis(state.array, support_perm, False) + out_state = state.array + support_perm = state_dims + tuple(set(tuple(range(out_state.ndim // 2))) - set(state_dims)) + + out_state = permute_basis(out_state, support_perm, False) out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, new_state_dims)) out_state = _dagger(out_state) @@ -120,7 +134,7 @@ def _( def apply_kraus_operator( kraus: Array, - state: Array, + array: Array, target: tuple[int, ...], ) -> Array: """Apply K \\rho K^\\dagger. @@ -138,23 +152,53 @@ def apply_kraus_operator( kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits))) op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int)) - state = jnp.tensordot(a=kraus, b=state, axes=(op_dims, state_dims)) + array = jnp.tensordot(a=kraus, b=array, axes=(op_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) - state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + array = jnp.moveaxis(a=array, source=new_state_dims, destination=state_dims) + + array = jnp.tensordot(a=kraus, b=_dagger(array), axes=(op_dims, state_dims)) + array = _dagger(array) + + return array + + +def apply_kraus_sum( + kraus_ops: NoiseProtocol, + array: Array, + target: tuple[int, ...], +) -> DensityMatrix: + """Apply the following evolution as a sum of Kraus operators: + .. math:: + S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger - state = jnp.tensordot(a=kraus, b=_dagger(state), axes=(op_dims, state_dims)) - state = _dagger(state) + Args: + noise (NoiseProtocol): Noise containing the K_i + state (Array): Input array + target (tuple[int, ...]): Qubits the operator is defined on. - return state + Returns: + DensityMatrix: Output density matrix. + """ + + apply_one_kraus = jax.vmap( + partial( + apply_kraus_operator, + array=array, + target=target, + ) + ) + kraus_evol = apply_one_kraus(kraus_ops) + output_dm = jnp.sum(kraus_evol, 0) + return DensityMatrix(output_dm) def apply_operator_with_noise( - state: Array | DensityMatrix, + state: DensityMatrix, operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], noise: NoiseProtocol, -) -> Array: +) -> Array | DensityMatrix: """Evolves the input state and applies a noisy quantum channel on the evolved state :math:`\rho`. @@ -177,15 +221,7 @@ def apply_operator_with_noise( return state_gate else: kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise)))) - apply_one_kraus = jax.vmap( - partial( - apply_kraus_operator, - state=state_gate.array if isinstance(state_gate, DensityMatrix) else state_gate, - target=target, - ) - ) - kraus_evol = apply_one_kraus(kraus_ops) - output_dm = jnp.sum(kraus_evol, 0) + output_dm = apply_kraus_sum(kraus_ops, state_gate.array, target) return output_dm @@ -272,7 +308,6 @@ def _( op_type: The type of operation to perform: Unitary, Dagger or Jacobian. group_gates: Group gates together which are acting on the same qubit. merge_ops: Attempt to merge operators acting on the same qubit. - is_density: If True, state is provided as a density matrix. Returns: Array or density matrix after applying 'gate'. @@ -301,16 +336,14 @@ def _( output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), zip(operator, target, control, noise), - state.array, + state, ) - output_state = DensityMatrix(output_state) else: output_state = reduce( - lambda state, gate: apply_operator_with_noise(state, *gate), - zip(operator, target, control, noise), + lambda state, gate: apply_operator(state, *gate), + zip(operator, target, control), state, ) - return output_state @@ -331,7 +364,6 @@ def _( op_type: The type of operation to perform: Unitary, Dagger or Jacobian. group_gates: Group gates together which are acting on the same qubit. merge_ops: Attempt to merge operators acting on the same qubit. - is_density: If True, state is provided as a density matrix. Returns: Array or density matrix after applying 'gate'. @@ -351,11 +383,9 @@ def _( if merge_ops: operator, target, control = merge_operators(operator, target, control) noise = [g.noise for g in gate] - output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), zip(operator, target, control, noise), - state.array, + state, ) - - return DensityMatrix(output_state) + return output_state diff --git a/horqrux/utils.py b/horqrux/utils.py index 7d29306..cd24193 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +from collections import Counter from dataclasses import dataclass from enum import Enum +from functools import singledispatch from typing import Any, Iterable, Union import jax @@ -251,3 +253,47 @@ def _normalize(wf: Array) -> Array: def is_normalized(state: Array) -> bool: return equivalent_state(state, state) + + +def sample_from_probs(probs: Array, n_qubits: int, n_shots: int) -> Counter: + key = jax.random.PRNGKey(0) + + # JAX handles pseudo random number generation by tracking an explicit state via a random key + # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html + samples = jax.vmap( + lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) + )(jax.random.split(key, n_shots)) + + return Counter( + { + format(k, "0{}b".format(n_qubits)): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) + + +@singledispatch +def get_probas(state: Array) -> Array: + """Extract probabilities from state or density matrix. + + Args: + state (Array): Input array. + + Raises: + NotImplementedError: If not implemented for given types. + + Returns: + Array: Vector of probabilities. + """ + raise NotImplementedError("get_probas is not implemented") + + +@get_probas.register +def _(state: Array) -> Array: + return jnp.abs(jnp.float_power(state, 2.0)).ravel() + + +@get_probas.register +def _(state: DensityMatrix) -> Array: + return jnp.diagonal(state.array).real diff --git a/tests/test_noise.py b/tests/test_noise.py index 7add33c..89b8263 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -51,6 +51,7 @@ def test_noisy_primitive(gate_fn: Callable, noise_type: NoiseType) -> None: orig_state = random_state(MAX_QUBITS) output_dm = apply_gate(orig_state, noisy_gate) + # check output is a density matrix assert len(output_dm.array.shape) == dm_shape_len diff --git a/tests/test_shots.py b/tests/test_shots.py index 5f012b0..368d0f1 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from horqrux import expectation, random_state +from horqrux import expectation, random_state, run from horqrux.parametric import RX from horqrux.primitive import Z from horqrux.utils import density_mat @@ -45,17 +45,20 @@ def shots_dm(x): n_shots=N_SHOTS, ) + expected_dm = density_mat(run(ops, state, {"theta": x})) + output_dm = run(ops, density_mat(state), {"theta": x}) + assert jnp.allclose(expected_dm.array, output_dm.array) + exp_exact = exact(x) - # FIXME: DM expectation not working - # exp_exact_dm = exact_dm(x) - # assert jnp.allclose(exp_exact, exp_exact_dm) + exp_exact_dm = exact_dm(x) + assert jnp.allclose(exp_exact, exp_exact_dm) exp_shots = shots(x) - # FIXME: DM expectation not working - # exp_shots_dm = shots_dm(x) + # # FIXME: DM expectation not working + # # exp_shots_dm = shots_dm(x) - assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) + # assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) + # # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum()) From 3a403a906365054b78541baa9dff95b814c00f0e Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 13:51:43 +0100 Subject: [PATCH 49/57] adding values to observable to matrix --- horqrux/api.py | 2 +- horqrux/shots.py | 15 +++++++++++---- tests/test_shots.py | 8 ++++---- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 7bf36c2..6bfd7b9 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -84,7 +84,7 @@ def _( values: dict[str, float], ) -> Array: n_qubits = len(state.array.shape) // 2 - mat_obs = observable_to_matrix(observable, n_qubits) + mat_obs = observable_to_matrix(observable, n_qubits, values) d = 2**n_qubits prod = jnp.matmul(mat_obs, state.array.reshape((d, d))) return jnp.trace(prod, axis1=-2, axis2=-1).real diff --git a/horqrux/shots.py b/horqrux/shots.py index 5827ffe..bcce401 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -13,7 +13,11 @@ from horqrux.utils import DensityMatrix, none_like -def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: +def observable_to_matrix( + observable: Primitive, + n_qubits: int, + values: dict[str, float], +) -> Array: """For finite shot sampling we need to calculate the eigenvalues/vectors of an observable. This helper function takes an observable and system size (n_qubits) and returns the overall action of the observable on the whole @@ -25,7 +29,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: observable.control == observable.parse_idx(none_like(observable.target)), "Controlled gates cannot be promoted from observables to operations on the whole state vector", ) - unitary = observable.unitary() + unitary = observable.unitary(values=values) target = observable.target[0][0] identity = jnp.eye(2, dtype=unitary.dtype) ops = [identity for _ in range(n_qubits)] @@ -36,11 +40,12 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: def eigenval_decomposition_sampling( state: Array, observables: list[Primitive], + values: dict[str, float], n_qubits: int, n_shots: int, key: Any = jax.random.PRNGKey(0), ) -> Array: - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + mat_obs = [observable_to_matrix(observable, n_qubits, values) for observable in observables] eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] eigvecs, eigvals = align_eigenvectors(eigs) inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) @@ -69,7 +74,9 @@ def finite_shots_fwd( else: output_gates = apply_gate(state, gates, values) n_qubits = len(state.shape) - return eigenval_decomposition_sampling(output_gates, observables, n_qubits, n_shots, key) + return eigenval_decomposition_sampling( + output_gates, observables, values, n_qubits, n_shots, key + ) def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: diff --git a/tests/test_shots.py b/tests/test_shots.py index 368d0f1..8184d1e 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -54,11 +54,11 @@ def shots_dm(x): assert jnp.allclose(exp_exact, exp_exact_dm) exp_shots = shots(x) - # # FIXME: DM expectation not working - # # exp_shots_dm = shots_dm(x) + # FIXME: DM expectation not working + # exp_shots_dm = shots_dm(x) - # assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - # # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) + assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) + # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum()) From 49172fc1276c5ad32fae2869e223468983eb3c42 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 14:53:11 +0100 Subject: [PATCH 50/57] fix shots dm --- horqrux/shots.py | 62 ++++++++++++++++++++++++++++++++++++++++----- tests/test_shots.py | 5 ++-- 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/horqrux/shots.py b/horqrux/shots.py index bcce401..1f2911e 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial, reduce +from functools import partial, reduce, singledispatch from typing import Any import jax @@ -37,8 +37,57 @@ def observable_to_matrix( return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) +@singledispatch +def probs_from_eigenvectors_state(state: Array, eigvecs: Array) -> Array: + """Obtain the probabilities using an input state and the eigenvectors decomposition + of an observable. + + Args: + state (Array): Input array. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + raise NotImplementedError("prod_eigenvectors_state is not implemented") + + +@probs_from_eigenvectors_state.register +def _(state: Array, eigvecs: Array) -> Array: + """Obtain the probabilities using an input quantum state vector + and the eigenvectors decomposition + of an observable. + + Args: + state (Array): Input array. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + return jnp.abs(inner_prod) ** 2 + + +@probs_from_eigenvectors_state.register +def _(state: DensityMatrix, eigvecs: Array) -> Array: + """Obtain the probabilities using an input quantum density matrix + and the eigenvectors decomposition + of an observable. + + Args: + state (DensityMatrix): Input array. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + mat_prob = jnp.conjugate(eigvecs.T) @ state.array @ eigvecs + return mat_prob.diagonal().real + + def eigenval_decomposition_sampling( - state: Array, + state: Array | DensityMatrix, observables: list[Primitive], values: dict[str, float], n_qubits: int, @@ -48,8 +97,7 @@ def eigenval_decomposition_sampling( mat_obs = [observable_to_matrix(observable, n_qubits, values) for observable in observables] eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] eigvecs, eigvals = align_eigenvectors(eigs) - inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) - probs = jnp.abs(inner_prod) ** 2 + probs = probs_from_eigenvectors_state(state, eigvecs) return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) @@ -67,10 +115,10 @@ def finite_shots_fwd( and compute the expectation given an observable. """ if isinstance(state, DensityMatrix): - output_gates = apply_gate(state, gates, values).array - n_qubits = len(output_gates.shape) // 2 + output_gates = apply_gate(state, gates, values) + n_qubits = len(output_gates.array.shape) // 2 d = 2**n_qubits - output_gates = output_gates.reshape((d, d)) + output_gates.array = output_gates.array.reshape((d, d)) else: output_gates = apply_gate(state, gates, values) n_qubits = len(state.shape) diff --git a/tests/test_shots.py b/tests/test_shots.py index 8184d1e..5660610 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -54,11 +54,10 @@ def shots_dm(x): assert jnp.allclose(exp_exact, exp_exact_dm) exp_shots = shots(x) - # FIXME: DM expectation not working - # exp_shots_dm = shots_dm(x) + exp_shots_dm = shots_dm(x) assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) + assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum()) From d5a3d80c959354cf3309b9e6608d891e16bfa6da Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 15:31:22 +0100 Subject: [PATCH 51/57] fix doc strings --- horqrux/api.py | 59 ++++++++++++++++++++++------ horqrux/apply.py | 100 ++++++++++++++++++++++++++++------------------- horqrux/shots.py | 10 ++--- horqrux/utils.py | 6 ++- 4 files changed, 116 insertions(+), 59 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 6bfd7b9..96bb308 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -18,6 +18,7 @@ DiffMode, ForwardMode, OperationType, + State, get_probas, inner, sample_from_probs, @@ -26,18 +27,32 @@ def run( circuit: GateSequence, - state: Array | DensityMatrix, + state: State, values: dict[str, float] = dict(), -) -> Array | DensityMatrix: +) -> State: return apply_gate(state, circuit, values) def sample( - state: Array | DensityMatrix, + state: State, gates: GateSequence, values: dict[str, float] = dict(), n_shots: int = 1000, ) -> Counter: + """Sample from a quantum program. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + values (dict[str, float], optional): _description_. Defaults to dict(). + n_shots (int, optional): Parameter values.. Defaults to 1000. + + Raises: + ValueError: If n_shots < 1. + + Returns: + Counter: Bitstrings and frequencies. + """ if n_shots < 1: raise ValueError("You can only call sample with n_shots>0.") output_circuit = apply_gate(state, gates, values) @@ -55,7 +70,7 @@ def sample( @singledispatch def __ad_expectation_single_observable( - output_state: Array, + state: Any, observable: Primitive, values: dict[str, float], ) -> Array: @@ -91,14 +106,22 @@ def _( def ad_expectation( - state: Array | DensityMatrix, + state: State, gates: GateSequence, observables: list[Primitive], values: dict[str, float], ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. + """Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + observables (list[Primitive]): List of observables. + values (dict[str, float]): Parameter values. + + Returns: + Array: Expectation values. """ outputs = [ __ad_expectation_single_observable( @@ -110,7 +133,7 @@ def ad_expectation( def expectation( - state: Array | DensityMatrix, + state: State, gates: GateSequence, observables: list[Primitive], values: dict[str, float], @@ -119,13 +142,27 @@ def expectation( n_shots: Optional[int] = None, key: Any = jax.random.PRNGKey(0), ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' + """Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + observables (list[Primitive]): List of observables. + values (dict[str, float]): Parameter values. + diff_mode (DiffMode, optional): Differentiation mode. Defaults to DiffMode.AD. + forward_mode (ForwardMode, optional): Type of forward method. Defaults to ForwardMode.EXACT. + n_shots (Optional[int], optional): Number of shots. Defaults to None. + key (Any, optional): Random key. Defaults to jax.random.PRNGKey(0). + + Returns: + Array: Expectation values. """ if diff_mode == DiffMode.AD: return ad_expectation(state, gates, observables, values) elif diff_mode == DiffMode.ADJOINT: + if isinstance(state, DensityMatrix): + raise ValueError("Adjoint does not support density matrices.") return adjoint_expectation(state, gates, observables, values) elif diff_mode == DiffMode.GPSR: checkify.check( diff --git a/horqrux/apply.py b/horqrux/apply.py index f112b56..3bf3032 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -2,7 +2,7 @@ from functools import partial, reduce, singledispatch from operator import add -from typing import Iterable +from typing import Any, Iterable import jax import jax.numpy as jnp @@ -15,6 +15,7 @@ from .utils import ( DensityMatrix, OperationType, + State, _controlled, _dagger, density_mat, @@ -25,7 +26,7 @@ @singledispatch def apply_operator( - state: Array, + state: Any, operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], @@ -33,7 +34,7 @@ def apply_operator( """Apply an operator on a state or density matrix. Args: - state (Array): Array to operate on. + state (Any): Array to operate on. operator (Array): Array to contract over 'state'. target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. control (tuple[int | None, ...]): tuple of control qubits. @@ -99,7 +100,7 @@ def _( In case of a controlled operation, the 'operator' array will be embedded into a controlled array. Args: - state (Array): Array to operate on. + state (DensityMatrix): Array to operate on. operator (Array): Array to contract over 'state'. target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. control (tuple[int | None, ...]): tuple of control qubits. @@ -145,7 +146,7 @@ def apply_kraus_operator( target (tuple[int, ...]): Target qubits. Returns: - Array: Output density matrix. + Array: K \\rho K^\\dagger. """ state_dims: tuple[int, ...] = target n_qubits = int(np.log2(kraus.size)) @@ -163,7 +164,7 @@ def apply_kraus_operator( def apply_kraus_sum( - kraus_ops: NoiseProtocol, + kraus_ops: Array, array: Array, target: tuple[int, ...], ) -> DensityMatrix: @@ -172,8 +173,8 @@ def apply_kraus_sum( S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger Args: - noise (NoiseProtocol): Noise containing the K_i - state (Array): Input array + kraus_ops (Array): Stacked K_i. + state (Array): Input array. target (tuple[int, ...]): Qubits the operator is defined on. Returns: @@ -198,7 +199,7 @@ def apply_operator_with_noise( target: tuple[int, ...], control: tuple[int | None, ...], noise: NoiseProtocol, -) -> Array | DensityMatrix: +) -> State: """Evolves the input state and applies a noisy quantum channel on the evolved state :math:`\rho`. @@ -207,7 +208,7 @@ def apply_operator_with_noise( S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger, Args: - state (Array | DensityMatrix): Input state or density matrix. + state (State): Input state or density matrix. operator (Array): Operator to apply. target (tuple[int, ...]): Target qubits. control (tuple[int | None, ...]): Control qubits. @@ -255,9 +256,9 @@ def merge_operators( operators: The arrays representing the unitaries to be merged. targets: The corresponding target qubits. controls: The corresponding control qubits. + Returns: A tuple of merged operators, targets and controls. - """ if len(operators) < 2: return operators, targets, controls @@ -281,36 +282,37 @@ def merge_operators( @singledispatch def apply_gate( - state: Array, + state: Any, gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, -) -> Array | DensityMatrix: +) -> State: raise NotImplementedError("apply_gate is not implemented") -@apply_gate.register -def _( - state: Array, +def prepare_sequence_reduce( gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, -) -> Array | DensityMatrix: - """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. - Arguments: - state: Array or DensityMatrix to operate on. - gate: Gate(s) to apply. - values: A dictionary with parameter values. - op_type: The type of operation to perform: Unitary, Dagger or Jacobian. - group_gates: Group gates together which are acting on the same qubit. - merge_ops: Attempt to merge operators acting on the same qubit. +) -> tuple[tuple[Array, ...], tuple, tuple, list[NoiseProtocol]]: + """Prepare the tuples to be used when applying operations. + + Args: + gate (Primitive | Iterable[Primitive]): Gate(s) to apply. + values (dict[str, float], optional): A dictionary with parameter values. + Defaults to dict(). + op_type (OperationType, optional): The type of operation to perform: Unitary, Dagger or Jacobian. + Defaults to OperationType.UNITARY. + group_gates (bool, optional): Group gates together which are acting on the same qubit. + Defaults to False. Returns: - Array or density matrix after applying 'gate'. + tuple[tuple[Array, ...], tuple, tuple, list[NoiseProtocol]]: Operators, targets, + controls and noise. """ operator: tuple[Array, ...] noise = list() @@ -328,6 +330,34 @@ def _( operator, target, control = merge_operators(operator, target, control) noise = [g.noise for g in gate] + return operator, target, control, noise + + +@apply_gate.register +def _( + state: Array, + gate: Primitive | Iterable[Primitive], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> State: + """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. + Arguments: + state: Array or DensityMatrix to operate on. + gate: Gate(s) to apply. + values: A dictionary with parameter values. + op_type: The type of operation to perform: Unitary, Dagger or Jacobian. + group_gates: Group gates together which are acting on the same qubit. + merge_ops: Attempt to merge operators acting on the same qubit. + + Returns: + Array or density matrix after applying 'gate'. + """ + operator, target, control, noise = prepare_sequence_reduce( + gate, values, op_type, group_gates, merge_ops + ) + # faster way to check has_noise has_noise = noise != [None] * len(noise) if has_noise: @@ -368,21 +398,9 @@ def _( Returns: Array or density matrix after applying 'gate'. """ - operator: tuple[Array, ...] - noise = list() - if isinstance(gate, Primitive): - operator_fn = getattr(gate, op_type) - operator, target, control = (operator_fn(values),), gate.target, gate.control - noise += [gate.noise] - else: - if group_gates: - gate = group_by_index(gate) - operator = tuple(getattr(g, op_type)(values) for g in gate) - target = reduce(add, [g.target for g in gate]) - control = reduce(add, [g.control for g in gate]) - if merge_ops: - operator, target, control = merge_operators(operator, target, control) - noise = [g.noise for g in gate] + operator, target, control, noise = prepare_sequence_reduce( + gate, values, op_type, group_gates, merge_ops + ) output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), zip(operator, target, control, noise), diff --git a/horqrux/shots.py b/horqrux/shots.py index 1f2911e..508cebe 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -10,7 +10,7 @@ from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive -from horqrux.utils import DensityMatrix, none_like +from horqrux.utils import DensityMatrix, State, none_like def observable_to_matrix( @@ -38,7 +38,7 @@ def observable_to_matrix( @singledispatch -def probs_from_eigenvectors_state(state: Array, eigvecs: Array) -> Array: +def probs_from_eigenvectors_state(state: Any, eigvecs: Array) -> Array: """Obtain the probabilities using an input state and the eigenvectors decomposition of an observable. @@ -76,7 +76,7 @@ def _(state: DensityMatrix, eigvecs: Array) -> Array: of an observable. Args: - state (DensityMatrix): Input array. + state (DensityMatrix): Input density matrix. eigvecs (Array): Eigenvectors of the observables. Returns: @@ -87,7 +87,7 @@ def _(state: DensityMatrix, eigvecs: Array) -> Array: def eigenval_decomposition_sampling( - state: Array | DensityMatrix, + state: State, observables: list[Primitive], values: dict[str, float], n_qubits: int, @@ -103,7 +103,7 @@ def eigenval_decomposition_sampling( @partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) def finite_shots_fwd( - state: Array | DensityMatrix, + state: State, gates: GateSequence, observables: list[Primitive], values: dict[str, float], diff --git a/horqrux/utils.py b/horqrux/utils.py index cd24193..457fec0 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -18,7 +18,6 @@ default_dtype = default_complex_dtype() -State = ArrayLike QubitSupport = tuple[Any, ...] ControlQubits = tuple[Union[None, tuple[int, ...]], ...] TargetQubits = tuple[tuple[int, ...], ...] @@ -42,7 +41,10 @@ def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) -def density_mat(state: Array | DensityMatrix) -> DensityMatrix: +State = Union[ArrayLike, DensityMatrix] + + +def density_mat(state: State) -> DensityMatrix: """Convert state to density matrix Args: From c855a9776dbaf54ffc940388e6db58cf8584a3c0 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 15:38:02 +0100 Subject: [PATCH 52/57] update State typing --- horqrux/api.py | 2 +- horqrux/apply.py | 4 ++-- horqrux/shots.py | 2 +- horqrux/utils.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/horqrux/api.py b/horqrux/api.py index 96bb308..7d68636 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -73,7 +73,7 @@ def __ad_expectation_single_observable( state: Any, observable: Primitive, values: dict[str, float], -) -> Array: +) -> Any: raise NotImplementedError("__ad_expectation_single_observable is not implemented") diff --git a/horqrux/apply.py b/horqrux/apply.py index 3bf3032..efb8cfc 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -30,7 +30,7 @@ def apply_operator( operator: Array, target: tuple[int, ...], control: tuple[int | None, ...], -) -> Array: +) -> Any: """Apply an operator on a state or density matrix. Args: @@ -288,7 +288,7 @@ def apply_gate( op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, -) -> State: +) -> Any: raise NotImplementedError("apply_gate is not implemented") diff --git a/horqrux/shots.py b/horqrux/shots.py index 508cebe..5dc69fa 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -43,7 +43,7 @@ def probs_from_eigenvectors_state(state: Any, eigvecs: Array) -> Array: of an observable. Args: - state (Array): Input array. + state (Any): Input. eigvecs (Array): Eigenvectors of the observables. Returns: diff --git a/horqrux/utils.py b/horqrux/utils.py index 457fec0..5b7f79c 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -44,14 +44,14 @@ def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: State = Union[ArrayLike, DensityMatrix] -def density_mat(state: State) -> DensityMatrix: +def density_mat(state: ArrayLike) -> DensityMatrix: """Convert state to density matrix Args: - state (State): Input state. + state (ArrayLike): Input state. Returns: - State: Density matrix representation. + DensityMatrix: Density matrix representation. """ # Expand dimensions to enable broadcasting if isinstance(state, DensityMatrix): From 4f45580bb38106f23a6b19c3df5f1a7db89af735 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 15:40:14 +0100 Subject: [PATCH 53/57] update typing primitive noiseprotocol --- horqrux/primitive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 2507fc9..c8c3948 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -28,7 +28,7 @@ class Primitive: generator_name: str target: QubitSupport control: QubitSupport - noise: NoiseProtocol | None = None + noise: NoiseProtocol = None @staticmethod def parse_idx( From 45bce8447339c7c21836fb7e98cf1fd4836930bd Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 16:09:19 +0100 Subject: [PATCH 54/57] fix docs is_density --- docs/noise.md | 2 +- horqrux/api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/noise.md b/docs/noise.md index 369b8f4..e79f08c 100644 --- a/docs/noise.md +++ b/docs/noise.md @@ -95,7 +95,7 @@ noisy_ops = [X(0, noise=noise)] state = product_state("0") noiseless_samples = sample(state, ops) -noisy_samples = sample(density_mat(state), noisy_ops, is_density=True) +noisy_samples = sample(density_mat(state), noisy_ops) print("Noiseless samples", noiseless_samples) # markdown-exec: hide print("Noiseless samples", noisy_samples) # markdown-exec: hide ``` diff --git a/horqrux/api.py b/horqrux/api.py index 7d68636..781c558 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -62,7 +62,7 @@ def sample( d = 2**n_qubits output_circuit.array = output_circuit.array.reshape((d, d)) else: - n_qubits = len(output_circuit.array.shape) + n_qubits = len(output_circuit.shape) probs = get_probas(output_circuit) return sample_from_probs(probs, n_qubits, n_shots) From ff52facd0528cdcfe2d09bd24faa293b15d76b4f Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 16:54:07 +0100 Subject: [PATCH 55/57] Union instead of pipe --- horqrux/apply.py | 10 +++++----- horqrux/utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index efb8cfc..e18b435 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -2,7 +2,7 @@ from functools import partial, reduce, singledispatch from operator import add -from typing import Any, Iterable +from typing import Any, Iterable, Union import jax import jax.numpy as jnp @@ -29,7 +29,7 @@ def apply_operator( state: Any, operator: Array, target: tuple[int, ...], - control: tuple[int | None, ...], + control: tuple[Union[int, None], ...], ) -> Any: """Apply an operator on a state or density matrix. @@ -53,7 +53,7 @@ def _( state: Array, operator: Array, target: tuple[int, ...], - control: tuple[int | None, ...], + control: tuple[Union[int, None], ...], ) -> Array: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. @@ -93,7 +93,7 @@ def _( state: DensityMatrix, operator: Array, target: tuple[int, ...], - control: tuple[int | None, ...], + control: tuple[Union[int, None], ...], ) -> DensityMatrix: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. @@ -197,7 +197,7 @@ def apply_operator_with_noise( state: DensityMatrix, operator: Array, target: tuple[int, ...], - control: tuple[int | None, ...], + control: tuple[Union[int, None], ...], noise: NoiseProtocol, ) -> State: """Evolves the input state and applies a noisy quantum channel diff --git a/horqrux/utils.py b/horqrux/utils.py index 5b7f79c..e011e52 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -231,7 +231,7 @@ def uniform_state( return state.reshape([2] * n_qubits) -def is_controlled(qubit_support: tuple[int | None, ...] | int | None) -> bool: +def is_controlled(qubit_support: tuple[Union[int, None], ...] | Union[int, None]) -> bool: if isinstance(qubit_support, int): return True elif isinstance(qubit_support, tuple): From d8d7c0bb43b403631a53b26f64937f492b51b45c Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 17:03:09 +0100 Subject: [PATCH 56/57] change union pipe is_controlled --- horqrux/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horqrux/utils.py b/horqrux/utils.py index e011e52..012eed7 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -231,7 +231,7 @@ def uniform_state( return state.reshape([2] * n_qubits) -def is_controlled(qubit_support: tuple[Union[int, None], ...] | Union[int, None]) -> bool: +def is_controlled(qubit_support: Union[tuple[Union[int, None], ...], int, None]) -> bool: if isinstance(qubit_support, int): return True elif isinstance(qubit_support, tuple): From 015697c9fb0ca99d70601b18f92e92055dae4852 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 18 Dec 2024 17:06:55 +0100 Subject: [PATCH 57/57] union of primitive --- horqrux/apply.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/horqrux/apply.py b/horqrux/apply.py index e18b435..3a3e04d 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -336,7 +336,7 @@ def prepare_sequence_reduce( @apply_gate.register def _( state: Array, - gate: Primitive | Iterable[Primitive], + gate: Union[Primitive, Iterable[Primitive]], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution @@ -380,7 +380,7 @@ def _( @apply_gate.register def _( state: DensityMatrix, - gate: Primitive | Iterable[Primitive], + gate: Union[Primitive, Iterable[Primitive]], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution