Skip to content

Commit

Permalink
Merge pull request #131 from Infleqtion/simplify-conjugate
Browse files Browse the repository at this point in the history
Remove conjugated argument from `QuditCode.__init__` methods
  • Loading branch information
perlinm authored Sep 17, 2024
2 parents 57c846b + bd1adb1 commit 96681a4
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 167 deletions.
39 changes: 9 additions & 30 deletions qldpc/codes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,9 @@ def __init__(
self,
matrix: AbstractCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
field: int | None = None,
*,
conjugate: slice | Sequence[int] | None = (),
) -> None:
"""Construct a qudit code from a parity check matrix over a finite field."""
AbstractCode.__init__(self, matrix, field)
if conjugate:
self._matrix = self.field(QuditCode.conjugate(self._matrix, conjugate))

def __str__(self) -> str:
"""Human-readable representation of this code."""
Expand Down Expand Up @@ -647,17 +643,14 @@ def from_stabilizers(cls, *stabilizers: str, field: int | None = None) -> QuditC

return QuditCode(matrix.reshape(num_checks, 2 * num_qudits), field)

@classmethod
def conjugate(
cls, matrix: npt.NDArray[np.int_] | Sequence[Sequence[int]], qudits: slice | Sequence[int]
) -> npt.NDArray[np.int_]:
"""Apply local Fourier transforms to the given qudits.
This is equivalent to swapping X-type and Z-type operators."""
num_checks = len(matrix)
matrix = np.reshape(matrix, (num_checks, 2, -1))
def conjugated(self, qudits: slice | Sequence[int] | None = None) -> QuditCode:
"""Apply local Fourier transforms to data qudits, swapping X-type and Z-type operators."""
if qudits is None:
qudits = self._default_conjugate if hasattr(self, "_default_conjugate") else ()
num_checks = len(self.matrix)
matrix = np.reshape(self.matrix.copy(), (num_checks, 2, -1))
matrix[:, :, qudits] = np.roll(matrix[:, :, qudits], 1, axis=1)
return matrix.reshape(num_checks, -1)
return QuditCode(matrix.reshape(num_checks, -1))

def get_logical_ops(self, pauli: PauliXZ | None = None) -> galois.FieldArray:
"""Complete basis of nontrivial logical operators for this code.
Expand Down Expand Up @@ -785,7 +778,6 @@ class CSSCode(QuditCode):
code_x: ClassicalCode # X-type parity checks, measuring Z-type errors
code_z: ClassicalCode # Z-type parity checks, measuring X-type errors

_conjugated: slice | Sequence[int]
_exact_distance_x: int | float | None = None
_exact_distance_z: int | float | None = None
_balanced_codes: bool
Expand All @@ -796,14 +788,10 @@ def __init__(
code_z: ClassicalCode | npt.NDArray[np.int_] | Sequence[Sequence[int]],
field: int | None = None,
*,
conjugate: slice | Sequence[int] | None = (),
promise_balanced_codes: bool = False, # do the subcodes have the same parameters [n, k, d]?
skip_validation: bool = False,
) -> None:
"""Build a CSSCode from classical subcodes that specify X-type and Z-type parity checks.
Allow specifying local Fourier transformations on the qudits specified by `conjugate`.
"""
"""Build a CSSCode from classical subcodes that specify X-type and Z-type parity checks."""
self.code_x = ClassicalCode(code_x, field)
self.code_z = ClassicalCode(code_z, field)

Expand All @@ -814,7 +802,6 @@ def __init__(
if not skip_validation and self.code_x != self.code_z:
self._validate_subcodes()

self._conjugated = conjugate or ()
self._balanced_codes = promise_balanced_codes or self.code_x == self.code_z

def _validate_subcodes(self) -> None:
Expand All @@ -834,9 +821,6 @@ def __str__(self) -> str:
text += f"{self.name} on {self.num_qudits} qudits over {self.field_name}"
text += f"\nX-type parity checks:\n{self.matrix_x}"
text += f"\nZ-type parity checks:\n{self.matrix_z}"
if self.conjugated:
qudits = "qubits" if self.field.order == 2 else "qudits"
text += f"\n{qudits} conjugated at:\n{self.conjugated}"
return text

@functools.cached_property
Expand All @@ -848,7 +832,7 @@ def matrix(self) -> galois.FieldArray:
[self.matrix_z, np.zeros_like(self.matrix_z)],
]
)
return self.field(self.conjugate(matrix, self.conjugated))
return self.field(matrix)

@property
def matrix_x(self) -> galois.FieldArray:
Expand All @@ -860,11 +844,6 @@ def matrix_z(self) -> galois.FieldArray:
"""Z-type parity checks."""
return self.code_z.matrix

@property
def conjugated(self) -> slice | Sequence[int]:
"""Which qudits are conjugated? Conjugated qudits swap their X and Z operators."""
return self._conjugated

@property
def num_checks_x(self) -> int:
"""Number of X-type parity checks in this code."""
Expand Down
19 changes: 8 additions & 11 deletions qldpc/codes/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def test_conversions_classical(bits: int = 5, checks: int = 3) -> None:

def get_random_qudit_code(qudits: int, checks: int, field: int = 2) -> codes.QuditCode:
"""Construct a random (but probably trivial or invalid) QuditCode."""
return codes.QuditCode(
codes.ClassicalCode.random(2 * qudits, checks, field).matrix,
conjugate=(0,), # conjugate the first qubit
)
return codes.QuditCode(codes.ClassicalCode.random(2 * qudits, checks, field).matrix)


def test_code_string() -> None:
Expand All @@ -151,8 +148,8 @@ def test_code_string() -> None:
code = codes.HGPCode(codes.RepetitionCode(2, field=2))
assert "qubits" in str(code)

code = codes.HGPCode(codes.RepetitionCode(2, field=3), conjugate=True)
assert "GF(3)" in str(code) and "conjugated" in str(code)
code = codes.HGPCode(codes.RepetitionCode(2, field=3))
assert "GF(3)" in str(code)


def test_qubit_code(num_qubits: int = 5, num_checks: int = 3) -> None:
Expand Down Expand Up @@ -192,14 +189,14 @@ def test_qudit_stabilizers(field: int, bits: int = 5, checks: int = 3) -> None:

def test_deformations(num_qudits: int = 5, num_checks: int = 3, field: int = 3) -> None:
"""Apply Pauli deformations to a qudit code."""
code = get_random_qudit_code(num_qudits, num_checks, field)
conjugate = tuple(qubit for qubit in range(num_qudits) if np.random.randint(2))
transformed_matrix = codes.QuditCode.conjugate(code.matrix, conjugate)
qudits = tuple(qubit for qubit in range(num_qudits) if np.random.randint(2))
code = get_random_qudit_code(num_qudits, num_checks, field).conjugated(qudits)
assert np.array_equal(code.matrix, code.conjugated().matrix)

transformed_matrix = transformed_matrix.reshape(num_checks, 2, num_qudits)
matrix = np.reshape(code.matrix, (num_checks, 2, num_qudits))
for node_check, node_qubit, data in code.graph.edges(data=True):
vals = data[QuditOperator].value
assert tuple(transformed_matrix[node_check.index, :, node_qubit.index]) == vals[::-1]
assert tuple(matrix[node_check.index, :, node_qubit.index]) == vals[::-1]


def test_qudit_ops() -> None:
Expand Down
Loading

0 comments on commit 96681a4

Please sign in to comment.