Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Promote operators/gates #155

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@

from numpy import array
from numpy.typing import NDArray
from torch import einsum

from pyqtorch.utils import Operator, State
from torch import Tensor, einsum

ABC_ARRAY: NDArray = array(list(ABC))


def apply_operator(
state: State,
operator: Operator,
state: Tensor,
operator: Tensor,
qubits: Tuple[int, ...] | list[int],
n_qubits: int = None,
batch_size: int = None,
) -> State:
) -> Tensor:
"""Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits.

Expand Down
12 changes: 7 additions & 5 deletions pyqtorch/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyqtorch.apply import apply_operator
from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger
from pyqtorch.utils import Operator, State, product_state
from pyqtorch.utils import product_state
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved


class Primitive(torch.nn.Module):
Expand Down Expand Up @@ -40,15 +40,17 @@ def extra_repr(self) -> str:
def param_type(self) -> None:
return self._param_type

def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> torch.Tensor:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
return self.pauli.unsqueeze(2)

def forward(self, state: State, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> State:
def forward(
self, state: torch.Tensor, values: dict[str, torch.Tensor] | torch.Tensor = {}
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
return apply_operator(
state, self.unitary(values), self.qubit_support, len(state.size()) - 1
)

def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator:
def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> torch.Tensor:
return _dagger(self.unitary(values))

@property
Expand Down Expand Up @@ -85,7 +87,7 @@ class I(Primitive): # noqa: E742
def __init__(self, target: int):
super().__init__(OPERATIONS_DICT["I"], target)

def forward(self, state: State, values: dict[str, torch.Tensor] = None) -> State:
def forward(self, state: torch.Tensor, values: dict[str, torch.Tensor] = None) -> torch.Tensor:
return state


Expand Down
31 changes: 31 additions & 0 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, DEFAULT_REAL_DTYPE
from pyqtorch.primitive import I

State = Tensor
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
Operator = Tensor
Expand Down Expand Up @@ -142,3 +143,33 @@ def density_mat(state: Tensor) -> Tensor:
state = torch.permute(state, batch_first_perm).reshape(batch_size, 2**n_qubits)
undo_perm = (1, 2, 0)
return torch.permute(torch.einsum("bi,bj->bij", (state, state.conj())), undo_perm)


def promote_ope(operator: Tensor, target: int, n_qubits: int) -> Tensor:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
"""
Promotes `operator` to the size of the circuit (number of qubits and batch).
Targeting the first qubit implies target = 0, so target > n_qubits - 1.

Args:
operator (Tensor): The operator tensor to be promoted.
target (int): The index of the target qubit to which the operator is applied.
Targeting the first qubit implies target = 0, so target > n_qubits - 1.
n_qubits (int): Number of qubits in the circuit.

Returns:
Tensor: The promoted operator tensor.

Raises:
ValueError: If `target` is outside the valid range of qubits.
"""
if target > n_qubits - 1:
raise ValueError("The target must be a valid qubit index within the circuit's range.")

qubits = torch.arange(0, n_qubits)
for qubit in qubits:
if target > qubit:
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
operator = torch.kron(I(qubit).unitary(), operator.contiguous())
# Add .contiguous() because kron does not support the transpose (dagger)
elif target < qubit:
operator = torch.kron(operator.contiguous(), I(qubit).unitary())
return operator
17 changes: 16 additions & 1 deletion tests/test_digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from pyqtorch.apply import apply_operator
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT
from pyqtorch.parametric import Parametric
from pyqtorch.utils import ATOL, density_mat, product_state, random_state
from pyqtorch.primitive import I, X
from pyqtorch.utils import ATOL, density_mat, product_state, promote_ope, random_state

state_000 = product_state("000")
state_001 = product_state("001")
Expand Down Expand Up @@ -322,3 +323,17 @@ def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None:
dm = density_mat(state_cat)
assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(dm, dm_proj)


size = (5, 2)
random_param = torch.randperm(size[0] * size[1])
random_param = random_param.view(size)
random_param = torch.sort(random_param, dim=1)[0]


@pytest.mark.parametrize("target,n_qubits", random_param)
def test_promote(target: int, n_qubits: int) -> None:
I_prom = promote_ope(I(0).unitary(), target, n_qubits)
assert I_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
X_prom = promote_ope(X(0).unitary(), target, n_qubits)
assert X_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1])
EthanObadia marked this conversation as resolved.
Show resolved Hide resolved
Loading