diff --git a/pyqtorch/apply.py b/pyqtorch/apply.py index 859f7b82..2f10a2ad 100644 --- a/pyqtorch/apply.py +++ b/pyqtorch/apply.py @@ -1,24 +1,24 @@ from __future__ import annotations from string import ascii_letters as ABC -from typing import Tuple -from numpy import array +import torch +from numpy import array, log2 from numpy.typing import NDArray -from torch import einsum +from torch import Tensor, einsum -from pyqtorch.utils import Operator, State +from pyqtorch.utils import batch_first, batch_last, promote_operator ABC_ARRAY: NDArray = array(list(ABC)) def apply_operator( - state: State, - operator: Operator, - qubits: Tuple[int, ...] | list[int], + state: Tensor, + operator: Tensor, + qubits: tuple[int, ...] | list[int], n_qubits: int = None, batch_size: int = None, -) -> State: +) -> Tensor: """Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits. @@ -57,3 +57,32 @@ def apply_operator( map(lambda e: "".join(list(e)), [operator_dims, in_state_dims, out_state_dims]) ) return einsum(f"{operator_dims},{in_state_dims}->{out_state_dims}", operator, state) + + +def operator_product(op1: Tensor, op2: Tensor, target: int) -> Tensor: + """ + Compute the product of two operators. + + Args: + op1 (Tensor): The first operator. + op2 (Tensor): The second operator. + target (int): The target qubit index. + + Returns: + Tensor: The product of the two operators. + """ + + n_qubits_1 = int(log2(op1.size(1))) + n_qubits_2 = int(log2(op2.size(1))) + batch_size_1 = op1.size(-1) + batch_size_2 = op2.size(-1) + if n_qubits_1 > n_qubits_2: + op2 = promote_operator(op2, target, n_qubits_1) + elif n_qubits_1 < n_qubits_2: + op1 = promote_operator(op1, target, n_qubits_2) + if batch_size_1 > batch_size_2: + op2 = op2.repeat(1, 1, batch_size_1)[:, :, :batch_size_1] + elif batch_size_2 > batch_size_1: + op1 = op1.repeat(1, 1, batch_size_2)[:, :, :batch_size_2] + + return batch_last(torch.bmm(batch_first(op1), batch_first(op2))) diff --git a/pyqtorch/primitive.py b/pyqtorch/primitive.py index bac9e4ba..af8b9b70 100644 --- a/pyqtorch/primitive.py +++ b/pyqtorch/primitive.py @@ -1,20 +1,21 @@ from __future__ import annotations from math import log2 -from typing import Any, Tuple +from typing import Any import torch +from torch import Tensor from pyqtorch.apply import apply_operator from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger -from pyqtorch.utils import Operator, State, product_state +from pyqtorch.utils import product_state class Primitive(torch.nn.Module): - def __init__(self, pauli: torch.Tensor, target: int) -> None: + def __init__(self, pauli: Tensor, target: int) -> None: super().__init__() self.target: int = target - self.qubit_support: Tuple[int, ...] = (target,) + self.qubit_support: tuple[int, ...] = (target,) self.n_qubits: int = max(self.qubit_support) self.register_buffer("pauli", pauli) self._param_type = None @@ -40,15 +41,15 @@ def extra_repr(self) -> str: def param_type(self) -> None: return self._param_type - def unitary(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator: + def unitary(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor: return self.pauli.unsqueeze(2) - def forward(self, state: State, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> State: + def forward(self, state: Tensor, values: dict[str, Tensor] | Tensor = dict()) -> Tensor: return apply_operator( state, self.unitary(values), self.qubit_support, len(state.size()) - 1 ) - def dagger(self, values: dict[str, torch.Tensor] | torch.Tensor = {}) -> Operator: + def dagger(self, values: dict[str, Tensor] | Tensor = dict()) -> Tensor: return _dagger(self.unitary(values)) @property @@ -85,7 +86,7 @@ class I(Primitive): # noqa: E742 def __init__(self, target: int): super().__init__(OPERATIONS_DICT["I"], target) - def forward(self, state: State, values: dict[str, torch.Tensor] = None) -> State: + def forward(self, state: Tensor, values: dict[str, Tensor] = None) -> Tensor: return state @@ -135,7 +136,7 @@ def __init__(self, control: int, target: int): class CSWAP(Primitive): - def __init__(self, control: int | Tuple[int, ...], target: int): + def __init__(self, control: int | tuple[int, ...], target: int): super().__init__(OPERATIONS_DICT["CSWAP"], target) self.control = (control,) if isinstance(control, int) else control self.target = target @@ -144,7 +145,7 @@ def __init__(self, control: int | Tuple[int, ...], target: int): class ControlledOperationGate(Primitive): - def __init__(self, gate: str, control: int | Tuple[int, ...], target: int): + def __init__(self, gate: str, control: int | tuple[int, ...], target: int): self.control = (control,) if isinstance(control, int) else control mat = OPERATIONS_DICT[gate] mat = _controlled( @@ -158,7 +159,7 @@ def __init__(self, gate: str, control: int | Tuple[int, ...], target: int): class CNOT(ControlledOperationGate): - def __init__(self, control: int | Tuple[int, ...], target: int): + def __init__(self, control: int | tuple[int, ...], target: int): super().__init__("X", control, target) @@ -166,15 +167,15 @@ def __init__(self, control: int | Tuple[int, ...], target: int): class CY(ControlledOperationGate): - def __init__(self, control: int | Tuple[int, ...], target: int): + def __init__(self, control: int | tuple[int, ...], target: int): super().__init__("Y", control, target) class CZ(ControlledOperationGate): - def __init__(self, control: int | Tuple[int, ...], target: int): + def __init__(self, control: int | tuple[int, ...], target: int): super().__init__("Z", control, target) class Toffoli(ControlledOperationGate): - def __init__(self, control: int | Tuple[int, ...], target: int): + def __init__(self, control: int | tuple[int, ...], target: int): super().__init__("X", control, target) diff --git a/pyqtorch/utils.py b/pyqtorch/utils.py index 43cbd3e7..a3ef3bd6 100644 --- a/pyqtorch/utils.py +++ b/pyqtorch/utils.py @@ -142,3 +142,63 @@ def density_mat(state: Tensor) -> Tensor: state = torch.permute(state, batch_first_perm).reshape(batch_size, 2**n_qubits) undo_perm = (1, 2, 0) return torch.permute(torch.einsum("bi,bj->bij", (state, state.conj())), undo_perm) + + +def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor: + from pyqtorch.primitive import I + + """ + Promotes `operator` to the size of the circuit (number of qubits and batch). + Targeting the first qubit implies target = 0, so target > n_qubits - 1. + + Args: + operator (Tensor): The operator tensor to be promoted. + target (int): The index of the target qubit to which the operator is applied. + Targeting the first qubit implies target = 0, so target > n_qubits - 1. + n_qubits (int): Number of qubits in the circuit. + + Returns: + Tensor: The promoted operator tensor. + + Raises: + ValueError: If `target` is outside the valid range of qubits. + """ + if target > n_qubits - 1: + raise ValueError("The target must be a valid qubit index within the circuit's range.") + qubits = torch.arange(0, n_qubits) + qubits = qubits[qubits != target] + for qubit in qubits: + operator = torch.where( + target > qubit, + torch.kron(I(target).unitary(), operator.contiguous()), + torch.kron(operator.contiguous(), I(target).unitary()), + ) + return operator + + +def batch_first(operator: Tensor) -> Tensor: + """ + Permute the operator's batch dimension on first dimension. + + Args: + operator (Tensor): Operator in size [2**n_qubits, 2**n_qubits,batch_size]. + + Returns: + Tensor: Operator in size [batch_size, 2**n_qubits, 2**n_qubits]. + """ + batch_first_perm = (2, 0, 1) + return torch.permute(operator, batch_first_perm) + + +def batch_last(operator: Tensor) -> Tensor: + """ + Permute the operator's batch dimension on last dimension. + + Args: + operator (Tensor): Operator in size [batch_size,2**n_qubits, 2**n_qubits]. + + Returns: + Tensor: Operator in size [2**n_qubits, 2**n_qubits,batch_size]. + """ + undo_perm = (1, 2, 0) + return torch.permute(operator, undo_perm) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..632824d9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from pyqtorch.primitive import H, I, Primitive, S, T, X, Y, Z + + +@pytest.fixture(params=[I, X, Y, Z, H, T, S]) +def gate(request: Primitive) -> Any: + return request.param diff --git a/tests/test_digital.py b/tests/test_digital.py index d67b470a..496f721a 100644 --- a/tests/test_digital.py +++ b/tests/test_digital.py @@ -9,10 +9,11 @@ from torch import Tensor import pyqtorch as pyq -from pyqtorch.apply import apply_operator -from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT +from pyqtorch.apply import apply_operator, operator_product +from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, ZMAT, _dagger from pyqtorch.parametric import Parametric -from pyqtorch.utils import ATOL, density_mat, product_state, random_state +from pyqtorch.primitive import Primitive +from pyqtorch.utils import ATOL, density_mat, product_state, promote_operator, random_state state_000 = product_state("000") state_001 = product_state("001") @@ -293,7 +294,6 @@ def test_U() -> None: @pytest.mark.parametrize("n_qubits,batch_size", torch.randint(1, 6, (8, 2))) def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None: - # Test without batches: state = random_state(n_qubits) projector = torch.outer(state.flatten(), state.conj().flatten()).view( 2**n_qubits, 2**n_qubits, 1 @@ -301,24 +301,44 @@ def test_dm(n_qubits: Tensor, batch_size: Tensor) -> None: dm = density_mat(state) assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, 1]) assert torch.allclose(dm, projector) - - # Test with batches: states = [] projectors = [] - # Batches creation: for batch in range(batch_size): - # Batch state creation: state = random_state(n_qubits) states.append(state) - # Batch projector: projector = torch.outer(state.flatten(), state.conj().flatten()).view( 2**n_qubits, 2**n_qubits, 1 ) projectors.append(projector) - # Concatenate all the batch projectors: dm_proj = torch.cat(projectors, dim=2) - # Concatenate the batch state to compute the density matrix state_cat = torch.cat(states, dim=n_qubits) dm = density_mat(state_cat) assert dm.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size]) assert torch.allclose(dm, dm_proj) + + +def test_promote(gate: Primitive) -> None: + n_qubits = torch.randint(low=1, high=8, size=(1,)).item() + target = random.choice([i for i in range(n_qubits)]) + op_prom = promote_operator(gate(target).unitary(), target, n_qubits) + assert op_prom.size() == torch.Size([2**n_qubits, 2**n_qubits, 1]) + assert torch.allclose( + operator_product(op_prom, _dagger(op_prom), target), + torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2), + ) + + +def test_operator_product(gate: Primitive) -> None: + n_qubits = torch.randint(low=1, high=8, size=(1,)).item() + target = random.choice([i for i in range(n_qubits)]) + batch_size_1 = torch.randint(low=1, high=5, size=(1,)).item() + batch_size_2 = torch.randint(low=1, high=5, size=(1,)).item() + max_batch = max(batch_size_2, batch_size_1) + op_prom = promote_operator(gate(target).unitary(), target, n_qubits).repeat(1, 1, batch_size_1) + op_mul = operator_product( + gate(target).unitary().repeat(1, 1, batch_size_2), _dagger(op_prom), target + ) + assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, max_batch]) + assert torch.allclose( + op_mul, torch.eye(2**n_qubits, dtype=torch.cdouble).unsqueeze(2).repeat(1, 1, max_batch) + )