Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Promote operators/gates #155

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@

from numpy import array
from numpy.typing import NDArray
from torch import einsum

from pyqtorch.utils import Operator, State
from torch import Tensor, einsum

ABC_ARRAY: NDArray = array(list(ABC))


def apply_operator(
state: State,
operator: Operator,
state: Tensor,
operator: Tensor,
qubits: Tuple[int, ...] | list[int],
n_qubits: int = None,
batch_size: int = None,
) -> State:
) -> Tensor:
"""Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits.

Expand Down
13 changes: 7 additions & 6 deletions pyqtorch/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Any, Tuple

import torch
from torch import Tensor

from pyqtorch.apply import apply_operator
from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger
from pyqtorch.utils import Operator, State, product_state
from pyqtorch.utils import product_state
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved


class Primitive(torch.nn.Module):
def __init__(self, pauli: torch.Tensor, target: int) -> None:
def __init__(self, pauli: Tensor, target: int) -> None:
super().__init__()
self.target: int = target
self.qubit_support: Tuple[int, ...] = (target,)
Expand Down Expand Up @@ -40,15 +41,15 @@ def extra_repr(self) -> str:
def param_type(self) -> None:
return self._param_type

def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def unitary(self, values: dict[str, Tensor] | Tensor = {}) -> Tensor:
return self.pauli.unsqueeze(2)

def forward(self, state: State, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] | Tensor = {}) -> Tensor:
return apply_operator(
state, self.unitary(values), self.qubit_support, len(state.size()) - 1
)

def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def dagger(self, values: dict[str, Tensor] | Tensor = {}) -> Tensor:
return _dagger(self.unitary(values))

@property
Expand Down Expand Up @@ -85,7 +86,7 @@ class I(Primitive): # noqa: E742
def __init__(self, target: int):
super().__init__(OPERATIONS_DICT["I"], target)

def forward(self, state: State, values: dict[str, torch.Tensor] = None) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] = None) -> Tensor:
return state


Expand Down
32 changes: 32 additions & 0 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,35 @@ def density_mat(state: Tensor) -> Tensor:
state = torch.permute(state, batch_first_perm).reshape(batch_size, 2**n_qubits)
undo_perm = (1, 2, 0)
return torch.permute(torch.einsum("bi,bj->bij", (state, state.conj())), undo_perm)


def promote_op(operator: Tensor, target: int, n_qubits: int) -> Tensor:
from pyqtorch.primitive import I

"""
Promotes `operator` to the size of the circuit (number of qubits and batch).
Targeting the first qubit implies target = 0, so target > n_qubits - 1.

Args:
operator (Tensor): The operator tensor to be promoted.
target (int): The index of the target qubit to which the operator is applied.
Targeting the first qubit implies target = 0, so target > n_qubits - 1.
n_qubits (int): Number of qubits in the circuit.

Returns:
Tensor: The promoted operator tensor.

Raises:
ValueError: If `target` is outside the valid range of qubits.
"""
if target > n_qubits - 1:
raise ValueError("The target must be a valid qubit index within the circuit's range.")
qubits = torch.arange(0, n_qubits)
qubits = qubits[qubits != target]
for qubit in qubits:
operator = torch.where(
target > qubit,
torch.kron(I(target).unitary(), operator.contiguous()),
torch.kron(operator.contiguous(), I(target).unitary()),
)
return operator
24 changes: 21 additions & 3 deletions tests/test_digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from torch import Tensor

import pyqtorch as pyq
from pyqtorch.apply import apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT
from pyqtorch.apply import apply_ope_ope, apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT, _dagger
from pyqtorch.parametric import Parametric
from pyqtorch.utils import ATOL, density_mat, product_state, random_state
from pyqtorch.primitive import H, I, S, T, X, Y, Z
from pyqtorch.utils import ATOL, density_mat, product_state, promote_op, random_state

state_000 = product_state("000")
state_001 = product_state("001")
Expand Down Expand Up @@ -322,3 +323,20 @@ def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None:
dm = density_mat(state_cat)
assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(dm, dm_proj)


size = (5, 2)
random_param = torch.randperm(size[0] * size[1])
random_param = random_param.reshape(size)
random_param = torch.sort(random_param, dim=1)[0]


@pytest.mark.parametrize("target,n_qubits", random_param)
@pytest.mark.parametrize("operator", [I, X, Y, Z, H, T, S])
def test_promote(target: int, n_qubits: int, operator: Tensor) -> None:
op_prom = promote_op(operator(target).unitary(), target, n_qubits)
assert op_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
assert torch.allclose(
apply_ope_ope(op_prom, _dagger(op_prom), target),
torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2),
)
Loading