Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature | Performance] Add circuit module, Merge gates acting on same qubits #14

Merged
merged 13 commits into from
Apr 16, 2024
12 changes: 6 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ class TotalMagnetization:
def __post_init__(self) -> None:
self.paulis = [Z(i) for i in range(self.n_qubits)]

def __call__(self, state: Array, values: dict) -> Array:
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])
def __call__(self, out_state: Array, values: dict) -> Array:
projected_state = reduce(add, [apply_gate(out_state, pauli, values) for pauli in self.paulis])
return inner(out_state, projected_state).real


@dataclass
Expand All @@ -284,15 +285,14 @@ class Circuit:
]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable = TotalMagnetization(self.n_qubits)
self.state = zero_state(self.n_qubits)

def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
out_state = apply_gate(
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
self.state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
)
projected_state = self.observable(state, param_dict)
return jnp.real(inner(out_state, projected_state))
return self.observable(out_state, {})

@property
def n_vparams(self) -> int:
Expand Down
12 changes: 8 additions & 4 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Tuple

from jax import Array, custom_vjp
from jax.numpy import real as jnpreal

from horqrux.apply import apply_gate
from horqrux.parametric import Parametric
Expand All @@ -14,15 +13,20 @@
def expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return jnpreal(inner(out_state, projected_state))
return inner(out_state, projected_state).real


@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""Custom vector-jacobian product to compute gradients
in O(P) time using O(1) state vectors via Algorithm 1 in https://arxiv.org/abs/2009.02823."""
return expectation(state, gates, observable, values)


Expand All @@ -31,7 +35,7 @@ def adjoint_expectation_fwd(
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values)
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)


def adjoint_expectation_bwd(
Expand All @@ -43,7 +47,7 @@ def adjoint_expectation_bwd(
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
if isinstance(gate, Parametric):
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state))
grads[gate.param] = tangent * 2 * inner(mu, projected_state).real
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
return (None, None, None, grads)

Expand Down
54 changes: 54 additions & 0 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,58 @@ def apply_operator(
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)


def group_by_index(gates: Iterable[Primitive]) -> Iterable[Primitive]:
sorted_gates = []
gate_batch = []
for gate in gates:
if not is_controlled(gate.control):
gate_batch.append(gate)
else:
if len(gate_batch) > 0:
gate_batch.sort(key=lambda g: g.target)
sorted_gates += gate_batch
gate_batch = []
sorted_gates.append(gate)
if len(gate_batch) > 0:
gate_batch.sort(key=lambda g: g.target)
sorted_gates += gate_batch
return sorted_gates


def merge_operators(
operators: tuple[Array, ...], targets: tuple[int, ...], controls: tuple[int, ...]
) -> tuple[tuple[Array, ...], tuple[int, ...], tuple[int, ...]]:
"""
If possible, merge several gates acting on the same qubits into a single tensordot operation.

Arguments:
operators: The arrays representing the unitaries to be merged.
targets: The corresponding target qubits.
controls: The corresponding control qubits.
Returns:
A tuple of merged operators, targets and controls.

"""
if len(operators) < 2:
return operators, targets, controls
operators, targets, controls = operators[::-1], targets[::-1], controls[::-1]
merged_operator, merged_target, merged_control = operators[0], targets[0], controls[0]
merged_operators = merged_targets = merged_controls = tuple() # type: ignore[var-annotated]
for operator, target, control in zip(operators[1:], targets[1:], controls[1:]):
if target == merged_target and control == merged_control:
merged_operator = merged_operator @ operator
else:
merged_operators += (merged_operator,)
merged_targets += (merged_target,)
merged_controls += (merged_control,)
merged_operator, merged_target, merged_control = operator, target, control
if merged_operator is not None:
merged_operators += (merged_operator,)
merged_targets += (merged_target,)
merged_controls += (merged_control,)
return merged_operators[::-1], merged_targets[::-1], merged_controls[::-1]


def apply_gate(
state: State,
gate: Primitive | Iterable[Primitive],
Expand All @@ -72,9 +124,11 @@ def apply_gate(
operator_fn = getattr(gate, op_type)
operator, target, control = (operator_fn(values),), gate.target, gate.control
else:
gate = group_by_index(gate)
operator = tuple(getattr(g, op_type)(values) for g in gate)
target = reduce(add, [g.target for g in gate])
control = reduce(add, [g.control for g in gate])
operator, target, control = merge_operators(operator, target, control)
return reduce(
lambda state, gate: apply_operator(state, *gate),
zip(operator, target, control),
Expand Down
10 changes: 5 additions & 5 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def inner(state: Array, projection: Array) -> Array:


def overlap(state: Array, projection: Array) -> Array:
return jnp.real(jnp.power(inner(state, projection), 2))
return jnp.power(inner(state, projection), 2).real


def uniform_state(
Expand All @@ -134,11 +134,11 @@ def uniform_state(
return state.reshape([2] * n_qubits)


def is_controlled(qs: Tuple[int | None, ...] | int | None) -> bool:
if isinstance(qs, int):
def is_controlled(qubit_support: Tuple[int | None, ...] | int | None) -> bool:
if isinstance(qubit_support, int):
return True
elif isinstance(qs, tuple):
return any(is_controlled(q) for q in qs)
elif isinstance(qubit_support, tuple):
return any(is_controlled(q) for q in qubit_support)
return False


Expand Down
18 changes: 17 additions & 1 deletion tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from jax import Array

from horqrux.apply import apply_gate, apply_operator
from horqrux.apply import apply_gate, apply_operator, group_by_index, merge_operators
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import equivalent_state, product_state, random_state
Expand Down Expand Up @@ -102,3 +102,19 @@ def test_swap_gate(inputs: tuple[str, str, Array]) -> None:
state = product_state(bitstring)
out_state = apply_gate(state, op)
assert equivalent_state(out_state, product_state(expected_bitstring))


def test_merge_gates() -> None:
gates = [RX("a", 0), RZ("b", 1), RY("c", 0)]
gates = group_by_index(gates)
values = {
"a": np.random.uniform(0.1, 2 * np.pi),
"b": np.random.uniform(0.1, 2 * np.pi),
"c": np.random.uniform(0.1, 2 * np.pi),
}
op, trgt, ctrl = merge_operators(
tuple(g.unitary(values) for g in gates),
tuple(g.target for g in gates),
tuple(g.control for g in gates),
)
assert len(op) == 2
Loading