Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Operator Product #156

Merged
merged 27 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from __future__ import annotations

from math import log2
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
from string import ascii_letters as ABC
from typing import Tuple

import torch
from numpy import array
from numpy.typing import NDArray
from torch import einsum
from torch import Tensor, einsum

from pyqtorch.utils import Operator, State
from pyqtorch.utils import batch_first, batch_last, promote_op

ABC_ARRAY: NDArray = array(list(ABC))


def apply_operator(
state: State,
operator: Operator,
state: Tensor,
operator: Tensor,
qubits: Tuple[int, ...] | list[int],
n_qubits: int = None,
batch_size: int = None,
) -> State:
) -> Tensor:
"""Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits.

Expand Down Expand Up @@ -57,3 +59,32 @@ def apply_operator(
map(lambda e: "".join(list(e)), [operator_dims, in_state_dims, out_state_dims])
)
return einsum(f"{operator_dims},{in_state_dims}->{out_state_dims}", operator, state)


def apply_op_op(operator_1: Tensor, operator_2: Tensor, target: int) -> Tensor:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the product of two operators.

Args:
operator_1 (Tensor): The first operator.
operator_2 (Tensor): The second operator.
target (int): The target qubit index.

Returns:
Tensor: The product of the two operators.
"""

n_qubits_1 = int(log2(operator_1.size(1)))
n_qubits_2 = int(log2(operator_2.size(1)))
batch_size_1 = operator_1.size(-1)
batch_size_2 = operator_2.size(-1)
if n_qubits_1 > n_qubits_2:
operator_2 = promote_op(operator_2, target, n_qubits_1)
if n_qubits_1 < n_qubits_2:
operator_1 = promote_op(operator_1, target, n_qubits_2)
if batch_size_1 > batch_size_2:
operator_2 = operator_2.repeat(1, 1, batch_size_1)
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
if batch_size_2 > batch_size_1:
operator_1 = operator_1.repeat(1, 1, batch_size_2)

return batch_last(torch.bmm(batch_first(operator_1), batch_first(operator_2)))
13 changes: 7 additions & 6 deletions pyqtorch/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Any, Tuple

import torch
from torch import Tensor

from pyqtorch.apply import apply_operator
from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger
from pyqtorch.utils import Operator, State, product_state
from pyqtorch.utils import product_state


class Primitive(torch.nn.Module):
def __init__(self, pauli: torch.Tensor, target: int) -> None:
def __init__(self, pauli: Tensor, target: int) -> None:
super().__init__()
self.target: int = target
self.qubit_support: Tuple[int, ...] = (target,)
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -40,15 +41,15 @@ def extra_repr(self) -> str:
def param_type(self) -> None:
return self._param_type

def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def unitary(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return self.pauli.unsqueeze(2)

def forward(self, state: State, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return apply_operator(
state, self.unitary(values), self.qubit_support, len(state.size()) - 1
)

def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def dagger(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return _dagger(self.unitary(values))

@property
Expand Down Expand Up @@ -85,7 +86,7 @@ class I(Primitive): # noqa: E742
def __init__(self, target: int):
super().__init__(OPERATIONS_DICT["I"], target)

def forward(self, state: State, values: dict[str, torch.Tensor] = None) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] = None) -> Tensor:
return state


Expand Down
60 changes: 60 additions & 0 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,63 @@ def density_mat(state: Tensor) -> Tensor:
state = torch.permute(state, batch_first_perm).reshape(batch_size, 2**n_qubits)
undo_perm = (1, 2, 0)
return torch.permute(torch.einsum("bi,bj->bij", (state, state.conj())), undo_perm)


def promote_op(operator: Tensor, target: int, n_qubits: int) -> Tensor:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
from pyqtorch.primitive import I

"""
Promotes `operator` to the size of the circuit (number of qubits and batch).
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
Targeting the first qubit implies target = 0, so target > n_qubits - 1.

Args:
operator (Tensor): The operator tensor to be promoted.
target (int): The index of the target qubit to which the operator is applied.
Targeting the first qubit implies target = 0, so target > n_qubits - 1.
n_qubits (int): Number of qubits in the circuit.

Returns:
Tensor: The promoted operator tensor.

Raises:
ValueError: If `target` is outside the valid range of qubits.
"""
if target > n_qubits - 1:
raise ValueError("The target must be a valid qubit index within the circuit's range.")
qubits = torch.arange(0, n_qubits)
qubits = qubits[qubits != target]
for qubit in qubits:
operator = torch.where(
target > qubit,
torch.kron(I(target).unitary(), operator.contiguous()),
torch.kron(operator.contiguous(), I(target).unitary()),
)
return operator


def batch_first(operator: Tensor) -> Tensor:
"""
Permute the operator's batch dimension on first dimension.

Args:
operator (Tensor): Operator in size [2**n_qubits, 2**n_qubits,batch_size].

Returns:
Tensor: Operator in size [batch_size, 2**n_qubits, 2**n_qubits].
"""
batch_first_perm = (2, 0, 1)
return torch.permute(operator, batch_first_perm)


def batch_last(operator: Tensor) -> Tensor:
"""
Permute the operator's batch dimension on last dimension.

Args:
operator (Tensor): Operator in size [batch_size,2**n_qubits, 2**n_qubits].

Returns:
Tensor: Operator in size [2**n_qubits, 2**n_qubits,batch_size].
"""
undo_perm = (1, 2, 0)
return torch.permute(operator, undo_perm)
43 changes: 40 additions & 3 deletions tests/test_digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from torch import Tensor

import pyqtorch as pyq
from pyqtorch.apply import apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT
from pyqtorch.apply import apply_op_op, apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT, _dagger
from pyqtorch.parametric import Parametric
from pyqtorch.utils import ATOL, density_mat, product_state, random_state
from pyqtorch.primitive import H, I, S, T, X, Y, Z
from pyqtorch.utils import ATOL, density_mat, product_state, promote_op, random_state

state_000 = product_state("000")
state_001 = product_state("001")
Expand Down Expand Up @@ -322,3 +323,39 @@ def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None:
dm = density_mat(state_cat)
assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(dm, dm_proj)


size = (5, 2)
random_param = torch.randperm(size[0] * size[1])
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
random_param = random_param.reshape(size)
random_param = torch.sort(random_param, dim=1)[0]
GATESET = [I, X, Y, Z, H, T, S]


@pytest.mark.parametrize("target,n_qubits", random_param)
@pytest.mark.parametrize("operator", GATESET)
def test_promote(target: int, n_qubits: int, operator: Tensor) -> None:
op_prom = promote_op(operator(target).unitary(), target, n_qubits)
assert op_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
assert torch.allclose(
apply_op_op(op_prom, _dagger(op_prom), target),
torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2),
)


size = (3, 3)
random_param = torch.randperm(size[0] * size[1])
random_param = random_param.reshape(size)
random_param = torch.sort(random_param, dim=1)[0]


@pytest.mark.parametrize("target,n_qubits,batch_size", random_param)
@pytest.mark.parametrize("operator", GATESET)
def test_apply_op_op(target: int, n_qubits: int, batch_size: int, operator: Tensor) -> None:
op_prom: Tensor = promote_op(operator(target).unitary(), target, n_qubits)
op_prom = op_prom.repeat(1, 1, batch_size)
op_mul = apply_op_op(op_prom, _dagger(operator(target).unitary()), target)
assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(
op_mul, torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2).repeat(1, 1, batch_size)
)
Loading