Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Operator Product #156

Merged
merged 27 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from __future__ import annotations

from string import ascii_letters as ABC
from typing import Tuple

from numpy import array
import torch
from numpy import array, log2
from numpy.typing import NDArray
from torch import einsum
from torch import Tensor, einsum

from pyqtorch.utils import Operator, State
from pyqtorch.utils import batch_first, batch_last, promote_operator

ABC_ARRAY: NDArray = array(list(ABC))


def apply_operator(
state: State,
operator: Operator,
qubits: Tuple[int, ...] | list[int],
state: Tensor,
operator: Tensor,
qubits: tuple[int, ...] | list[int],
n_qubits: int = None,
batch_size: int = None,
) -> State:
) -> Tensor:
"""Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits.

Expand Down Expand Up @@ -57,3 +57,32 @@ def apply_operator(
map(lambda e: "".join(list(e)), [operator_dims, in_state_dims, out_state_dims])
)
return einsum(f"{operator_dims},{in_state_dims}->{out_state_dims}", operator, state)


def operator_product(op1: Tensor, op2: Tensor, target: int) -> Tensor:
"""
Compute the product of two operators.

Args:
op1 (Tensor): The first operator.
op2 (Tensor): The second operator.
target (int): The target qubit index.

Returns:
Tensor: The product of the two operators.
"""

n_qubits_1 = int(log2(op1.size(1)))
n_qubits_2 = int(log2(op2.size(1)))
batch_size_1 = op1.size(-1)
batch_size_2 = op2.size(-1)
if n_qubits_1 > n_qubits_2:
op2 = promote_operator(op2, target, n_qubits_1)
elif n_qubits_1 < n_qubits_2:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
op1 = promote_operator(op1, target, n_qubits_2)
if batch_size_1 > batch_size_2:
op2 = op2.repeat(1, 1, batch_size_1)[:, :, :batch_size_1]
elif batch_size_2 > batch_size_1:
op1 = op1.repeat(1, 1, batch_size_2)[:, :, :batch_size_2]
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved

return batch_last(torch.bmm(batch_first(op1), batch_first(op2)))
29 changes: 15 additions & 14 deletions pyqtorch/primitive.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import annotations

from math import log2
from typing import Any, Tuple
from typing import Any

import torch
from torch import Tensor

from pyqtorch.apply import apply_operator
from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger
from pyqtorch.utils import Operator, State, product_state
from pyqtorch.utils import product_state


class Primitive(torch.nn.Module):
def __init__(self, pauli: torch.Tensor, target: int) -> None:
def __init__(self, pauli: Tensor, target: int) -> None:
super().__init__()
self.target: int = target
self.qubit_support: Tuple[int, ...] = (target,)
self.qubit_support: tuple[int, ...] = (target,)
self.n_qubits: int = max(self.qubit_support)
self.register_buffer("pauli", pauli)
self._param_type = None
Expand All @@ -40,15 +41,15 @@ def extra_repr(self) -> str:
def param_type(self) -> None:
return self._param_type

def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def unitary(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return self.pauli.unsqueeze(2)

def forward(self, state: State, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return apply_operator(
state, self.unitary(values), self.qubit_support, len(state.size()) - 1
)

def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def dagger(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor:
return _dagger(self.unitary(values))

@property
Expand Down Expand Up @@ -85,7 +86,7 @@ class I(Primitive): # noqa: E742
def __init__(self, target: int):
super().__init__(OPERATIONS_DICT["I"], target)

def forward(self, state: State, values: dict[str, torch.Tensor] = None) -> State:
def forward(self, state: Tensor, values: dict[str, Tensor] = None) -> Tensor:
return state


Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(self, control: int, target: int):


class CSWAP(Primitive):
def __init__(self, control: int | Tuple[int, ...], target: int):
def __init__(self, control: int | tuple[int, ...], target: int):
super().__init__(OPERATIONS_DICT["CSWAP"], target)
self.control = (control,) if isinstance(control, int) else control
self.target = target
Expand All @@ -144,7 +145,7 @@ def __init__(self, control: int | Tuple[int, ...], target: int):


class ControlledOperationGate(Primitive):
def __init__(self, gate: str, control: int | Tuple[int, ...], target: int):
def __init__(self, gate: str, control: int | tuple[int, ...], target: int):
self.control = (control,) if isinstance(control, int) else control
mat = OPERATIONS_DICT[gate]
mat = _controlled(
Expand All @@ -158,23 +159,23 @@ def __init__(self, gate: str, control: int | Tuple[int, ...], target: int):


class CNOT(ControlledOperationGate):
def __init__(self, control: int | Tuple[int, ...], target: int):
def __init__(self, control: int | tuple[int, ...], target: int):
super().__init__("X", control, target)


CX = CNOT


class CY(ControlledOperationGate):
def __init__(self, control: int | Tuple[int, ...], target: int):
def __init__(self, control: int | tuple[int, ...], target: int):
super().__init__("Y", control, target)


class CZ(ControlledOperationGate):
def __init__(self, control: int | Tuple[int, ...], target: int):
def __init__(self, control: int | tuple[int, ...], target: int):
super().__init__("Z", control, target)


class Toffoli(ControlledOperationGate):
def __init__(self, control: int | Tuple[int, ...], target: int):
def __init__(self, control: int | tuple[int, ...], target: int):
super().__init__("X", control, target)
60 changes: 60 additions & 0 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,63 @@ def density_mat(state: Tensor) -> Tensor:
state = torch.permute(state, batch_first_perm).reshape(batch_size, 2**n_qubits)
undo_perm = (1, 2, 0)
return torch.permute(torch.einsum("bi,bj->bij", (state, state.conj())), undo_perm)


def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor:
from pyqtorch.primitive import I

"""
Promotes `operator` to the size of the circuit (number of qubits and batch).
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
Targeting the first qubit implies target = 0, so target > n_qubits - 1.

Args:
operator (Tensor): The operator tensor to be promoted.
target (int): The index of the target qubit to which the operator is applied.
Targeting the first qubit implies target = 0, so target > n_qubits - 1.
n_qubits (int): Number of qubits in the circuit.

Returns:
Tensor: The promoted operator tensor.

Raises:
ValueError: If `target` is outside the valid range of qubits.
"""
if target > n_qubits - 1:
raise ValueError("The target must be a valid qubit index within the circuit's range.")
qubits = torch.arange(0, n_qubits)
qubits = qubits[qubits != target]
for qubit in qubits:
operator = torch.where(
target > qubit,
torch.kron(I(target).unitary(), operator.contiguous()),
torch.kron(operator.contiguous(), I(target).unitary()),
)
return operator


def batch_first(operator: Tensor) -> Tensor:
"""
Permute the operator's batch dimension on first dimension.

Args:
operator (Tensor): Operator in size [2**n_qubits, 2**n_qubits,batch_size].

Returns:
Tensor: Operator in size [batch_size, 2**n_qubits, 2**n_qubits].
"""
batch_first_perm = (2, 0, 1)
return torch.permute(operator, batch_first_perm)


def batch_last(operator: Tensor) -> Tensor:
"""
Permute the operator's batch dimension on last dimension.

Args:
operator (Tensor): Operator in size [batch_size,2**n_qubits, 2**n_qubits].

Returns:
Tensor: Operator in size [2**n_qubits, 2**n_qubits,batch_size].
"""
undo_perm = (1, 2, 0)
return torch.permute(operator, undo_perm)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from typing import Any

import pytest

from pyqtorch.primitive import H, I, Primitive, S, T, X, Y, Z


@pytest.fixture(params=[I, X, Y, Z, H, T, S])
def gate(request: Primitive) -> Any:
return request.param
42 changes: 31 additions & 11 deletions tests/test_digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from torch import Tensor

import pyqtorch as pyq
from pyqtorch.apply import apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT
from pyqtorch.apply import apply_operator, operator_product
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT, _dagger
from pyqtorch.parametric import Parametric
from pyqtorch.utils import ATOL, density_mat, product_state, random_state
from pyqtorch.primitive import Primitive
from pyqtorch.utils import ATOL, density_mat, product_state, promote_operator, random_state

state_000 = product_state("000")
state_001 = product_state("001")
Expand Down Expand Up @@ -293,32 +294,51 @@ def test_U() -> None:

@pytest.mark.parametrize("n_qubits,batch_size", torch.randint(1, 6, (8, 2)))
def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None:
# Test without batches:
state = random_state(n_qubits)
projector = torch.outer(state.flatten(), state.conj().flatten()).view(
2**n_qubits, 2**n_qubits, 1
)
dm = density_mat(state)
assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
assert torch.allclose(dm, projector)

# Test with batches:
states = []
projectors = []
# Batches creation:
for batch in range(batch_size):
# Batch state creation:
state = random_state(n_qubits)
states.append(state)
# Batch projector:
projector = torch.outer(state.flatten(), state.conj().flatten()).view(
2**n_qubits, 2**n_qubits, 1
)
projectors.append(projector)
# Concatenate all the batch projectors:
dm_proj = torch.cat(projectors, dim=2)
# Concatenate the batch state to compute the density matrix
state_cat = torch.cat(states, dim=n_qubits)
dm = density_mat(state_cat)
assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(dm, dm_proj)


def test_promote(gate: Primitive) -> None:
n_qubits = torch.randint(low=1, high=8, size=(1,)).item()
target = random.choice([i for i in range(n_qubits)])
op_prom = promote_operator(gate(target).unitary(), target, n_qubits)
assert op_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
assert torch.allclose(
operator_product(op_prom, _dagger(op_prom), target),
torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2),
)


def test_operator_product(gate: Primitive) -> None:
n_qubits = torch.randint(low=1, high=8, size=(1,)).item()
target = random.choice([i for i in range(n_qubits)])
batch_size_1 = torch.randint(low=1, high=5, size=(1,)).item()
batch_size_2 = torch.randint(low=1, high=5, size=(1,)).item()
max_batch = max(batch_size_2, batch_size_1)
op_prom = promote_operator(gate(target).unitary(), target, n_qubits).repeat(1, 1, batch_size_1)
op_mul = operator_product(
gate(target).unitary().repeat(1, 1, batch_size_2), _dagger(op_prom), target
)
assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, max_batch])
assert torch.allclose(
op_mul, torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2).repeat(1, 1, max_batch)
)
Loading