Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom attribute types to the decoding graph #334

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,6 @@ docs/_build/
docs/stubs/*

.DS_Store

# Cached decoding graphs
graph*.json
161 changes: 76 additions & 85 deletions src/qiskit_qec/circuits/repetition_code.py

Large diffs are not rendered by default.

42 changes: 16 additions & 26 deletions src/qiskit_qec/circuits/surface_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister

from qiskit_qec.utils import DecodingGraphNode
from qiskit_qec.circuits.code_circuit import CodeCircuit


Expand Down Expand Up @@ -107,7 +108,6 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True):
self.readout()

def _get_plaquettes(self):

"""
Returns `zplaqs` and `xplaqs`, which are lists of the Z and X type
stabilizers. Each plaquettes is specified as a list of four qubits,
Expand All @@ -122,7 +122,6 @@ def _get_plaquettes(self):
xplaq_coords = []
for y in range(-1, d):
for x in range(-1, d):

bulk = x in range(d - 1) and y in range(d - 1)
ztab = (x == -1 and y % 2 == 0) or (x == d - 1 and y % 2 == 1)
xtab = (y == -1 and x % 2 == 1) or (y == d - 1 and x % 2 == 0)
Expand Down Expand Up @@ -219,7 +218,6 @@ def syndrome_measurement(self, final=False, barrier=False):
)

for log in ["0", "1"]:

self.circuit[log].add_register(self.zplaq_bits[-1])
self.circuit[log].add_register(self.xplaq_bits[-1])

Expand Down Expand Up @@ -262,7 +260,6 @@ def readout(self):
self.circuit[log].measure(self.code_qubit, self.code_bit)

def _string2changes(self, string):

basis = self.basis

# final syndrome for plaquettes deduced from final code qubit readout
Expand Down Expand Up @@ -347,7 +344,6 @@ def string2raw_logicals(self, string):
return str(Z[0]) + " " + str(Z[1])

def _process_string(self, string):

# get logical readout
measured_Z = self.string2raw_logicals(string)

Expand All @@ -361,7 +357,6 @@ def _process_string(self, string):
return new_string

def _separate_string(self, string):

separated_string = []
for syndrome_type_string in string.split(" "):
separated_string.append(syndrome_type_string.split(" "))
Expand Down Expand Up @@ -399,26 +394,24 @@ def string2nodes(self, string, **kwargs):
boundary = separated_string[0] # [<last_elem>, <init_elem>]
for bqec_index, belement in enumerate(boundary[::-1]):
if all_logicals or belement != logical:
bnode = {"time": 0}
bnode["qubits"] = self._logicals[self.basis][-bqec_index - 1]
bnode["is_boundary"] = True
bnode["element"] = 1 - bqec_index
nodes.append(bnode)
node = DecodingGraphNode(
is_boundary=True,
qubits=self._logicals[self.basis][-bqec_index - 1],
index=1 - bqec_index,
)
nodes.append(node)

# bulk nodes
for syn_type in range(1, len(separated_string)):
for syn_round in range(len(separated_string[syn_type])):
elements = separated_string[syn_type][syn_round]
for qec_index, element in enumerate(elements[::-1]):
if element == "1":
node = {"time": syn_round}
if self.basis == "x":
qubits = self.css_x_stabilizer_ops[qec_index]
else:
qubits = self.css_z_stabilizer_ops[qec_index]
node["qubits"] = qubits
node["is_boundary"] = False
node["element"] = qec_index
node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index)
nodes.append(node)
return nodes

Expand All @@ -441,9 +434,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
num_errors (int): Minimum number of errors required to create nodes.
"""

bulk_nodes = [node for node in nodes if not node["is_boundary"]]
boundary_nodes = [node for node in nodes if node["is_boundary"]]
given_logicals = set(node["element"] for node in boundary_nodes)
bulk_nodes = [node for node in nodes if not node.is_boundary]
boundary_nodes = [node for node in nodes if node.is_boundary]
given_logicals = set(node.index for node in boundary_nodes)

if self.basis == "z":
coords = self._zplaq_coords
Expand All @@ -459,7 +452,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
xs = []
ys = []
for node in bulk_nodes:
x, y = coords[node["element"]]
x, y = coords[node.index]
xs.append(x)
ys.append(y)
dx = max(xs) - min(xs)
Expand All @@ -477,7 +470,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
# find nearest boundary
num_errors = (self.d - 1) / 2
for node in bulk_nodes:
x, y = coords[node["element"]]
x, y = coords[node.index]
if self.basis == "z":
p = y
else:
Expand All @@ -497,12 +490,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
# get the required boundary nodes
flipped_logical_nodes = []
for elem in flipped_logicals:
node = {
"time": 0,
"qubits": self._logicals[self.basis][elem],
"is_boundary": True,
"element": elem,
}
node = DecodingGraphNode(
is_boundary=True, qubits=self._logicals[self.basis][elem], index=elem
)
flipped_logical_nodes.append(node)

return neutral, flipped_logical_nodes, num_errors
Expand Down
2 changes: 1 addition & 1 deletion src/qiskit_qec/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
from .circuit_matching_decoder import CircuitModelMatchingDecoder
from .repetition_decoder import RepetitionDecoder
from .three_bit_decoder import ThreeBitDecoder
from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder
from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder, ClAYGDecoder
70 changes: 33 additions & 37 deletions src/qiskit_qec/decoders/circuit_matching_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from qiskit import QuantumCircuit
from qiskit_qec.analysis.faultenumerator import FaultEnumerator
from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph
from qiskit_qec.utils import DecodingGraphEdge
from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher
from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher
from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome
Expand Down Expand Up @@ -173,13 +174,13 @@ def _process_graph(
n0, n1 = graph.edge_list()[j]
source = graph.nodes()[n0]
target = graph.nodes()[n1]
if source["time"] != target["time"]:
if source["is_boundary"] == target["is_boundary"] == False:
new_source = source.copy()
new_source["time"] = target["time"]
if source.time != target.time:
if source.is_boundary == target.is_boundary == False:
new_source = copy(source)
new_source.time = target.time
nn0 = graph.nodes().index(new_source)
new_target = target.copy()
new_target["time"] = source["time"]
new_target = copy(target)
new_target.time = source.time
nn1 = graph.nodes().index(new_target)
graph.add_edge(nn0, nn1, edge)

Expand All @@ -191,16 +192,16 @@ def _process_graph(

# add the required attributes
# highlighted', 'measurement_error','qubit_id' and 'error_probability'
edge["highlighted"] = False
edge["measurement_error"] = int(source["time"] != target["time"])
edge.properties["highlighted"] = False
edge.properties["measurement_error"] = int(source.time != target.time)

# make it so times of boundary/boundary nodes agree
if source["is_boundary"] and not target["is_boundary"]:
if source["time"] != target["time"]:
new_source = source.copy()
new_source["time"] = target["time"]
if source.is_boundary and not target.is_boundary:
if source.time != target.time:
new_source = copy(source)
new_source.time = target.time
n = graph.add_node(new_source)
edge["measurement_error"] = 0
edge.properties["measurement_error"] = 0
edges_to_remove.append((n0, n1))
graph.add_edge(n, n1, edge)

Expand All @@ -211,29 +212,24 @@ def _process_graph(
for n0, source in enumerate(graph.nodes()):
for n1, target in enumerate(graph.nodes()):
# add weightless nodes connecting different boundaries
if source["time"] == target["time"]:
if source["is_boundary"] and target["is_boundary"]:
if source["qubits"] != target["qubits"]:
edge = {
"measurement_error": 0,
"weight": 0,
"highlighted": False,
}
edge["qubits"] = list(
set(source["qubits"]).intersection((set(target["qubits"])))
if source.time == target.time:
if source.is_boundary and target.is_boundary:
if source.qubits != target.qubits:
edge = DecodingGraphEdge(
weight=0,
qubits=list(set(source.qubits).intersection((set(target.qubits)))),
)
edge.properties["highlighted"] = False
edge.properties["measurement_error"] = 0
if (n0, n1) not in graph.edge_list():
graph.add_edge(n0, n1, edge)

# connect one of the boundaries at different times
if target["time"] == source["time"] + 1:
if source["qubits"] == target["qubits"] == [0]:
edge = {
"qubits": [],
"measurement_error": 0,
"weight": 0,
"highlighted": False,
}
if target.time == (source.time or 0) + 1:
if source.qubits == target.qubits == [0]:
edge = DecodingGraphEdge(weight=0, qubits=[])
edge.properties["highlighted"] = False
edge.properties["measurement_error"] = 0
if (n0, n1) not in graph.edge_list():
graph.add_edge(n0, n1, edge)

Expand All @@ -245,14 +241,14 @@ def _process_graph(

idxmap = {}
for n, node in enumerate(graph.nodes()):
idxmap[node["time"], tuple(node["qubits"])] = n
idxmap[node.time, tuple(node.qubits)] = n

node_layers = []
for node in graph.nodes():
time = node["time"]
time = node.time or 0
if len(node_layers) < time + 1:
node_layers += [[]] * (time + 1 - len(node_layers))
node_layers[time].append(node["qubits"])
node_layers[time].append(node.qubits)

# create a list of decoding graph layer types
# the entries are 'g' for gauge and 's' for stabilizer
Expand Down Expand Up @@ -298,11 +294,11 @@ def _revise_decoding_graph(
# TODO: new edges may be needed for hooks, but raise exception for now
raise QiskitQECError("edge {s1} - {s2} not in decoding graph")
data = graph.get_edge_data(idxmap[s1], idxmap[s2])
data["weight_poly"] = wpoly
data.properties["weight_poly"] = wpoly
remove_list = []
for source, target in graph.edge_list():
edge_data = graph.get_edge_data(source, target)
if "weight_poly" not in edge_data and edge_data["weight"] != 0:
if "weight_poly" not in edge_data.properties and edge_data.weight != 0:
# Remove the edge
remove_list.append((source, target))
logging.info("remove edge (%d, %d)", source, target)
Expand Down Expand Up @@ -340,7 +336,7 @@ def update_edge_weights(self, model: PauliNoiseModel):
# p_i is the probability that edge i carries an error
# l(i) is 1 if the link belongs to the chain and 0 otherwise
for source, target in self.graph.edge_list():
edge_data = self.graph.get_edge_data(source, target)
edge_data = self.graph.get_edge_data(source, target).properties
if "weight_poly" in edge_data:
logging.info(
"update_edge_weights (%d, %d) %s",
Expand Down
Loading