Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating type hints #466

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions mrmustard/lab/abstract/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -114,7 +110,7 @@ def __init__(
len(modes) == self.num_modes
), f"Number of modes supplied ({len(modes)}) must match the representation dimension {self.num_modes}"

def _add_parameter(self, parameter: Union[Constant, Variable]):
def _add_parameter(self, parameter: Constant | Variable):
r"""
Adds a parameter to a state.

Expand All @@ -141,7 +137,7 @@ def modes(self):
return list(range(self.num_modes))
return self._modes

def indices(self, modes) -> Union[Tuple[int], int]:
def indices(self, modes) -> int | tuple[int]:
r"""Returns the indices of the given modes.

Args:
Expand Down Expand Up @@ -175,12 +171,12 @@ def is_pure(self):
return np.isclose(self.purity, 1.0, atol=1e-6)

@property
def means(self) -> Optional[RealVector]:
def means(self) -> RealVector | None:
r"""Returns the means vector of the state."""
return self._means

@property
def cov(self) -> Optional[RealMatrix]:
def cov(self) -> RealMatrix | None:
r"""Returns the covariance matrix of the state."""
return self._cov

Expand All @@ -195,7 +191,7 @@ def number_stdev(self) -> RealVector:
)

@property
def cutoffs(self) -> List[int]:
def cutoffs(self) -> list[int]:
r"""Returns the Hilbert space dimension of each mode."""
if self._cutoffs is None:
if self._ket is None and self._dm is None:
Expand All @@ -214,7 +210,7 @@ def cutoffs(self) -> List[int]:
return self._cutoffs

@property
def shape(self) -> List[int]:
def shape(self) -> list[int]:
r"""Returns the shape of the state, accounting for ket/dm representation.

If the state is in Gaussian representation, the shape is inferred from
Expand Down Expand Up @@ -274,10 +270,10 @@ def probability(self) -> float:

def ket(
self,
cutoffs: List[int] = None,
cutoffs: list[int] | None = None,
max_prob: float = 1.0,
max_photons: int = None,
) -> Optional[ComplexTensor]:
max_photons: int | None = None,
) -> ComplexTensor | None:
r"""Returns the ket of the state in Fock representation or ``None`` if the state is mixed.

Args:
Expand Down Expand Up @@ -323,7 +319,7 @@ def ket(
return padded[tuple(slice(s) for s in cutoffs)]
return self._ket[tuple(slice(s) for s in cutoffs)]

def dm(self, cutoffs: Optional[List[int]] = None) -> ComplexTensor:
def dm(self, cutoffs: list[int] | None = None) -> ComplexTensor:
r"""Returns the density matrix of the state in Fock representation.

Args:
Expand Down Expand Up @@ -376,7 +372,7 @@ def fock_probabilities(self, cutoffs: Sequence[int]) -> RealTensor:
self._fock_probabilities = fock.ket_to_probs(ket)
return self._fock_probabilities

def primal(self, other: Union[State, Transformation]) -> State:
def primal(self, other: State | Transformation) -> State:
r"""Returns the post-measurement state after ``other`` is projected onto ``self``.

``other << self`` is other projected onto ``self``.
Expand All @@ -399,7 +395,7 @@ def primal(self, other: Union[State, Transformation]) -> State:
f"Cannot apply {other.__class__.__qualname__} to {self.__class__.__qualname__}"
) from e

def _project_onto_state(self, other: State) -> Union[State, float]:
def _project_onto_state(self, other: State) -> State | float:
"""If states are gaussian use generaldyne measurement, else use
the states' Fock representation."""

Expand All @@ -410,7 +406,7 @@ def _project_onto_state(self, other: State) -> Union[State, float]:
# either self or other is not gaussian
return self._project_onto_fock(other)

def _project_onto_fock(self, other: State) -> Union[State, float]:
def _project_onto_fock(self, other: State) -> State | float:
"""Returns the post-measurement state of the projection between two non-Gaussian
states on the remaining modes or the probability of the result. When doing homodyne sampling,
returns the post-measurement state or the measument outcome if no modes remain.
Expand Down Expand Up @@ -459,7 +455,7 @@ def _contract_with_other(self, other):

return out_fock

def _project_onto_gaussian(self, other: State) -> Union[State, float]:
def _project_onto_gaussian(self, other: State) -> State | float:
"""Returns the result of a generaldyne measurement given that states ``self`` and
``other`` are gaussian.

Expand Down Expand Up @@ -549,7 +545,7 @@ def __getitem__(self, item) -> State:
self._modes = item
return self

def bargmann(self, numpy=False) -> Optional[tuple[ComplexMatrix, ComplexVector, complex]]:
def bargmann(self, numpy=False) -> tuple[ComplexMatrix, ComplexVector, complex] | None:
r"""Returns the Bargmann representation of the state.
If numpy=True, returns the numpy arrays instead of the backend arrays.
"""
Expand Down Expand Up @@ -700,8 +696,8 @@ def _repr_markdown_(self):

def mikkel_plot(
rho: np.ndarray,
xbounds: Tuple[int] = (-6, 6),
ybounds: Tuple[int] = (-6, 6),
xbounds: tuple[int] = (-6, 6),
ybounds: tuple[int] = (-6, 6),
**kwargs,
): # pylint: disable=too-many-statements
"""Plots the Wigner function of a state given its density matrix.
Expand Down
42 changes: 21 additions & 21 deletions mrmustard/lab/abstract/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from __future__ import annotations

from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
from typing import Callable, Iterable, Sequence

import numpy as np

Expand All @@ -41,10 +41,10 @@ class Transformation(Tensor):
def __init__(
self,
name: str,
modes_in_ket: Optional[list[int]] = None,
modes_out_ket: Optional[list[int]] = None,
modes_in_bra: Optional[list[int]] = None,
modes_out_bra: Optional[list[int]] = None,
modes_in_ket: list[int] | None = None,
modes_out_ket: list[int] | None = None,
modes_in_bra: list[int] | None = None,
modes_out_bra: list[int] | None = None,
):
super().__init__(
name=name,
Expand All @@ -55,7 +55,7 @@ def __init__(
)
self._parameter_set = ParameterSet()

def _add_parameter(self, parameter: Union[Constant, Variable]):
def _add_parameter(self, parameter: Constant | Variable):
r"""
Adds a parameter to a transformation.

Expand Down Expand Up @@ -136,33 +136,33 @@ def _validate_modes(self, modes):
pass

@property
def X_matrix(self) -> Optional[RealMatrix]:
def X_matrix(self) -> RealMatrix | None:
return None

@property
def Y_matrix(self) -> Optional[RealMatrix]:
def Y_matrix(self) -> RealMatrix | None:
return None

@property
def d_vector(self) -> Optional[RealVector]:
def d_vector(self) -> RealVector | None:
return None

@property
def X_matrix_dual(self) -> Optional[RealMatrix]:
def X_matrix_dual(self) -> RealMatrix | None:
if (X := self.X_matrix) is None:
return None
return gaussian.math.inv(X)

@property
def Y_matrix_dual(self) -> Optional[RealMatrix]:
def Y_matrix_dual(self) -> RealMatrix | None:
if (Y := self.Y_matrix) is None:
return None
if (Xdual := self.X_matrix_dual) is None:
return Y
return math.matmul(math.matmul(Xdual, Y), math.transpose(Xdual))

@property
def d_vector_dual(self) -> Optional[RealVector]:
def d_vector_dual(self) -> RealVector | None:
if (d := self.d_vector) is None:
return None
if (Xdual := self.X_matrix_dual) is None:
Expand All @@ -181,8 +181,8 @@ def bargmann(self, numpy=False):

def choi(
self,
cutoffs: Optional[Sequence[int]] = None,
shape: Optional[Sequence[int]] = None,
cutoffs: Sequence[int] | None = None,
shape: Sequence[int] | None = None,
dual: bool = False,
):
r"""Returns the Choi representation of the transformation.
Expand Down Expand Up @@ -224,7 +224,7 @@ def choi(

def XYd(
self, allow_none: bool = True
) -> Tuple[Optional[RealMatrix], Optional[RealMatrix], Optional[RealVector]]:
) -> tuple[RealMatrix | None, RealMatrix | None, RealVector | None]:
r"""Returns the ```(X, Y, d)``` triple.

Override in subclasses if computing ``X``, ``Y`` and ``d`` together is more efficient.
Expand All @@ -238,7 +238,7 @@ def XYd(

def XYd_dual(
self, allow_none: bool = True
) -> tuple[Optional[RealMatrix], Optional[RealMatrix], Optional[RealVector]]:
) -> tuple[RealMatrix | None, RealMatrix | None, RealVector | None]:
r"""Returns the ```(X, Y, d)``` triple of the dual of the current transformation.

Override in subclasses if computing ``Xdual``, ``Ydual`` and ``ddual`` together is more efficient.
Expand Down Expand Up @@ -290,7 +290,7 @@ def __rshift__(self, other: Transformation):
ops2 = other._ops if isinstance(other, Circuit) else [other]
return Circuit(ops1 + ops2)

def __lshift__(self, other: Union[State, Transformation]):
def __lshift__(self, other: State | Transformation):
r"""Applies the dual of self to other.

If other is a state, the dual of self is applied to the state.
Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(self, name: str, modes: list[int]):
super().__init__(name=name, modes_in_ket=modes, modes_out_ket=modes)
self.is_unitary = True

def value(self, shape: Tuple[int]):
def value(self, shape: tuple[int]):
return self.U(shape=shape)

def _transform_fock(self, state: State, dual=False) -> State:
Expand All @@ -387,8 +387,8 @@ def _transform_fock(self, state: State, dual=False) -> State:

def U(
self,
cutoffs: Optional[Sequence[int]] = None,
shape: Optional[Sequence[int]] = None,
cutoffs: Sequence[int] | None = None,
shape: Sequence[int] | None = None,
):
r"""Returns the unitary representation of the transformation.

Expand Down Expand Up @@ -456,7 +456,7 @@ def _transform_fock(self, state: State, dual: bool = False) -> State:
return State(dm=fock.apply_choi_to_ket(choi, state.ket(), op_idx), modes=state.modes)
return State(dm=fock.apply_choi_to_dm(choi, state.dm(), op_idx), modes=state.modes)

def value(self, shape: Tuple[int]):
def value(self, shape: tuple[int]):
return self.choi(shape=shape)

def __eq__(self, other):
Expand Down
12 changes: 5 additions & 7 deletions mrmustard/lab/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

__all__ = ["Circuit"]

from typing import List, Optional, Tuple

import numpy as np

from mrmustard import settings
Expand All @@ -38,13 +36,13 @@ class Circuit(Transformation):
ops (list or none): A list of operations comprising the circuit.
"""

def __init__(self, ops: Optional[List] = None):
def __init__(self, ops: list | None = None):
self._ops = list(ops) if ops is not None else []
super().__init__(name="Circuit")
self.reset()

@property
def ops(self) -> Optional[List]:
def ops(self) -> list | None:
r"""
The list of operations comprising the circuit.
"""
Expand All @@ -53,7 +51,7 @@ def ops(self) -> Optional[List]:
def reset(self):
"""Resets the state of the circuit clearing the list of modes and setting the compiled flag to false."""
self._compiled: bool = False
self._modes: List[int] = []
self._modes: list[int] = []

@property
def num_modes(self) -> int:
Expand All @@ -73,7 +71,7 @@ def dual(self, state: State) -> State:
def XYd(
self,
allow_none: bool = True,
) -> Tuple[
) -> tuple[
RealMatrix, RealMatrix, RealVector
]: # NOTE: Overriding Transformation.XYd for efficiency
X = XPMatrix(like_1=True)
Expand Down Expand Up @@ -105,7 +103,7 @@ def is_unitary(self):
"""Returns `true` if all operations in the circuit are unitary."""
return all(op.is_unitary for op in self._ops)

def value(self, shape: Tuple[int]):
def value(self, shape: tuple[int]):
raise NotImplementedError

def __len__(self):
Expand Down
Loading
Loading