Skip to content

Commit

Permalink
more explicit typing
Browse files Browse the repository at this point in the history
  • Loading branch information
perlinm committed Jan 16, 2024
1 parent 90d5071 commit 4371d70
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions qldpc/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
if TYPE_CHECKING:
from typing_extensions import Self

IntegerMatrix = npt.NDArray[np.int_] | Sequence[Sequence[int]]
ObjectMatrix = npt.NDArray[np.object_] | Sequence[Sequence[object]]

DEFAULT_FIELD_ORDER = abstract.DEFAULT_FIELD_ORDER

################################################################################
Expand All @@ -50,7 +47,11 @@ class AbstractCode(abc.ABC):

_field_order: int

def __init__(self, matrix: Self | IntegerMatrix, field: int | None = None) -> None:
def __init__(
self,
matrix: Self | npt.NDArray[np.int_] | Sequence[Sequence[int]],
field: int | None = None,
) -> None:
"""Construct a code from a parity check matrix over a finite field.
The base field is taken to be F_2 by default.
Expand Down Expand Up @@ -88,7 +89,7 @@ def graph(self) -> nx.DiGraph:

@classmethod
@abc.abstractmethod
def matrix_to_graph(cls, matrix: IntegerMatrix) -> nx.DiGraph:
def matrix_to_graph(cls, matrix: npt.NDArray[np.int_] | Sequence[Sequence[int]]) -> nx.DiGraph:
"""Convert a parity check matrix into a Tanner graph."""

@classmethod
Expand All @@ -112,7 +113,7 @@ def __contains__(self, word: npt.NDArray[np.int_] | Sequence[int]) -> bool:
return not np.any(self.matrix @ self.field(word))

@classmethod
def matrix_to_graph(cls, matrix: IntegerMatrix) -> nx.DiGraph:
def matrix_to_graph(cls, matrix: npt.NDArray[np.int_] | Sequence[Sequence[int]]) -> nx.DiGraph:
"""Convert a parity check matrix H into a Tanner graph.
The Tanner graph is a bipartite graph with (num_checks, num_bits) vertices, respectively
Expand Down Expand Up @@ -322,7 +323,7 @@ def _assert_qubit_code(self) -> None:
raise ValueError("Attempted to call a qubit-only method with a non-qubit code.")

@classmethod
def matrix_to_graph(cls, matrix: IntegerMatrix) -> nx.DiGraph:
def matrix_to_graph(cls, matrix: npt.NDArray[np.int_] | Sequence[Sequence[int]]) -> nx.DiGraph:
"""Convert a parity check matrix into a Tanner graph."""
graph = nx.DiGraph()
matrix = np.reshape(matrix, (len(matrix), 2, -1))
Expand Down Expand Up @@ -395,7 +396,7 @@ def from_stabilizers(cls, stabilizers: Iterable[str], field: int | None = None)
# see https://arxiv.org/pdf/quant-ph/0408190.pdf
@classmethod
def conjugate(
cls, matrix: IntegerMatrix, qudits: slice | Sequence[int]
cls, matrix: npt.NDArray[np.int_] | Sequence[Sequence[int]], qudits: slice | Sequence[int]
) -> npt.NDArray[np.int_]:
"""Apply local Fourier transforms to the given qudits.
Expand Down Expand Up @@ -431,8 +432,8 @@ class CSSCode(QuditCode):

def __init__(
self,
code_x: ClassicalCode | IntegerMatrix,
code_z: ClassicalCode | IntegerMatrix,
code_x: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
code_z: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
field: int | None = None,
*,
conjugate: slice | Sequence[int] | None = (),
Expand Down Expand Up @@ -769,8 +770,8 @@ class GBCode(CSSCode):

def __init__(
self,
matrix_a: IntegerMatrix,
matrix_b: IntegerMatrix | None = None,
matrix_a: npt.NDArray[np.int_] | Sequence[Sequence[int]],
matrix_b: npt.NDArray[np.int_] | Sequence[Sequence[int]] | None = None,
field: int | None = None,
*,
conjugate: slice | Sequence[int] = (),
Expand Down Expand Up @@ -891,8 +892,8 @@ class HGPCode(CSSCode):

def __init__(
self,
code_a: ClassicalCode | IntegerMatrix,
code_b: ClassicalCode | IntegerMatrix | None = None,
code_a: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
code_b: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]] | None = None,
field: int | None = None,
*,
conjugate: bool = False,
Expand Down Expand Up @@ -1012,8 +1013,11 @@ class LPCode(CSSCode):

def __init__(
self,
protograph_a: abstract.Protograph | ObjectMatrix,
protograph_b: abstract.Protograph | ObjectMatrix | None = None,
protograph_a: abstract.Protograph | npt.NDArray[np.object_] | Sequence[Sequence[object]],
protograph_b: abstract.Protograph
| npt.NDArray[np.object_]
| Sequence[Sequence[object]]
| None = None,
*,
conjugate: bool = False,
) -> None:
Expand Down Expand Up @@ -1146,8 +1150,8 @@ def __init__(
self,
subset_a: Collection[abstract.GroupMember],
subset_b: Collection[abstract.GroupMember],
code_a: ClassicalCode | IntegerMatrix,
code_b: ClassicalCode | IntegerMatrix | None = None,
code_a: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
code_b: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]] | None = None,
field: int | None = None,
*,
conjugate: slice | Sequence[int] | None = (),
Expand Down

0 comments on commit 4371d70

Please sign in to comment.