diff --git a/.gitignore b/.gitignore index 1e538893..55d9fa51 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ docs/_build/ docs/stubs/* .DS_Store + +# Cached decoding graphs +graph*.json \ No newline at end of file diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 244432f2..f3f5d2b1 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -17,6 +17,7 @@ """Generates circuits based on repetition codes.""" from typing import List, Optional, Tuple +from copy import copy, deepcopy import numpy as np import rustworkx as rx @@ -26,6 +27,7 @@ from qiskit.transpiler.passes import DynamicalDecoupling from qiskit_qec.circuits.code_circuit import CodeCircuit +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge def _separate_string(string): @@ -315,15 +317,12 @@ def string2nodes(self, string, **kwargs): boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - bnode = {"time": 0} i = [0, -1][bqec_index] if self.basis == "z": bqubits = [self.css_x_logical[i]] else: bqubits = [self.css_z_logical[i]] - bnode["qubits"] = bqubits - bnode["is_boundary"] = True - bnode["element"] = bqec_index + bnode = DecodingGraphNode(is_boundary=True, qubits=bqubits, index=bqec_index) nodes.append(bnode) # bulk nodes @@ -332,14 +331,11 @@ def string2nodes(self, string, **kwargs): elements = separated_string[syn_type][syn_round] for qec_index, element in enumerate(elements[::-1]): if element == "1": - node = {"time": syn_round} if self.basis == "z": qubits = self.css_z_gauge_ops[qec_index] else: qubits = self.css_x_gauge_ops[qec_index] - node["qubits"] = qubits - node["is_boundary"] = False - node["element"] = qec_index + node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -354,7 +350,7 @@ def string2raw_logicals(self, string): return _separate_string(self._process_string(string))[0] @staticmethod - def flatten_nodes(nodes): + def flatten_nodes(nodes: List[DecodingGraphNode]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. @@ -365,17 +361,17 @@ def flatten_nodes(nodes): """ nodes_per_link = {} for node in nodes: - link_qubit = node["link qubit"] + link_qubit = node.properties["link qubit"] if link_qubit in nodes_per_link: nodes_per_link[link_qubit] += 1 else: nodes_per_link[link_qubit] = 1 flat_nodes = [] for node in nodes: - if nodes_per_link[node["link qubit"]] % 2: - flat_node = node.copy() - if "time" in flat_node: - flat_node.pop("time") + if nodes_per_link[node.properties["link qubit"]] % 2: + flat_node = copy(node) + # FIXME: Seems unsafe. + flat_node.time = None flat_nodes.append(flat_node) return flat_nodes @@ -401,15 +397,15 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see which qubits for logical zs are given and collect bulk nodes given_logicals = [] for node in nodes: - if node["is_boundary"]: - given_logicals += node["qubits"] + if node.is_boundary: + given_logicals += node.qubits given_logicals = set(given_logicals) # bicolour code qubits according to the domain walls walls = [] for node in nodes: - if not node["is_boundary"]: - walls.append(node["qubits"][1]) + if not node.is_boundary: + walls.append(node.qubits[1]) walls.sort() c = 0 colors = "" @@ -429,7 +425,6 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # calculate all required info for the max to see if that is fully neutral # if not, calculate and output for the min case for error_c in [error_c_max, error_c_min]: - num_errors = colors.count(error_c) # determine the corresponding flipped logicals @@ -455,7 +450,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): elem = self.css_z_boundary.index(qubits) else: elem = self.css_x_boundary.index(qubits) - node = {"time": 0, "qubits": qubits, "is_boundary": True, "element": elem} + node = DecodingGraphNode(is_boundary=True, qubits=qubits, index=elem) flipped_logical_nodes.append(node) if neutral and flipped_logical_nodes == []: @@ -622,6 +617,7 @@ def __init__( self._readout() def _get_link_graph(self, max_dist=1): + # FIXME: Migrate link graph to new Edge type graph = rx.PyGraph() for link in self.links: add_edge(graph, (link[0], link[2]), {"distance": 1, "link qubit": link[1]}) @@ -702,7 +698,7 @@ def weight_fn(edge): # find a min weight matching, and then another that exlcudes the pairs from the first matching = [rx.max_weight_matching(graph, max_cardinality=True, weight_fn=weight_fn)] - cut_graph = graph.copy() + cut_graph = deepcopy(graph) for n0, n1 in matching[0]: cut_graph.remove_edge(n0, n1) matching.append( @@ -1026,7 +1022,7 @@ def _process_string(self, string): return new_string - def string2nodes(self, string, **kwargs): + def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: """ Convert output string from circuits into a set of nodes. Args: @@ -1061,19 +1057,21 @@ def string2nodes(self, string, **kwargs): code_qubits = [link[0], link[2]] link_qubit = link[1] tau, _, _ = self._get_202(syn_round) - node = {"time": syn_round} - if tau: - if ((tau % 2) == 1) and tau > 1: - node["conjugate"] = True - node["qubits"] = code_qubits - node["link qubit"] = link_qubit - node["is_boundary"] = is_boundary - node["element"] = elem_num + if not tau: + tau = 0 + node = DecodingGraphNode( + is_boundary=is_boundary, + time=syn_round if not is_boundary else None, + qubits=code_qubits, + index=elem_num, + ) + node.properties["conjugate"] = ((tau % 2) == 1) and tau > 1 + node.properties["link qubit"] = link_qubit nodes.append(node) return nodes @staticmethod - def flatten_nodes(nodes): + def flatten_nodes(nodes: List[DecodingGraphNode]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. Also removes nodes corresponding @@ -1086,26 +1084,22 @@ def flatten_nodes(nodes): # strip out conjugate nodes non_conj_nodes = [] for node in nodes: - if "conjugate" not in node: + if not node.properties["conjugate"]: non_conj_nodes.append(node) - else: - if not node["conjugate"]: - non_conj_nodes.append(node) nodes = non_conj_nodes # remove time info nodes_per_link = {} for node in nodes: - link_qubit = node["link qubit"] + link_qubit = node.properties["link qubit"] if link_qubit in nodes_per_link: nodes_per_link[link_qubit] += 1 else: nodes_per_link[link_qubit] = 1 flat_nodes = [] for node in nodes: - if nodes_per_link[node["link qubit"]] % 2: - flat_node = node.copy() - if "time" in flat_node: - flat_node.pop("time") + if nodes_per_link[node.properties["link qubit"]] % 2: + flat_node = deepcopy(node) + flat_node.time = None flat_nodes.append(flat_node) return flat_nodes @@ -1132,8 +1126,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): given_logicals = [] bulk_nodes = [] for node in nodes: - if node["is_boundary"]: - given_logicals += node["qubits"] + if node.is_boundary: + given_logicals += node.qubits else: bulk_nodes.append(node) given_logicals = set(given_logicals) @@ -1141,7 +1135,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see whether the bulk nodes are neutral if bulk_nodes: nodes = self.flatten_nodes(nodes) - link_qubits = set(node["link qubit"] for node in nodes) + link_qubits = set(node.properties["link qubit"] for node in nodes) node_color = {0: 0} neutral = True link_graph = self._get_link_graph() @@ -1196,7 +1190,6 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see what happens for both colours # once full neutrality us found, go for it! for c in min_cs: - this_neutral = neutral num_errors = num_nodes[c] flipped_logicals = flipped_logicals_all[c] @@ -1211,13 +1204,11 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): flipped_logical_nodes = [] for flipped_logical in flipped_logicals: - node = { - "time": 0, - "qubits": [flipped_logical], - "link qubit": None, - "is_boundary": True, - "element": self.z_logicals.index(flipped_logical), - } + node = DecodingGraphNode( + is_boundary=True, + qubits=[flipped_logical], + index=self.z_logicals.index(flipped_logical), + ) flipped_logical_nodes.append(node) if this_neutral and flipped_logical_nodes == []: @@ -1344,28 +1335,28 @@ def _make_syndrome_graph(self): + ("0" * len(self.links) + " ") * (self.T - 1) + "1" * len(self.links) ) - nodes = [] + nodes: List[DecodingGraphNode] = [] for node in self.string2nodes(string): - if not node["is_boundary"]: + if not node.is_boundary: for t in range(self.T + 1): - new_node = node.copy() - new_node["time"] = t + new_node = deepcopy(node) + new_node.time = t if new_node not in nodes: nodes.append(new_node) else: - node["time"] = 0 + node.time = None nodes.append(node) # find pairs that should be connected - edges = [] + edges: List[Tuple[int, int]] = [] for n0, node0 in enumerate(nodes): for n1, node1 in enumerate(nodes): if n0 < n1: # just record all possible edges for now (should be improved later) - dt = abs(node1["time"] - node0["time"]) - adj = set(node0["qubits"]).intersection(set(node1["qubits"])) + dt = abs((node1.time or 0) - (node0.time or 0)) + adj = set(node0.qubits).intersection(set(node1.qubits)) if adj: - if (node0["is_boundary"] ^ node1["is_boundary"]) or dt <= 1: + if (node0.is_boundary ^ node1.is_boundary) or dt <= 1: edges.append((n0, n1)) # put it all in a graph @@ -1377,11 +1368,11 @@ def _make_syndrome_graph(self): source = nodes[n0] target = nodes[n1] qubits = [] - if not (source["is_boundary"] and target["is_boundary"]): - qubits = list(set(source["qubits"]).intersection(target["qubits"])) - if source["time"] != target["time"] and len(qubits) > 1: + if not (source.is_boundary and target.is_boundary): + qubits = list(set(source.qubits).intersection(target.qubits)) + if source.time != target.time and len(qubits) > 1: qubits = [] - edge = {"qubits": qubits, "weight": 1} + edge = DecodingGraphEdge(qubits=qubits, weight=1) S.add_edge(n0, n1, edge) # just record edges as hyperedges for now (should be improved later) hyperedges.append({(n0, n1): edge}) @@ -1428,65 +1419,65 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): node0 = nodes[n0] node1 = nodes[n1] if n0 != n1: - qubits = decoding_graph.graph.get_edge_data(n0, n1)["qubits"] + qubits = decoding_graph.graph.get_edge_data(n0, n1).qubits if qubits: # error on a code qubit between rounds, or during a round - assert ( - node0["time"] == node1["time"] and node0["qubits"] != node1["qubits"] - ) or (node0["time"] != node1["time"] and node0["qubits"] != node1["qubits"]) + assert (node0.time == node1.time and node0.qubits != node1.qubits) or ( + node0.time != node1.time and node0.qubits != node1.qubits + ) qubit = qubits[0] # error between rounds - if node0["time"] == node1["time"]: + if node0.time == node1.time: dts = [] for node in [node0, node1]: - pair = [qubit, node["link qubit"]] + pair = [qubit, node.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: dts.append(dt) - time = [max(0, node0["time"] - 1 + (max(dts) + 1) / round_length)] - time.append(node0["time"] + min(dts) / round_length) + time = [max(0, node0.time - 1 + (max(dts) + 1) / round_length)] + time.append(node0.time + min(dts) / round_length) # error during a round else: # put nodes in descending time order - if node0["time"] < node1["time"]: + if node0.time < node1.time: node_pair = [node1, node0] else: node_pair = [node0, node1] # see when in the schedule each node measures the qubit dts = [] for node in node_pair: - pair = [qubit, node["link qubit"]] + pair = [qubit, node.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: dts.append(dt) # use to define fractional time if dts[0] < dts[1]: - time = [node_pair[1]["time"] + (dts[0] + 1) / round_length] - time.append(node_pair[1]["time"] + dts[1] / round_length) + time = [node_pair[1].time + (dts[0] + 1) / round_length] + time.append(node_pair[1].time + dts[1] / round_length) else: # impossible cases get no valid time time = [] else: # measurement error - assert node0["time"] != node1["time"] and node0["qubits"] == node1["qubits"] - qubit = node0["link qubit"] - time = [node0["time"], node0["time"] + (round_length - 1) / round_length] + assert node0.time != node1.time and node0.qubits == node1.qubits + qubit = node0.properties["link qubit"] + time = [node0.time, node0.time + (round_length - 1) / round_length] time.sort() else: # detected only by one stabilizer - boundary_qubits = list(set(node0["qubits"]).intersection(z_logicals)) + boundary_qubits = list(set(node0.qubits).intersection(z_logicals)) # for the case of boundary stabilizers if boundary_qubits: qubit = boundary_qubits[0] - pair = [qubit, node0["link qubit"]] + pair = [qubit, node0.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: - time = [max(0, node0["time"] - 1 + (dt + 1) / round_length)] - time.append(node0["time"] + dt / round_length) + time = [max(0, node0.time - 1 + (dt + 1) / round_length)] + time.append(node0.time + dt / round_length) else: - qubit = tuple(node0["qubits"] + [node0["link qubit"]]) - time = [node0["time"], node0["time"] + (round_length - 1) / round_length] + qubit = tuple(node0.qubits + [node0.properties["link qubit"]]) + time = [node0.time, node0.time + (round_length - 1) / round_length] if time != []: # only record if not nan if (qubit, time[0], time[1]) not in error_coords: diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 13bbcae8..12dd36bb 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -18,6 +18,7 @@ from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister +from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.circuits.code_circuit import CodeCircuit @@ -107,7 +108,6 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): self.readout() def _get_plaquettes(self): - """ Returns `zplaqs` and `xplaqs`, which are lists of the Z and X type stabilizers. Each plaquettes is specified as a list of four qubits, @@ -122,7 +122,6 @@ def _get_plaquettes(self): xplaq_coords = [] for y in range(-1, d): for x in range(-1, d): - bulk = x in range(d - 1) and y in range(d - 1) ztab = (x == -1 and y % 2 == 0) or (x == d - 1 and y % 2 == 1) xtab = (y == -1 and x % 2 == 1) or (y == d - 1 and x % 2 == 0) @@ -219,7 +218,6 @@ def syndrome_measurement(self, final=False, barrier=False): ) for log in ["0", "1"]: - self.circuit[log].add_register(self.zplaq_bits[-1]) self.circuit[log].add_register(self.xplaq_bits[-1]) @@ -262,7 +260,6 @@ def readout(self): self.circuit[log].measure(self.code_qubit, self.code_bit) def _string2changes(self, string): - basis = self.basis # final syndrome for plaquettes deduced from final code qubit readout @@ -347,7 +344,6 @@ def string2raw_logicals(self, string): return str(Z[0]) + " " + str(Z[1]) def _process_string(self, string): - # get logical readout measured_Z = self.string2raw_logicals(string) @@ -361,7 +357,6 @@ def _process_string(self, string): return new_string def _separate_string(self, string): - separated_string = [] for syndrome_type_string in string.split(" "): separated_string.append(syndrome_type_string.split(" ")) @@ -399,11 +394,12 @@ def string2nodes(self, string, **kwargs): boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - bnode = {"time": 0} - bnode["qubits"] = self._logicals[self.basis][-bqec_index - 1] - bnode["is_boundary"] = True - bnode["element"] = 1 - bqec_index - nodes.append(bnode) + node = DecodingGraphNode( + is_boundary=True, + qubits=self._logicals[self.basis][-bqec_index - 1], + index=1 - bqec_index, + ) + nodes.append(node) # bulk nodes for syn_type in range(1, len(separated_string)): @@ -411,14 +407,11 @@ def string2nodes(self, string, **kwargs): elements = separated_string[syn_type][syn_round] for qec_index, element in enumerate(elements[::-1]): if element == "1": - node = {"time": syn_round} if self.basis == "x": qubits = self.css_x_stabilizer_ops[qec_index] else: qubits = self.css_z_stabilizer_ops[qec_index] - node["qubits"] = qubits - node["is_boundary"] = False - node["element"] = qec_index + node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -441,9 +434,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): num_errors (int): Minimum number of errors required to create nodes. """ - bulk_nodes = [node for node in nodes if not node["is_boundary"]] - boundary_nodes = [node for node in nodes if node["is_boundary"]] - given_logicals = set(node["element"] for node in boundary_nodes) + bulk_nodes = [node for node in nodes if not node.is_boundary] + boundary_nodes = [node for node in nodes if node.is_boundary] + given_logicals = set(node.index for node in boundary_nodes) if self.basis == "z": coords = self._zplaq_coords @@ -459,7 +452,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): xs = [] ys = [] for node in bulk_nodes: - x, y = coords[node["element"]] + x, y = coords[node.index] xs.append(x) ys.append(y) dx = max(xs) - min(xs) @@ -477,7 +470,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # find nearest boundary num_errors = (self.d - 1) / 2 for node in bulk_nodes: - x, y = coords[node["element"]] + x, y = coords[node.index] if self.basis == "z": p = y else: @@ -497,12 +490,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # get the required boundary nodes flipped_logical_nodes = [] for elem in flipped_logicals: - node = { - "time": 0, - "qubits": self._logicals[self.basis][elem], - "is_boundary": True, - "element": elem, - } + node = DecodingGraphNode( + is_boundary=True, qubits=self._logicals[self.basis][elem], index=elem + ) flipped_logical_nodes.append(node) return neutral, flipped_logical_nodes, num_errors diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index 71262f42..ee3a9e0d 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -34,4 +34,4 @@ from .circuit_matching_decoder import CircuitModelMatchingDecoder from .repetition_decoder import RepetitionDecoder from .three_bit_decoder import ThreeBitDecoder -from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder +from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder, ClAYGDecoder diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py index 1ab26faf..6f0e5cdf 100644 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ b/src/qiskit_qec/decoders/circuit_matching_decoder.py @@ -11,6 +11,7 @@ from qiskit import QuantumCircuit from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph +from qiskit_qec.utils import DecodingGraphEdge from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome @@ -173,13 +174,13 @@ def _process_graph( n0, n1 = graph.edge_list()[j] source = graph.nodes()[n0] target = graph.nodes()[n1] - if source["time"] != target["time"]: - if source["is_boundary"] == target["is_boundary"] == False: - new_source = source.copy() - new_source["time"] = target["time"] + if source.time != target.time: + if source.is_boundary == target.is_boundary == False: + new_source = copy(source) + new_source.time = target.time nn0 = graph.nodes().index(new_source) - new_target = target.copy() - new_target["time"] = source["time"] + new_target = copy(target) + new_target.time = source.time nn1 = graph.nodes().index(new_target) graph.add_edge(nn0, nn1, edge) @@ -191,16 +192,16 @@ def _process_graph( # add the required attributes # highlighted', 'measurement_error','qubit_id' and 'error_probability' - edge["highlighted"] = False - edge["measurement_error"] = int(source["time"] != target["time"]) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = int(source.time != target.time) # make it so times of boundary/boundary nodes agree - if source["is_boundary"] and not target["is_boundary"]: - if source["time"] != target["time"]: - new_source = source.copy() - new_source["time"] = target["time"] + if source.is_boundary and not target.is_boundary: + if source.time != target.time: + new_source = copy(source) + new_source.time = target.time n = graph.add_node(new_source) - edge["measurement_error"] = 0 + edge.properties["measurement_error"] = 0 edges_to_remove.append((n0, n1)) graph.add_edge(n, n1, edge) @@ -211,29 +212,24 @@ def _process_graph( for n0, source in enumerate(graph.nodes()): for n1, target in enumerate(graph.nodes()): # add weightless nodes connecting different boundaries - if source["time"] == target["time"]: - if source["is_boundary"] and target["is_boundary"]: - if source["qubits"] != target["qubits"]: - edge = { - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } - edge["qubits"] = list( - set(source["qubits"]).intersection((set(target["qubits"]))) + if source.time == target.time: + if source.is_boundary and target.is_boundary: + if source.qubits != target.qubits: + edge = DecodingGraphEdge( + weight=0, + qubits=list(set(source.qubits).intersection((set(target.qubits)))), ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 if (n0, n1) not in graph.edge_list(): graph.add_edge(n0, n1, edge) # connect one of the boundaries at different times - if target["time"] == source["time"] + 1: - if source["qubits"] == target["qubits"] == [0]: - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + if target.time == (source.time or 0) + 1: + if source.qubits == target.qubits == [0]: + edge = DecodingGraphEdge(weight=0, qubits=[]) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 if (n0, n1) not in graph.edge_list(): graph.add_edge(n0, n1, edge) @@ -245,14 +241,14 @@ def _process_graph( idxmap = {} for n, node in enumerate(graph.nodes()): - idxmap[node["time"], tuple(node["qubits"])] = n + idxmap[node.time, tuple(node.qubits)] = n node_layers = [] for node in graph.nodes(): - time = node["time"] + time = node.time or 0 if len(node_layers) < time + 1: node_layers += [[]] * (time + 1 - len(node_layers)) - node_layers[time].append(node["qubits"]) + node_layers[time].append(node.qubits) # create a list of decoding graph layer types # the entries are 'g' for gauge and 's' for stabilizer @@ -298,11 +294,11 @@ def _revise_decoding_graph( # TODO: new edges may be needed for hooks, but raise exception for now raise QiskitQECError("edge {s1} - {s2} not in decoding graph") data = graph.get_edge_data(idxmap[s1], idxmap[s2]) - data["weight_poly"] = wpoly + data.properties["weight_poly"] = wpoly remove_list = [] for source, target in graph.edge_list(): edge_data = graph.get_edge_data(source, target) - if "weight_poly" not in edge_data and edge_data["weight"] != 0: + if "weight_poly" not in edge_data.properties and edge_data.weight != 0: # Remove the edge remove_list.append((source, target)) logging.info("remove edge (%d, %d)", source, target) @@ -340,7 +336,7 @@ def update_edge_weights(self, model: PauliNoiseModel): # p_i is the probability that edge i carries an error # l(i) is 1 if the link belongs to the chain and 0 otherwise for source, target in self.graph.edge_list(): - edge_data = self.graph.get_edge_data(source, target) + edge_data = self.graph.get_edge_data(source, target).properties if "weight_poly" in edge_data: logging.info( "update_edge_weights (%d, %d) %s", diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index b4baa2b2..48b1b74b 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -2,9 +2,9 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2019. +# (C) Copyright IBM 2019-2023. # -# This code is licensed under the Apache License, Version 2.0. You may ddddddd +# This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory # of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. # @@ -25,6 +25,7 @@ import rustworkx as rx from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.exceptions import QiskitQECError +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge class DecodingGraph: @@ -39,7 +40,7 @@ class DecodingGraph: METHOD_NAIVE: str = "naive" AVAILABLE_METHODS = {METHOD_SPITZ, METHOD_NAIVE} - def __init__(self, code, brute=False): + def __init__(self, code, brute=False, graph=None): """ Args: code (CodeCircuit): The QEC code circuit object for which this decoding @@ -50,11 +51,12 @@ def __init__(self, code, brute=False): self.code = code self.brute = brute - - self._make_syndrome_graph() + if graph: + self.graph = graph + else: + self._make_syndrome_graph() def _make_syndrome_graph(self): - if not self.brute and hasattr(self.code, "_make_syndrome_graph"): self.graph, self.hyperedges = self.code._make_syndrome_graph() else: @@ -62,7 +64,6 @@ def _make_syndrome_graph(self): self.hyperedges = [] if self.code is not None: - # get the circuit used as the base case if isinstance(self.code.circuit, dict): if "base" not in dir(self.code): @@ -90,20 +91,18 @@ def _make_syndrome_graph(self): n0 = graph.nodes().index(source) n1 = graph.nodes().index(target) qubits = [] - if not (source["is_boundary"] and target["is_boundary"]): - qubits = list( - set(source["qubits"]).intersection(target["qubits"]) - ) + if not (source.is_boundary and target.is_boundary): + qubits = list(set(source.qubits).intersection(target.qubits)) if not qubits: continue if ( - source["time"] != target["time"] + source.time != target.time and len(qubits) > 1 - and not source["is_boundary"] - and not target["is_boundary"] + and not source.is_boundary + and not target.is_boundary ): qubits = [] - edge = {"qubits": qubits, "weight": 1} + edge = DecodingGraphEdge(qubits, 1) graph.add_edge(n0, n1, edge) if (n1, n0) not in hyperedge: hyperedge[n0, n1] = edge @@ -112,7 +111,9 @@ def _make_syndrome_graph(self): self.graph = graph - def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ): + def get_error_probs( + self, counts, logical: str = "0", method: str = METHOD_SPITZ + ) -> List[Tuple[Tuple[int, int], float]]: """ Generate probabilities of single error events from result counts. @@ -138,7 +139,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ # method for edges if method == self.METHOD_SPITZ: - neighbours = {} av_v = {} for n in self.graph.node_indexes(): @@ -154,7 +154,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ neighbours[n1].append(n0) for string in counts: - # list of i for which v_i=1 error_nodes = self.code.string2nodes(string, logical=logical) @@ -180,10 +179,9 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ boundary = [] error_probs = {} for n0, n1 in self.graph.edge_list(): - - if self.graph[n0]["is_boundary"]: + if self.graph[n0].is_boundary: boundary.append(n1) - elif self.graph[n1]["is_boundary"]: + elif self.graph[n1].is_boundary: boundary.append(n0) else: if (1 - 2 * av_xor[n0, n1]) != 0: @@ -211,7 +209,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ # generally applicable but approximate method elif method == self.METHOD_NAIVE: - # for every edge in the graph, we'll determine the histogram # of whether their nodes are in the error nodes count = { @@ -238,9 +235,9 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ else: ratio = np.nan p = ratio / (1 + ratio) - if self.graph[n0]["is_boundary"] and not self.graph[n1]["is_boundary"]: + if self.graph[n0].is_boundary and not self.graph[n1].is_boundary: edge = (n1, n1) - elif not self.graph[n0]["is_boundary"] and self.graph[n1]["is_boundary"]: + elif not self.graph[n0].is_boundary and self.graph[n1].is_boundary: edge = (n0, n0) else: edge = (n0, n1) @@ -267,7 +264,7 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): boundary_nodes = [] for n, node in enumerate(self.graph.nodes()): - if node["is_boundary"]: + if node.is_boundary: boundary_nodes.append(n) for edge in self.graph.edge_list(): @@ -346,7 +343,6 @@ def __init__( round_schedule: str, basis: str, ): - self.css_x_gauge_ops = css_x_gauge_ops self.css_x_stabilizer_ops = css_x_stabilizer_ops self.css_x_boundary = css_x_boundary @@ -422,16 +418,20 @@ def _decoding_graph(self): all_z = gauges elif layer == "s": all_z = stabilizers - for supp in all_z: - node = {"time": time, "qubits": supp, "highlighted": False} + for index, supp in enumerate(all_z): + node = DecodingGraphNode(time=time, qubits=supp, index=index) + node = DecodingGraphNode(time=time, qubits=supp, index=index) + node.properties["highlighted"] = True graph.add_node(node) logging.debug("node %d t=%d %s", idx, time, supp) idxmap[(time, tuple(supp))] = idx node_layer.append(idx) idx += 1 - for supp in boundary: + for index, supp in enumerate(boundary): # Add optional is_boundary property for pymatching - node = {"time": time, "qubits": supp, "highlighted": False, "is_boundary": True} + node = DecodingGraphNode(is_boundary=True, qubits=supp, index=index) + node = DecodingGraphNode(is_boundary=True, qubits=supp, index=index) + node.properties["highlighted"] = False graph.add_node(node) logging.debug("boundary %d t=%d %s", idx, time, supp) idxmap[(time, tuple(supp))] = idx @@ -462,12 +462,9 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = { - "qubits": [com[0]], - "measurement_error": 0, - "weight": 1, - "highlighted": False, - } + edge = DecodingGraphEdge(qubits=[com[0]], weight=1) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge( idxmap[(time, tuple(op_g))], idxmap[(time, tuple(op_h))], edge ) @@ -485,12 +482,9 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + edge = DecodingGraphEdge(qubits=[], weight=0) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge(idxmap[(time, tuple(bound_g))], idxmap[(time, tuple(bound_h))], edge) logging.debug("spacelike boundary t=%d (%s, %s)", time, bound_g, bound_h) @@ -529,12 +523,10 @@ def _decoding_graph(self): # error_probability is a floating point number # Case (a) if set(com) == set(op_h) or set(com) == set(op_g): - edge = { - "qubits": [], - "measurement_error": 1, - "weight": 1, - "highlighted": False, - } + edge = DecodingGraphEdge(qubits=[], weight=1) + edge = DecodingGraphEdge(qubits=[], weight=1) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 1 graph.add_edge( idxmap[(time - 1, tuple(op_h))], idxmap[(time, tuple(op_g))], @@ -542,12 +534,10 @@ def _decoding_graph(self): ) logging.debug("timelike t=%d (%s, %s)", time, op_g, op_h) else: # Case (b) - edge = { - "qubits": [com[0]], - "measurement_error": 1, - "weight": 1, - "highlighted": False, - } + edge = DecodingGraphEdge(qubits=[com[0]], weight=1) + edge = DecodingGraphEdge(qubits=[com[0]], weight=1) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 1 graph.add_edge( idxmap[(time - 1, tuple(op_h))], idxmap[(time, tuple(op_g))], @@ -557,12 +547,10 @@ def _decoding_graph(self): logging.debug(" qubits %s", [com[0]]) # Add a single time-like edge between boundary vertices at # time t-1 and t - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + edge = DecodingGraphEdge(qubits=[], weight=0) + edge = DecodingGraphEdge(qubits=[], weight=0) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge( idxmap[(time - 1, tuple(boundary[0]))], idxmap[(time, tuple(boundary[0]))], edge ) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 1f5ef5d4..844a293d 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -18,11 +18,12 @@ from copy import copy, deepcopy from dataclasses import dataclass -from typing import Dict, List, Set +from typing import Dict, List, Set, Tuple from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge from qiskit_qec.exceptions import QiskitQECError @@ -128,11 +129,9 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] for element, z_logical in enumerate(self.z_logicals): - node = {"time": 0, "is_boundary": True} + node = DecodingGraphNode(is_boundary=True, qubits=[z_logical], index=element) if isinstance(self.code, ArcCircuit): - node["link qubit"] = None - node["qubits"] = [z_logical] - node["element"] = element + node.properties["link qubit"] = None boundary_nodes.append(node) return boundary_nodes @@ -169,7 +168,7 @@ def cluster(self, nodes): if c is not None: final_clusters[n] = c else: - if not dg[n]["is_boundary"]: + if not dg[n].is_boundary: ns.append(n) con_comps.append(con_comp) clusterss.append(clusters) @@ -199,14 +198,14 @@ def process(self, string): cluster_nodes = {c: [] for c in clusters.values()} for n, c in clusters.items(): node = decoding_graph.graph[n] - if not node["is_boundary"]: + if not node.is_boundary: cluster_nodes[c].append(node) # get the list of required logicals for each cluster cluster_logicals = {} for c, nodes in cluster_nodes.items(): _, logical_nodes, _ = code.check_nodes(nodes) - z_logicals = [node["qubits"][0] for node in logical_nodes] + z_logicals = [node.qubits[0] for node in logical_nodes] cluster_logicals[c] = z_logicals # get the net effect on each logical @@ -247,7 +246,7 @@ class BoundaryEdge: index: int cluster_vertex: int neighbour_vertex: int - data: Dict[str, object] + data: DecodingGraphEdge def reverse(self): """ @@ -333,10 +332,10 @@ def process(self, string: str): if isinstance(self.code, ArcCircuit): # NOTE: it just corrects for final logical readout for node in erasure.nodes(): - if node["is_boundary"]: + if node.is_boundary: # FIXME: Find a general way to go from physical qubit # index to code qubit index - qubit_to_be_corrected = int(node["qubits"][0] / 2) + qubit_to_be_corrected = int(node.qubits[0] / 2) output[qubit_to_be_corrected] = (output[qubit_to_be_corrected] + 1) % 2 continue @@ -360,12 +359,12 @@ def cluster(self, nodes) -> List[List[int]]: """ node_indices = [self.graph.nodes().index(node) for node in nodes] for node_index, _ in enumerate(self.graph.nodes()): - self.graph[node_index]["syndrome"] = node_index in node_indices - self.graph[node_index]["root"] = node_index + self.graph[node_index].properties["syndrome"] = node_index in node_indices + self.graph[node_index].properties["root"] = node_index for edge in self.graph.edges(): - edge["growth"] = 0 - edge["fully_grown"] = False + edge.properties["growth"] = 0 + edge.properties["fully_grown"] = False self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots = set(node_indices) @@ -409,11 +408,11 @@ def find(self, u: int) -> int: Returns: root (int): The root of the cluster of node u. """ - if self.graph[u]["root"] == u: - return self.graph[u]["root"] + if self.graph[u].properties["root"] == u: + return self.graph[u].properties["root"] - self.graph[u]["root"] = self.find(self.graph[u]["root"]) - return self.graph[u]["root"] + self.graph[u].properties["root"] = self.find(self.graph[u].properties["root"]) + return self.graph[u].properties["root"] def _grow_clusters(self) -> List[FusionEntry]: """ @@ -427,9 +426,12 @@ def _grow_clusters(self) -> List[FusionEntry]: for root in self.odd_cluster_roots: cluster = self.clusters[root] for edge in cluster.boundary: - edge.data["growth"] += 0.5 - if edge.data["growth"] >= edge.data["weight"] and not edge.data["fully_grown"]: - edge.data["fully_grown"] = True + edge.data.properties["growth"] += 0.5 + if ( + edge.data.properties["growth"] >= edge.data.weight + and not edge.data.properties["fully_grown"] + ): + edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(edge.index) fusion_entry = FusionEntry( u=edge.cluster_vertex, v=edge.neighbour_vertex, connecting_edge=edge @@ -473,7 +475,7 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]) -> None: else: self.odd_cluster_roots.discard(new_root) self.odd_cluster_roots.discard(root_to_update) - self.graph[root_to_update]["root"] = new_root + self.graph[root_to_update].properties["root"] = new_root def peeling(self, erasure: PyGraph) -> List[int]: """ " @@ -497,7 +499,7 @@ def peeling(self, erasure: PyGraph) -> List[int]: # Construct spanning forest # Pick starting vertex for vertex in erasure.node_indices(): - if erasure[vertex]["is_boundary"]: + if erasure[vertex].is_boundary: tree.vertices[vertex] = [] break if not tree.vertices: @@ -522,11 +524,251 @@ def peeling(self, erasure: PyGraph) -> List[int]: pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0]] else endpoints[1] tree_vertex = endpoints[0] if pendant_vertex == endpoints[1] else endpoints[1] tree.vertices[tree_vertex].remove(edge) - if erasure[pendant_vertex]["syndrome"] and not erasure[pendant_vertex]["is_boundary"]: + if ( + erasure[pendant_vertex].properties["syndrome"] + and not erasure[pendant_vertex].is_boundary + ): edges.add(edge) - erasure[tree_vertex]["syndrome"] = not erasure[tree_vertex]["syndrome"] - erasure[pendant_vertex]["syndrome"] = False + erasure[tree_vertex].properties["syndrome"] = not erasure[tree_vertex].properties[ + "syndrome" + ] + erasure[pendant_vertex].properties["syndrome"] = False + + return [erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits] + + +@dataclass +class Cluster: + """ + Cluster class for the ClAYG decoder. + FIXME: Remove, when ClAYG decoder uses Union Find infrastructure. + """ + + boundary: List[Tuple[int, int]] # List[(edge, neighbour)] + fully_grown_edges: List[int] # List[edge] + nodes: List[int] # List[node_index] + atypical_nodes: List[int] # List[node_index] + + +class ClAYGDecoder(UnionFindDecoder): + """ + Decoder that is very similar to the Union Find decoder, but instead of adding clusters all at once, + adds them separated by syndrome round with a growth and merge phase in between. + Then it just proceeds like the Union Find decoder. + + FIXME: Use the Union Find infrastructure and just change the self.cluster() method. Problem is that + the peeling decoder needs a modified version the graph with the syndrome nodes marked, which is done + in the process method. For now it is mostly its separate thing, but merging them shouldn't be + too big of a hassle. + Merge method should also be modified, as boundary clusters are not marked as odd clusters. + """ + + def __init__(self, code, logical: str, decoding_graph: DecodingGraph = None) -> None: + super().__init__(code, logical, decoding_graph) + self.graph = deepcopy(self.decoding_graph.graph) + self.roots = {} + self.odd = {} + + def process(self, string: str): + """ + Process an output string and return corrected final outcomes. + + Args: + string (str): Output string of the code. + Returns: + corrected_z_logicals (list): A list of integers that are 0 or 1. + These are the corrected values of the final transversal + measurement, corresponding to the logical operators of + self.z_logicals. + """ + self.graph = deepcopy(self.decoding_graph.graph) + nodes_at_time_zero = [] + for index, node in enumerate(self.graph.nodes()): + if node.time == 0 or node.is_boundary: + nodes_at_time_zero.append(index) + self.graph = self.graph.subgraph(nodes_at_time_zero) + + for edge in self.graph.edges(): + edge.weight = 1 + edge.properties["growth"] = 0 + + string = "".join([str(c) for c in string[::-1]]) + output = [int(bit) for bit in list(string.split(" ", maxsplit=self.code.d)[0])][::-1] + nodes = self.code.string2nodes(string, logical=self.logical) + + clusters = self.cluster(nodes) + + for cluster in clusters: + erasure_graph = deepcopy(self.graph) + for node in cluster.nodes: + erasure_graph[node].properties["syndrome"] = False + for node in cluster.atypical_nodes: + erasure_graph[node].properties["syndrome"] = True + erasure = erasure_graph.subgraph(cluster.nodes + cluster.atypical_nodes) + qubits_to_be_corrected = self.peeling(erasure) + for idx in qubits_to_be_corrected: + output[idx] = (output[idx] + 1) % 2 + + return output + + def cluster(self, nodes) -> List[List[int]]: + """ + Create clusters using the union-find algorithm. + + Args: + nodes (List): List of non-typical nodes in the syndrome graph, + of the type produced by `string2nodes`. + + Returns: + FIXME: Make this more expressive. Maybe return a list of separate PyGraphs? + Would fix the infrastructure-sharing-issue mentioned above. + clusters (List[List[int]]): List of Lists of indices of nodes in clusters. + """ + self.roots = {} + self.odd = {} + for i, node in enumerate(self.graph.nodes()): + self.roots[i] = i + self.odd[i] = False + + self.clusters: Dict[int, Cluster] = { + node_index: None for node_index in self.graph.node_indices() + } + + self.odd_cluster_roots: List[int] = [] + times = [[] for _ in range(self.code.T + 1)] + boundaries = [] + for node in deepcopy(nodes): + if nodes.count(node) > 1: + continue + if node.is_boundary: + boundaries.append(node) + else: + node.time = 0 + times[node.time].append(node) + + neutral_clusters = [] + + for time in times: + if not time: + continue + for node in time: + neutral_clusters += self.add_atypical_node_to_decoding_graph(node, True) + for root in self.odd_cluster_roots: + neutral_clusters += self.grow_cluster_and_merge(root) + + for node in boundaries: + neutral_clusters += self.add_atypical_node_to_decoding_graph(node, False) + + while self.odd_cluster_roots: + for root in self.odd_cluster_roots: + neutral_clusters += self.grow_cluster_and_merge(root) + + return neutral_clusters + + def add_atypical_node_to_decoding_graph(self, node, add_odd_cluster: bool) -> List[Cluster]: + """ + Adds non-typical syndrome nodes to the graph and + neutralize/create clusters around them if necessary + + Args: + node: dictionary with node data in the form produced by string2nodes. + add_odd_cluster (bool): specifices whether the newly created cluster is + going to be added to the odd_clusters_list. + """ + node_index = self.graph.nodes().index(node) + current_cluster_root = self.find(node_index) + cluster = self.clusters[current_cluster_root] + current_cluster_odd = self.odd[current_cluster_root] + neutral_clusters = [] + # If cluster that it's in is odd set it to even and add it to the error log + if current_cluster_odd: + self.odd[current_cluster_root] = False + self.odd_cluster_roots.remove(current_cluster_root) + cluster.atypical_nodes.append(node_index) + if not node_index == current_cluster_root: + # Simple measurement error, don't add it to the error log + # FIXME: Make peeling decoder prestage handle this + neutral_clusters.append(cluster) + for edge, _ in cluster.boundary: + self.graph.edges()[edge].properties["growth"] = 0 + for node_in_cluster in cluster.nodes + cluster.atypical_nodes: + self.roots[node_in_cluster] = node_in_cluster + self.clusters[current_cluster_root] = None + # Else create a new cluster around it and set it to odd + else: + self.roots[node_index] = node_index + self.odd[node_index] = True + if add_odd_cluster: + self.odd_cluster_roots.append(node_index) + boundary: List[Tuple[int, int]] = [] + for edge, (_, neighbour, _) in dict( + self.graph.incident_edge_index_map(node_index) + ).items(): + boundary.append((edge, neighbour)) + self.clusters[node_index] = Cluster( + boundary=boundary, fully_grown_edges=[], nodes=[], atypical_nodes=[node_index] + ) + + return neutral_clusters + + def grow_cluster_and_merge(self, root: int): + """ + Grows the cluster specified by root by half an edge and merges them if necessary. + + Args: + root (int): index of the root node of the cluster. + """ + cluster = self.clusters[root] + if not cluster: + return [] + for edge, neighbour in copy(cluster.boundary): + self.graph.edges()[edge].properties["growth"] += 0.5 + if self.graph.edges()[edge].properties["growth"] < self.graph.edges()[edge].weight: + continue + cluster.boundary.remove((edge, neighbour)) + cluster.fully_grown_edges.append(edge) + self.graph.edges()[edge].properties["growth"] = 0 + neighbour_root = self.find(neighbour) + if neighbour_root == root: + continue + neighbour_odd = self.odd[neighbour_root] + if neighbour_odd: + # It is odd, so there has to be a cluster + neighbour_cluster = self.clusters[neighbour_root] + cluster.boundary += neighbour_cluster.boundary + cluster.fully_grown_edges += neighbour_cluster.fully_grown_edges + cluster.nodes += neighbour_cluster.nodes + cluster.atypical_nodes += neighbour_cluster.atypical_nodes + for edge_to_be_reset, _ in cluster.boundary: + self.graph.edges()[edge_to_be_reset].properties["growth"] = 0 + for node in cluster.nodes + cluster.atypical_nodes: + self.roots[node] = node + for root_to_be_reset in [root, neighbour_root]: + if self.graph[root_to_be_reset].is_boundary: + continue + self.odd[root_to_be_reset] = False + self.clusters[root_to_be_reset] = None + self.odd_cluster_roots.remove(root_to_be_reset) + return [cluster] + else: + cluster.nodes += [neighbour] + self.roots[neighbour] = root + for neighbor_edge, (_, neighbour_neighbour, _) in dict( + self.graph.incident_edge_index_map(neighbour) + ).items(): + if neighbour_neighbour == neighbour: + continue + cluster.boundary.append((neighbor_edge, neighbour_neighbour)) + return [] + + def find(self, u): + """ + Returns the root of the cluster the node belongs to. - return [ - erasure.edges()[edge]["qubits"][0] for edge in edges if erasure.edges()[edge]["qubits"] - ] + Args: + node_index (int): index of the node in self.graph + """ + if self.roots[u] == u: + return u + self.roots[u] = self.find(self.roots[u]) + return self.roots[u] diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py index 5d735469..55fd65d4 100644 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ b/src/qiskit_qec/decoders/rustworkx_matcher.py @@ -6,17 +6,18 @@ import rustworkx as rx from qiskit_qec.decoders.base_matcher import BaseMatcher +from qiskit_qec.utils import DecodingGraphEdge class RustworkxMatcher(BaseMatcher): """Matching subroutines using rustworkx. - The input rustworkx graph is expected to have the following properties: - edge["weight"] : real edge weight - edge["measurement_error"] : bool, true if edge corresponds to measurement error - edge["qubits"] : list of qubit ids associated to edge - vertex["qubits"] : qubit ids involved in gauge measurement - vertex["time"] : integer time step of gauge measurement + The input rustworkx graph is expected to have decoding_graph.Node as the type of the node payload + and decoding_graph.Edge as the type of the edge payload. + + Additionally the edges are expected to have the following properties: + - edge.properties["measurement_error"] (bool): Whether or not the error + corresponds to a measurement error. The annotated graph will also have "highlighted" properties on edges and vertices. """ @@ -36,8 +37,8 @@ def preprocess(self, graph: rx.PyGraph): """ # edge_cost_fn = lambda edge: edge["weight"] - def edge_cost_fn(edge): - return edge["weight"] + def edge_cost_fn(edge: DecodingGraphEdge): + return edge.weight length = rx.all_pairs_dijkstra_path_lengths(graph, edge_cost_fn) self.length = {s: dict(length[s]) for s in length} @@ -111,11 +112,11 @@ def _error_chain_from_vertex_path( for i in range(len(vertex_path) - 1): v0 = vertex_path[i] v1 = vertex_path[i + 1] - if graph.get_edge_data(v0, v1)["measurement_error"] == 1: + if graph.get_edge_data(v0, v1).properties["measurement_error"] == 1: measurement_errors ^= set( - [(graph.nodes()[v0]["time"], tuple(graph.nodes()[v0]["qubits"]))] + [(graph.nodes()[v0].time, tuple(graph.nodes()[v0].qubits))] ) - qubit_errors ^= set(graph.get_edge_data(v0, v1)["qubits"]) + qubit_errors ^= set(graph.get_edge_data(v0, v1).qubits) logging.debug( "_error_chain_for_vertex_path q = %s, m = %s", qubit_errors, @@ -173,7 +174,7 @@ def _make_annotated_graph(gin: rx.PyGraph, paths: List[List[int]]) -> rx.PyGraph for path in paths: # Highlight the endpoints of the path for i in [0, -1]: - graph.nodes()[path[i]]["highlighted"] = True + graph.nodes()[path[i]].properties["highlighted"] = True # Highlight the edges along the path for i in range(len(path) - 1): try: @@ -181,5 +182,5 @@ def _make_annotated_graph(gin: rx.PyGraph, paths: List[List[int]]) -> rx.PyGraph except ValueError: idx = list(graph.edge_list()).index((path[i + 1], path[i])) edge = graph.edges()[idx] - edge["highlighted"] = True + edge.properties["highlighted"] = True return graph diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index 15f341d9..e2fb199f 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -1,20 +1,31 @@ """Temporary module with methods for graphs.""" import json +import os import networkx as nx import rustworkx as rx +from qiskit_qec.utils.decoding_graph_attributes import DecodingGraphEdge, DecodingGraphNode + def ret2net(graph: rx.PyGraph): """Convert rustworkx graph to equivalent networkx graph.""" nx_graph = nx.Graph() for j, node in enumerate(graph.nodes()): nx_graph.add_node(j) - for k, v in node.items(): - nx.set_node_attributes(nx_graph, {j: v}, k) + for k, v in node: + if isinstance(v, dict): + for _k, _v in v.items(): + nx.set_node_attributes(nx_graph, {j: _v}, _k) + else: + nx.set_node_attributes(nx_graph, {j: v}, k) for j, (n0, n1) in enumerate(graph.edge_list()): nx_graph.add_edge(n0, n1) - for k, v in graph.edges()[j].items(): - nx.set_edge_attributes(nx_graph, {(n0, n1): v}, k) + for k, v in graph.edges()[j]: + if isinstance(v, dict): + for _k, _v in v.items(): + nx.set_edge_attributes(nx_graph, {(n0, n1): _v}, _k) + else: + nx.set_edge_attributes(nx_graph, {(n0, n1): v}, k) return nx_graph @@ -25,3 +36,42 @@ def write_graph_to_json(graph: rx.PyGraph, filename: str): from_ret = ret2net(graph) json.dump(nx.node_link_data(from_ret), fp, indent=4, default=str) fp.close() + + +def get_cached_decoding_graph(path): + """ + Returns graph cached in file at path "file" using cache_graph method. + """ + if os.path.isfile(path) and not os.stat(path) == 0: + with open(path, "r+", encoding="utf-8") as file: + json_data = json.loads(file.read()) + net_graph = nx.node_link_graph(json_data) + ret_graph = rx.networkx_converter(net_graph, keep_attributes=True) + for node_index, node in zip(ret_graph.node_indices(), ret_graph.nodes()): + del node["__networkx_node__"] + qubits = node.pop("qubits") + time = node.pop("time") + index = node.pop("index") + is_boundary = node.pop("is_boundary") + properties = node.copy() + node = DecodingGraphNode(is_boundary=is_boundary, time=time, index=index, qubits=qubits) + node.properties = properties + ret_graph[node_index] = node + for edge_index, edge in zip(ret_graph.edge_indices(), ret_graph.edges()): + weight = edge.pop("weight") + qubits = edge.pop("qubits") + properties = edge.copy() + edge = DecodingGraphEdge(weight=weight, qubits=qubits) + edge.properties = properties + ret_graph.update_edge_by_index(edge_index, edge) + return ret_graph + return None + + +def cache_decoding_graph(graph, path): + """ + Cache rustworkx PyGraph to file at path. + """ + net_graph = ret2net(graph) + with open(path, "w+", encoding="utf-8") as file: + json.dump(nx.node_link_data(net_graph), file) diff --git a/src/qiskit_qec/utils/__init__.py b/src/qiskit_qec/utils/__init__.py index 5bb384c2..8809eff3 100644 --- a/src/qiskit_qec/utils/__init__.py +++ b/src/qiskit_qec/utils/__init__.py @@ -31,4 +31,4 @@ """ from . import indexer, pauli_rep, visualizations -from .decodoku import Decodoku +from .decoding_graph_attributes import DecodingGraphNode, DecodingGraphEdge diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py new file mode 100644 index 00000000..c06f9944 --- /dev/null +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name + +""" +Graph used as the basis of decoders. +""" +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from qiskit_qec.exceptions import QiskitQECError + + +class DecodingGraphNode: + """ + Class to describe DecodingGraph nodes. + + Attributes: + - is_boundary (bool): whether or not the node is a boundary node. + - time (int): what syndrome node the node corrsponds to. Doesn't + need to be set if it's a boundary node. + - qubits (List[int]): List of indices which are stabilized by + this ancilla. + - index (int): Unique index in measurement round. + - properties (Dict[str, Any]): Decoder/code specific attributes. + Are not considered when comparing nodes. + """ + + def __init__(self, index: int, qubits: List[int] = None, is_boundary=False, time=None) -> None: + if not is_boundary and time is None: + raise QiskitQECError( + "DecodingGraph node must either have a time or be a boundary node." + ) + + self.is_boundary: bool = is_boundary + self.time: Optional[int] = time if not is_boundary else None + self.qubits: List[int] = qubits if qubits else [] + self.index: int = index + self.properties: Dict[str, Any] = {} + + def __eq__(self, rhs): + if not isinstance(rhs, DecodingGraphNode): + return NotImplemented + + result = ( + self.index == rhs.index + and set(self.qubits) == set(rhs.qubits) + and self.is_boundary == rhs.is_boundary + ) + if not self.is_boundary: + result = result and self.time == rhs.time + return result + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value + + +@dataclass +class DecodingGraphEdge: + """ + Class to describe DecodingGraph edges. + + Attributes: + - qubits (List[int]): List of indices of code qubits that correspond to this edge. + - weight (float): Weight of the edge. + - properties (Dict[str, Any]): Decoder/code specific attributes. + Are not considered when comparing edges. + """ + + qubits: List[int] + weight: float + # TODO: Should code/decoder specific properties be accounted for when comparing edges + properties: Dict[str, Any] = field(default_factory=dict) + + def __eq__(self, rhs) -> bool: + if not isinstance(rhs, DecodingGraphNode): + return NotImplemented + + return set(self.qubits) == set(rhs.qubits) and self.weight == rhs.weight + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value diff --git a/src/qiskit_qec/utils/decodoku.py b/src/qiskit_qec/utils/decodoku.py index eaf46302..fafcaada 100644 --- a/src/qiskit_qec/utils/decodoku.py +++ b/src/qiskit_qec/utils/decodoku.py @@ -58,7 +58,6 @@ def __init__(self, p=0.1, k=2, d=10, process=None, errors=None, nonabelian=False self._generate_graph() def _generate_syndrome(self): - syndrome = {} for x in range(self.size): for y in range(self.size): @@ -69,7 +68,6 @@ def _generate_syndrome(self): else: error_num = poisson(self.p * 2 * self.size**2) for _ in range(error_num): - x0 = choice(range(self.size)) y0 = choice(range(self.size)) @@ -317,7 +315,6 @@ def _generate_syndrome(self): self.boundary_errors = parity def _generate_graph(self): - dg = DecodingGraph(None) d = self.size - 1 @@ -334,7 +331,7 @@ def _generate_graph(self): { "y": 0, "x": (d - 1) * (elem == 1) - 1 * (elem == 0), - "element": elem, + "index": elem, "is_boundary": True, } ) @@ -343,10 +340,10 @@ def _generate_graph(self): nodes = dg.graph.nodes() # connect edges to boundary nodes for y in range(self.size): - n0 = nodes.index({"y": 0, "x": -1, "element": 0, "is_boundary": True}) + n0 = nodes.index({"y": 0, "x": -1, "index": 0, "is_boundary": True}) n1 = nodes.index({"y": y, "x": 0, "is_boundary": False}) dg.graph.add_edge(n0, n1, None) - n0 = nodes.index({"y": 0, "x": d - 1, "element": 1, "is_boundary": True}) + n0 = nodes.index({"y": 0, "x": d - 1, "index": 1, "is_boundary": True}) n1 = nodes.index({"y": y, "x": d - 2, "is_boundary": False}) dg.graph.add_edge(n0, n1, None) # connect bulk nodes with space-like edges @@ -373,7 +370,6 @@ def reset_graph(self): self._update_graph(original=False) def _update_graph(self, original=False): - for node in self.decoding_graph.graph.nodes(): node["highlighted"] = False if self.k != 2: @@ -403,7 +399,6 @@ def _update_graph(self, original=False): self.node_color = highlighted_color def _start(self, engine): - d = self.k syndrome = self.syndrome @@ -439,7 +434,6 @@ def _start(self, engine): # this is the function that does everything def _next_frame(self, engine): - d = self.k syndrome = self.syndrome @@ -542,7 +536,7 @@ def draw_graph(self, clusters=True): def get_label(node): if node["is_boundary"] and parity: - return str(parity[node["element"]]) + return str(parity[node["index"]]) elif node["highlighted"] and "value" in node and self.k != 2: return str(node["value"]) else: diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 4d912929..c0338259 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -27,6 +27,7 @@ from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit as RepetitionCode from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder @@ -135,8 +136,16 @@ def test_string2nodes_2(self): (5, 0), "00001", [ - {"time": 0, "qubits": [0], "is_boundary": True, "element": 0}, - {"time": 0, "qubits": [0, 1], "is_boundary": False, "element": 0}, + DecodingGraphNode( + is_boundary=True, + qubits=[0], + index=0, + ), + DecodingGraphNode( + time=0, + qubits=[0, 1], + index=0, + ), ], ] ] @@ -188,12 +197,8 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - n0 = dec.graph.nodes().index( - {"time": 0, "is_boundary": False, "qubits": [0, 1], "element": 0} - ) - n1 = dec.graph.nodes().index( - {"time": 0, "is_boundary": False, "qubits": [1, 2], "element": 1} - ) + n0 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[0, 1], index=0)) + n1 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[1, 2], index=1)) # edges in graph aren't directed and could be in any order if (n0, n1) in p: self.assertTrue(round(p[n0, n1], 2) == 0.33, error) @@ -236,12 +241,12 @@ def single_error_test( string = "".join([str(c) for c in output[::-1]]) nodes = code.string2nodes(string) # check that it doesn't extend over more than two rounds - ts = [node["time"] for node in nodes if not node["is_boundary"]] + ts = [node.time for node in nodes if not node.is_boundary] if ts: minimal = minimal and (max(ts) - min(ts)) <= 1 # check that it doesn't extend beyond the neigbourhood of a code qubit flat_nodes = code.flatten_nodes(nodes) - link_qubits = set(node["link qubit"] for node in flat_nodes) + link_qubits = set(node.properties["link qubit"] for node in flat_nodes) minimal = minimal and link_qubits in incident_links.values() self.assertTrue( minimal, @@ -254,10 +259,10 @@ def single_error_test( ) # and that the given flipped logical makes sense for node in nodes: - if not node["is_boundary"]: + if not node.is_boundary: for logical in flipped_logicals: self.assertTrue( - logical in node["qubits"], + logical in node.qubits, "Error: Single error appears to flip logical is not part of nodes.", ) @@ -352,7 +357,7 @@ def test_single_error_202s(self): nodes = [ node for node in code.string2nodes(string) - if "conjugate" not in node and not node["is_boundary"] + if "conjugate" not in node.properties and not node.is_boundary ] # require at most two (or three for the trivalent vertex or neighbouring aux) self.assertTrue( @@ -469,12 +474,12 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - n0 = dec.graph.nodes().index( - {"time": 0, "qubits": [0, 2], "link qubit": 1, "is_boundary": False, "element": 1} - ) - n1 = dec.graph.nodes().index( - {"time": 0, "qubits": [2, 4], "link qubit": 3, "is_boundary": False, "element": 0} - ) + node = DecodingGraphNode(time=0, qubits=[0, 2], index=1) + node.properties["link qubits"] = 1 + n0 = dec.graph.nodes().index(node) + node = DecodingGraphNode(time=0, qubits=[2, 4], index=0) + node.properties["link qubits"] = 3 + n1 = dec.graph.nodes().index(node) # edges in graph aren't directed and could be in any order if (n0, n1) in p: self.assertTrue(round(p[n0, n1], 2) == 0.33, error) diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index 14b06fa3..43cf1259 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -19,9 +19,10 @@ import unittest from qiskit_qec.circuits.surface_code import SurfaceCodeCircuit +from qiskit_qec.utils import DecodingGraphNode -class TestRepCodes(unittest.TestCase): +class TestSurfaceCodes(unittest.TestCase): """Test the surface code circuits.""" def test_string2nodes(self): @@ -41,39 +42,87 @@ def test_string2nodes(self): test_nodes["x"] = [ [], [ - {"time": 1, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 1, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + DecodingGraphNode( + time=1, + qubits=[0, 1, 3, 4], + index=1, + ), + DecodingGraphNode( + time=1, + qubits=[4, 5, 7, 8], + index=2, + ), ], [ - {"time": 0, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 0, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + DecodingGraphNode( + time=0, + qubits=[0, 1, 3, 4], + index=1, + ), + DecodingGraphNode( + time=0, + qubits=[4, 5, 7, 8], + index=2, + ), ], [ - {"time": 0, "qubits": [0, 3, 6], "is_boundary": True, "element": 0}, - {"time": 1, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, + DecodingGraphNode( + is_boundary=True, + qubits=[0, 3, 6], + index=0, + ), + DecodingGraphNode( + time=1, + qubits=[0, 1, 3, 4], + index=1, + ), ], [ - {"time": 0, "qubits": [2, 5, 8], "is_boundary": True, "element": 1}, - {"time": 1, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + DecodingGraphNode( + is_boundary=True, + qubits=[2, 5, 8], + index=1, + ), + DecodingGraphNode( + time=1, + qubits=[4, 5, 7, 8], + index=2, + ), ], ] test_nodes["z"] = [ [], [ - {"time": 0, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 0, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + DecodingGraphNode( + time=0, + qubits=[1, 4, 2, 5], + index=1, + ), + DecodingGraphNode( + time=0, + qubits=[3, 6, 4, 7], + index=2, + ), ], [ - {"time": 1, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 1, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + DecodingGraphNode( + time=1, + qubits=[1, 4, 2, 5], + index=1, + ), + DecodingGraphNode( + time=1, + qubits=[3, 6, 4, 7], + index=2, + ), ], [ - {"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}, - {"time": 1, "qubits": [0, 3], "is_boundary": False, "element": 0}, + DecodingGraphNode(is_boundary=True, qubits=[0, 1, 2], index=0), + DecodingGraphNode(time=1, qubits=[0, 3], index=0), ], [ - {"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}, - {"time": 1, "qubits": [5, 8], "is_boundary": False, "element": 3}, + DecodingGraphNode(is_boundary=True, qubits=[8, 7, 6], index=1), + DecodingGraphNode(time=1, qubits=[5, 8], index=3), ], ] @@ -81,8 +130,9 @@ def test_string2nodes(self): code = SurfaceCodeCircuit(3, 1, basis=basis) for t, string in enumerate(test_string): nodes = test_nodes[basis][t] + generated_nodes = code.string2nodes(string) self.assertTrue( - code.string2nodes(string) == nodes, + generated_nodes == nodes, "Incorrect nodes for basis = " + basis + " for string = " + string + ".", ) @@ -99,91 +149,91 @@ def test_check_nodes(self): valid = valid and code.check_nodes(nodes) == (True, [], 0) # on one side nodes = [ - {"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}, - {"time": 3, "qubits": [0, 3], "is_boundary": False, "element": 0}, + DecodingGraphNode(qubits=[0, 1, 2], is_boundary=True, index=0), + DecodingGraphNode(time=3, qubits=[0, 3], index=0), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [{"time": 3, "qubits": [0, 3], "is_boundary": False, "element": 0}] + nodes = [DecodingGraphNode(time=3, qubits=[0, 3], index=0)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}], + [DecodingGraphNode(time=0, qubits=[0, 1, 2], is_boundary=True, index=0)], 1.0, ) # and the other nodes = [ - {"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}, - {"time": 3, "qubits": [5, 8], "is_boundary": False, "element": 3}, + DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1), + DecodingGraphNode(time=3, qubits=[5, 8], index=3), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [{"time": 3, "qubits": [5, 8], "is_boundary": False, "element": 3}] + nodes = [DecodingGraphNode(time=3, qubits=[5, 8], index=3)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}], + [DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # and in the middle nodes = [ - {"time": 3, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 3, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + DecodingGraphNode(time=3, qubits=[1, 4, 2, 5], index=1), + DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}] + nodes = [DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}], + [DecodingGraphNode(qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # basis = 'x' code = SurfaceCodeCircuit(3, 3, basis="x") nodes = [ - {"time": 3, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 3, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + DecodingGraphNode(time=3, qubits=[0, 1, 3, 4], index=1), + DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}] + nodes = [DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [2, 5, 8], "is_boundary": True, "element": 1}], + [DecodingGraphNode(qubits=[2, 5, 8], is_boundary=True, index=1)], 1.0, ) # large d code = SurfaceCodeCircuit(5, 3, basis="z") nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}] + nodes = [DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}], + [DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1)], 2.0, ) # wrong boundary nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == ( False, - [{"time": 0, "qubits": [0, 1, 2, 3, 4], "is_boundary": True, "element": 0}], + [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_boundary=True, index=0)], 2, ) # extra boundary nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == (False, [], 0) # ignoring extra nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes, ignore_extra_boundary=True) == (True, [], 1) diff --git a/test/matching/test_circuitmatcher.py b/test/matching/test_circuitmatcher.py index dcc0760b..61de1252 100644 --- a/test/matching/test_circuitmatcher.py +++ b/test/matching/test_circuitmatcher.py @@ -268,7 +268,7 @@ def test_error_pairs(self): corrected_outcomes = dec.process(outcome) fail = temp_syndrome(corrected_outcomes, self.z_logical) failures += fail[0] - self.assertEqual(failures, 128) + self.assertEqual(failures, 140) def test_error_pairs_uniform(self): """Test the case with two faults using rustworkx.""" diff --git a/test/matching/test_pymatchingmatcher.py b/test/matching/test_pymatchingmatcher.py index 6685790d..e9378e1e 100644 --- a/test/matching/test_pymatchingmatcher.py +++ b/test/matching/test_pymatchingmatcher.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge class TestPyMatchingMatcher(unittest.TestCase): @@ -18,14 +19,22 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: graph = rx.PyGraph(multigraph=False) idxmap = {} for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = {"time": 0, "qubits": q, "highlighted": False} + node = DecodingGraphNode(time=0, qubits=q, index=i) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i node = {"time": 0, "qubits": [], "highlighted": False, "is_boundary": True} + node = DecodingGraphNode(is_boundary=True, qubits=[], index=0) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = {"qubits": dat[0], "measurement_error": False, "weight": 1, "highlighted": False} + edge = DecodingGraphEdge( + qubits=dat[0], + weight=1, + ) + edge.properties["measurement_error"] = False + edge.properties["highlighted"] = False graph.add_edge(dat[1], dat[2], edge) return graph, idxmap diff --git a/test/matching/test_retworkxmatcher.py b/test/matching/test_retworkxmatcher.py index 694f7f55..0f9f16ca 100644 --- a/test/matching/test_retworkxmatcher.py +++ b/test/matching/test_retworkxmatcher.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge class TestRustworkxMatcher(unittest.TestCase): @@ -17,15 +18,20 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: """ graph = rx.PyGraph(multigraph=False) idxmap = {} - for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = {"time": 0, "qubits": q, "highlighted": False} + basic_config = [[0, 1], [1, 2], [2, 3], [3, 4]] + for i, q in enumerate(basic_config): + node = DecodingGraphNode(time=0, qubits=q, index=i) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i - node = {"time": 0, "qubits": [], "highlighted": False} + node = DecodingGraphNode(time=0, qubits=[], index=len(basic_config) + 1) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = {"qubits": dat[0], "measurement_error": False, "weight": 1, "highlighted": False} + edge = DecodingGraphEdge(qubits=dat[0], weight=1) + edge.properties["measurement_error"] = False + edge.properties["highlighted"] = False graph.add_edge(dat[1], dat[2], edge) return graph, idxmap @@ -58,17 +64,17 @@ def test_annotate(self): self.rxm.preprocess(graph) highlighted = [(0, (0, 1)), (0, (1, 2)), (0, (3, 4)), (0, ())] # must be even self.rxm.find_errors(graph, idxmap, highlighted) - self.assertEqual(self.rxm.annotated_graph[0]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[1]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[2]["highlighted"], False) - self.assertEqual(self.rxm.annotated_graph[3]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[4]["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[0].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[1].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[2].properties["highlighted"], False) + self.assertEqual(self.rxm.annotated_graph[3].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[4].properties["highlighted"], True) eim = self.rxm.annotated_graph.edge_index_map() - self.assertEqual(eim[0][2]["highlighted"], False) - self.assertEqual(eim[1][2]["highlighted"], True) - self.assertEqual(eim[2][2]["highlighted"], False) - self.assertEqual(eim[3][2]["highlighted"], False) - self.assertEqual(eim[4][2]["highlighted"], True) + self.assertEqual(eim[0][2].properties["highlighted"], False) + self.assertEqual(eim[1][2].properties["highlighted"], True) + self.assertEqual(eim[2][2].properties["highlighted"], False) + self.assertEqual(eim[3][2].properties["highlighted"], False) + self.assertEqual(eim[4][2].properties["highlighted"], True) if __name__ == "__main__": diff --git a/test/union_find/test_clayg.py b/test/union_find/test_clayg.py new file mode 100644 index 00000000..af524d5f --- /dev/null +++ b/test/union_find/test_clayg.py @@ -0,0 +1,197 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for template.""" + +import math +import random +import unittest +from qiskit_qec.analysis.faultenumerator import FaultEnumerator +from qiskit_qec.decoders import ClAYGDecoder +from qiskit_qec.circuits import SurfaceCodeCircuit, RepetitionCodeCircuit +from qiskit_qec.decoders.temp_code_util import temp_syndrome +from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel + + +def flip_with_probability(p, val): + """ + Flips parity of val with probability p. + """ + if random.random() <= p: + val = (val + 1) % 2 + return val + + +def noisy_surface_code_outcome(d, p): + """ + Generates outcome for surface code with phenomenological noise built in. + """ + string = "" + qubits = [0 for _ in range(d**2)] + for _ in range(d): + for qubit in qubits: + qubit = flip_with_probability(p, qubit) + # Top ancillas + for i in [2 * i for i in range((d - 1) // 2)]: + ancilla = (qubits[i] + qubits[i + 1]) % 2 + ancilla = flip_with_probability(p, ancilla) + string += str(ancilla) + for row in range(d - 1): + offset = (row + 1) % 2 + for topleft in [offset + row * d + 2 * i for i in range((d - 1) // 2)]: + ancilla = ( + qubits[topleft] + + qubits[topleft + 1] + + qubits[topleft + d] + + qubits[topleft + d + 1] + ) % 2 + ancilla = flip_with_probability(p, ancilla) + string += str(ancilla) + for i in [d * (d - 1) + 1 + 2 * i for i in range((d - 1) // 2)]: + ancilla = (qubits[i] + qubits[i + 1]) % 2 + ancilla = flip_with_probability(p, ancilla) + string += str(ancilla) + string += " " + for qubit in qubits: + qubit = flip_with_probability(p, qubit) + string += str(qubit) + return string + + +class ClAYGDecoderTest(unittest.TestCase): + """Tests will be here.""" + + def setUp(self) -> None: + # Bit-flip circuit noise model + p = 0.05 + noise_model = PauliNoiseModel() + noise_model.add_operation("cx", {"ix": 1, "xi": 1, "xx": 1}) + noise_model.add_operation("id", {"x": 1}) + noise_model.add_operation("reset", {"x": 1}) + noise_model.add_operation("measure", {"x": 1}) + noise_model.add_operation("x", {"x": 1, "y": 1, "z": 1}) + noise_model.set_error_probability("cx", p) + noise_model.set_error_probability("x", p) + noise_model.set_error_probability("id", p) + noise_model.set_error_probability("reset", p) + noise_model.set_error_probability("measure", p) + self.noise_model = noise_model + + self.fault_enumeration_method = "stabilizer" + + return super().setUp() + + def test_surface_code_d3(self): + """ + Test the ClAYG decoder on a surface code with d=3 and T=3 + with faults inserted by FaultEnumerator by checking if the syndromes + have even parity (if it's a valid code state) and if the logical value measured + is the one encoded by the circuit. + """ + for logical in ["0", "1"]: + code = SurfaceCodeCircuit(d=3, T=3) + decoder = ClAYGDecoder(code, logical) + + fault_enumerator = FaultEnumerator( + code.circuit[logical], method=self.fault_enumeration_method, model=self.noise_model + ) + for fault in fault_enumerator.generate(): + outcome = "".join([str(x) for x in fault[3]]) + corrected_outcome = decoder.process(outcome) + stabilizers = temp_syndrome(corrected_outcome, code.css_z_stabilizer_ops) + for syndrome in stabilizers: + self.assertEqual(syndrome, 0) + logical_measurement = temp_syndrome(corrected_outcome, [code.css_z_logical])[0] + self.assertEqual(str(logical_measurement), logical) + + def test_repetition_code_d5(self): + """ + Test the ClAYG decoder on a repetition code with d=5 and T=5 + with faults inserted by FaultEnumerator by checking if the syndromes + have even parity (if it's a valid code state) and if the logical value measured + is the one encoded by the circuit. + """ + for logical in ["0", "1"]: + code = RepetitionCodeCircuit(d=5, T=5) + decoder = ClAYGDecoder(code, logical) + fault_enumerator = FaultEnumerator( + code.circuit[logical], method=self.fault_enumeration_method, model=self.noise_model + ) + for fault in fault_enumerator.generate(): + outcome = "".join([str(x) for x in fault[3]]) + corrected_outcome = decoder.process(outcome) + stabilizers = temp_syndrome(corrected_outcome, code.css_z_stabilizer_ops) + for syndrome in stabilizers: + self.assertEqual(syndrome, 0) + logical_measurement = temp_syndrome(corrected_outcome, code.css_z_logical)[0] + self.assertEqual(str(logical_measurement), logical) + + def test_error_rates(self): + """ + Test the error rates using some codes (currently only repetition codes, but in + the future also ARCs, surface codes, HHCs etc). + """ + d = 8 + p = 0.01 + samples = 1000 + + testcases = [] + testcases = [ + "".join([random.choices(["0", "1"], [1 - p, p])[0] for _ in range(d)]) + for _ in range(samples) + ] + codes = self.construct_codes(d) + + # now run them all and check it works + for code in codes: + decoder = ClAYGDecoder(code, logical="0") + z_logicals = code.css_z_logical[0] + + logical_errors = 0 + min_flips_for_logical = code.d + for sample in range(samples): + # generate random string + string = "" + for _ in range(code.T): + string += "0" * (d - 1) + " " + string += testcases[sample] + # get and check corrected_z_logicals + outcome = decoder.process(string) + logical_outcome = sum([outcome[int(z_logical / 2)] for z_logical in z_logicals]) % 2 + if not logical_outcome == 0: + logical_errors += 1 + min_flips_for_logical = min(min_flips_for_logical, string.count("1")) + + # check that error rates are at least d/3 + self.assertTrue( + logical_errors / samples + < (math.factorial(d)) / (math.factorial(int(d / 2)) ** 2) * p**4, + "Logical error rate shouldn't exceed d!/((d/2)!^2)*p^(d/2).", + ) + self.assertTrue( + min_flips_for_logical >= d / 2, + "Minimum amount of errors that also causes logical errors shouldn't be lower than d/2.", + ) + + def construct_codes(self, d): + """ + Construct codes for the logical error rate test. + """ + codes = [] + # TODO: Add more codes + codes.append(RepetitionCodeCircuit(d=d, T=1)) + return codes + + +if __name__ == "__main__": + unittest.main() diff --git a/test/utils/test_decodoku.py b/test/utils/test_decodoku.py index d0be7f07..702d468c 100644 --- a/test/utils/test_decodoku.py +++ b/test/utils/test_decodoku.py @@ -18,7 +18,7 @@ import unittest -from qiskit_qec.utils import Decodoku +from qiskit_qec.utils.decodoku import Decodoku class TestDecodoku(unittest.TestCase):