From c5ddf4a4c208ca5d5355e77e8bb7c8c3217d9fd6 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Wed, 8 Mar 2023 17:22:07 +0100 Subject: [PATCH 01/22] make hdrg decoder more universal --- src/qiskit_qec/circuits/repetition_code.py | 14 +++++++++ src/qiskit_qec/decoders/hdrg_decoders.py | 33 ++++++++++------------ test/code_circuits/test_rep_codes.py | 6 ++-- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 244432f2..474be129 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -809,6 +809,20 @@ def _preparation(self): z_logicals = [min(self.code_index.keys())] self.z_logicals = z_logicals + # set css attributes for decoder + gauge_ops = [[link[0], link[2]] for link in self.links] + measured_logical = [[self.z_logicals[0]]] + flip_logical = list(range(self.d)) + boundary = [[logical] for logical in self.z_logicals] + self.css_x_gauge_ops = [] + self.css_x_stabilizer_ops = [] + self.css_x_logical = flip_logical + self.css_x_boundary = [] + self.css_z_gauge_ops = gauge_ops + self.css_z_stabilizer_ops = gauge_ops + self.css_z_logical = measured_logical + self.css_z_boundary = boundary + def _get_202(self, t): """ Returns the position within a 202 sequence for the current measurement round: diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 1f5ef5d4..893dbb17 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -21,9 +21,8 @@ from typing import Dict, List, Set from rustworkx import connected_components, distance_matrix, PyGraph -from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit +from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph -from qiskit_qec.exceptions import QiskitQECError class ClusteringDecoder: @@ -51,22 +50,20 @@ def __init__( code_circuit, decoding_graph: DecodingGraph = None, ): - if not isinstance(code_circuit, (ArcCircuit, RepetitionCodeCircuit)): - raise QiskitQECError("Error: code_circuit not supported.") super().__init__(code_circuit, decoding_graph) - if isinstance(self.code, ArcCircuit): - self.z_logicals = self.code.z_logicals - elif isinstance(self.code, RepetitionCodeCircuit): + if hasattr(self.code, "_xbasis"): if self.code._xbasis: - self.z_logicals = self.code.css_x_logical[0] + self.measured_logicals = self.code.css_x_logical[0] else: - self.z_logicals = self.code.css_z_logical[0] - if isinstance(self.code, ArcCircuit): + self.measured_logicals = self.code.css_z_logical[0] + else: + self.measured_logicals = self.code.css_z_logical[0] + if hasattr(self.code, "code_index"): self.code_index = self.code.code_index - elif isinstance(self.code, RepetitionCodeCircuit): - self.code_index = {2 * j: j for j in range(self.code.d)} + else: + self.code_index = {j: j for j in range(self.code.d)} def _cluster(self, ns, dist_max): """ @@ -127,7 +124,7 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] - for element, z_logical in enumerate(self.z_logicals): + for element, z_logical in enumerate(self.measured_logicals): node = {"time": 0, "is_boundary": True} if isinstance(self.code, ArcCircuit): node["link qubit"] = None @@ -183,10 +180,10 @@ def process(self, string): Args: string (str): Output string of the code. Returns: - corrected_z_logicals (list): A list of integers that are 0 or 1. + corrected_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final transversal measurement, corresponding to the logical operators of - self.z_logicals. + self.measured_logicals. """ code = self.code decoding_graph = self.decoding_graph @@ -210,9 +207,9 @@ def process(self, string): cluster_logicals[c] = z_logicals # get the net effect on each logical - net_z_logicals = {z_logical: 0 for z_logical in self.z_logicals} + net_z_logicals = {z_logical: 0 for z_logical in self.measured_logicals} for c, z_logicals in cluster_logicals.items(): - for z_logical in self.z_logicals: + for z_logical in self.measured_logicals: if z_logical in z_logicals: net_z_logicals[z_logical] += 1 for z_logical, num in net_z_logicals.items(): @@ -220,7 +217,7 @@ def process(self, string): corrected_z_logicals = [] string = string.split(" ")[0] - for z_logical in self.z_logicals: + for z_logical in self.measured_logicals: raw_logical = int(string[-1 - self.code_index[z_logical]]) corrected_logical = (raw_logical + net_z_logicals[z_logical]) % 2 corrected_z_logicals.append(corrected_logical) diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 4d912929..98cf4343 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -532,7 +532,7 @@ def test_clustering_decoder(self): code = ArcCircuit(links, 0) decoding_graph = DecodingGraph(code) decoder = BravyiHaahDecoder(code, decoding_graph=decoding_graph) - errors = {z_logical: 0 for z_logical in decoder.z_logicals} + errors = {z_logical: 0 for z_logical in decoder.measured_logicals} min_error_num = code.d for sample in range(N): # generate random string @@ -541,14 +541,14 @@ def test_clustering_decoder(self): string = string + " " + "0" * (d - 1) # get and check corrected_z_logicals corrected_z_logicals = decoder.process(string) - for j, z_logical in enumerate(decoder.z_logicals): + for j, z_logical in enumerate(decoder.measured_logicals): error = corrected_z_logicals[j] != 1 if error: min_error_num = min(min_error_num, string.count("0")) errors[z_logical] += error # check that error rates are at least d/3 - for z_logical in decoder.z_logicals: + for z_logical in decoder.measured_logicals: self.assertTrue( errors[z_logical] / (sample + 1) < p**2, "Logical error rate greater than p^2.", From bebc10e1bd2ffa5cde1f95fe8bdf0c212e387ce9 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Wed, 8 Mar 2023 18:23:22 +0100 Subject: [PATCH 02/22] fix typos --- src/qiskit_qec/circuits/surface_code.py | 1 + src/qiskit_qec/decoders/hdrg_decoders.py | 20 ++++++++++---------- test/code_circuits/test_rep_codes.py | 6 +++--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 13bbcae8..45b4a787 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -49,6 +49,7 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): super().__init__() self.d = d + self.n = d**2 self.T = 0 self.basis = basis self._resets = resets diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 893dbb17..39f47c4a 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -55,15 +55,15 @@ def __init__( if hasattr(self.code, "_xbasis"): if self.code._xbasis: - self.measured_logicals = self.code.css_x_logical[0] + self.measured_logicals = self.code.css_x_logical else: - self.measured_logicals = self.code.css_z_logical[0] + self.measured_logicals = self.code.css_z_logical else: - self.measured_logicals = self.code.css_z_logical[0] + self.measured_logicals = self.code.css_z_logical if hasattr(self.code, "code_index"): self.code_index = self.code.code_index else: - self.code_index = {j: j for j in range(self.code.d)} + self.code_index = {j: j for j in range(self.code.n)} def _cluster(self, ns, dist_max): """ @@ -128,7 +128,7 @@ def _get_boundary_nodes(self): node = {"time": 0, "is_boundary": True} if isinstance(self.code, ArcCircuit): node["link qubit"] = None - node["qubits"] = [z_logical] + node["qubits"] = z_logical node["element"] = element boundary_nodes.append(node) return boundary_nodes @@ -207,19 +207,19 @@ def process(self, string): cluster_logicals[c] = z_logicals # get the net effect on each logical - net_z_logicals = {z_logical: 0 for z_logical in self.measured_logicals} + net_z_logicals = {z_logical[0]: 0 for z_logical in self.measured_logicals} for c, z_logicals in cluster_logicals.items(): for z_logical in self.measured_logicals: - if z_logical in z_logicals: - net_z_logicals[z_logical] += 1 + if z_logical[0] in z_logicals: + net_z_logicals[z_logical[0]] += 1 for z_logical, num in net_z_logicals.items(): net_z_logicals[z_logical] = num % 2 corrected_z_logicals = [] string = string.split(" ")[0] for z_logical in self.measured_logicals: - raw_logical = int(string[-1 - self.code_index[z_logical]]) - corrected_logical = (raw_logical + net_z_logicals[z_logical]) % 2 + raw_logical = int(string[-1 - self.code_index[z_logical[0]]]) + corrected_logical = (raw_logical + net_z_logicals[z_logical[0]]) % 2 corrected_z_logicals.append(corrected_logical) return corrected_z_logicals diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 98cf4343..b7976ccc 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -532,7 +532,7 @@ def test_clustering_decoder(self): code = ArcCircuit(links, 0) decoding_graph = DecodingGraph(code) decoder = BravyiHaahDecoder(code, decoding_graph=decoding_graph) - errors = {z_logical: 0 for z_logical in decoder.measured_logicals} + errors = {z_logical[0]: 0 for z_logical in decoder.measured_logicals} min_error_num = code.d for sample in range(N): # generate random string @@ -545,12 +545,12 @@ def test_clustering_decoder(self): error = corrected_z_logicals[j] != 1 if error: min_error_num = min(min_error_num, string.count("0")) - errors[z_logical] += error + errors[z_logical[0]] += error # check that error rates are at least d/3 for z_logical in decoder.measured_logicals: self.assertTrue( - errors[z_logical] / (sample + 1) < p**2, + errors[z_logical[0]] / (sample + 1) < p**2, "Logical error rate greater than p^2.", ) self.assertTrue( From 7a1ab20c68a922b29edcbc9fc840a93821e5f147 Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Mon, 6 Mar 2023 10:59:46 +0100 Subject: [PATCH 03/22] Add first version of Node and Edge types Tests won't run because of circular import issue --- src/qiskit_qec/circuits/repetition_code.py | 153 +++++++++--------- src/qiskit_qec/circuits/surface_code.py | 43 ++--- src/qiskit_qec/decoders/__init__.py | 2 +- .../decoders/circuit_matching_decoder.py | 70 ++++---- src/qiskit_qec/decoders/decoding_graph.py | 46 ++++-- src/qiskit_qec/decoders/hdrg_decoders.py | 61 +++---- src/qiskit_qec/decoders/rustworkx_matcher.py | 27 ++-- 7 files changed, 216 insertions(+), 186 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index a3277bbb..2d4326d8 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -25,6 +25,7 @@ from qiskit.transpiler import PassManager, InstructionDurations from qiskit.transpiler.passes import DynamicalDecoupling +from qiskit_qec.decoders.decoding_graph import Node, Edge def _separate_string(string): separated_string = [] @@ -311,15 +312,16 @@ def string2nodes(self, string, **kwargs): boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - bnode = {"time": 0} i = [0, -1][bqec_index] if self.basis == "z": bqubits = [self.css_x_logical[i]] else: bqubits = [self.css_z_logical[i]] - bnode["qubits"] = bqubits - bnode["is_boundary"] = True - bnode["element"] = bqec_index + bnode = Node( + is_boundary=True, + qubits = bqubits, + index = bqec_index + ) nodes.append(bnode) # bulk nodes @@ -328,14 +330,15 @@ def string2nodes(self, string, **kwargs): elements = separated_string[syn_type][syn_round] for qec_index, element in enumerate(elements[::-1]): if element == "1": - node = {"time": syn_round} if self.basis == "z": qubits = self.css_z_gauge_ops[qec_index] else: qubits = self.css_x_gauge_ops[qec_index] - node["qubits"] = qubits - node["is_boundary"] = False - node["element"] = qec_index + node = Node( + time=syn_round, + qubits=qubits, + index=qec_index + ) nodes.append(node) return nodes @@ -359,6 +362,7 @@ def flatten_nodes(nodes): Returns: flat_nodes (list): List of flattened nodes. """ + raise NotImplementedError("Has not been updated to new decoding graph node type") nodes_per_link = {} for node in nodes: link_qubit = node["link qubit"] @@ -397,15 +401,15 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see which qubits for logical zs are given and collect bulk nodes given_logicals = [] for node in nodes: - if node["is_boundary"]: - given_logicals += node["qubits"] + if node.is_boundary: + given_logicals += node.qubits given_logicals = set(given_logicals) # bicolour code qubits according to the domain walls walls = [] for node in nodes: - if not node["is_boundary"]: - walls.append(node["qubits"][1]) + if not node.is_boundary: + walls.append(node.qubits[1]) walls.sort() c = 0 colors = "" @@ -451,7 +455,11 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): elem = self.css_z_boundary.index(qubits) else: elem = self.css_x_boundary.index(qubits) - node = {"time": 0, "qubits": qubits, "is_boundary": True, "element": elem} + node = Node( + is_boundary=True, + qubits=qubits, + index=elem + ) flipped_logical_nodes.append(node) if neutral and flipped_logical_nodes == []: @@ -1020,7 +1028,7 @@ def _process_string(self, string): return new_string - def string2nodes(self, string, **kwargs): + def string2nodes(self, string, **kwargs) -> List[Node]: """ Convert output string from circuits into a set of nodes. Args: @@ -1055,19 +1063,21 @@ def string2nodes(self, string, **kwargs): code_qubits = [link[0], link[2]] link_qubit = link[1] tau, _, _ = self._get_202(syn_round) + if not tau: + tau = 0 node = {"time": syn_round} - if tau: - if ((tau % 2) == 1) and tau > 1: - node["conjugate"] = True - node["qubits"] = code_qubits - node["link qubit"] = link_qubit - node["is_boundary"] = is_boundary - node["element"] = elem_num + node = Node( + time=syn_round, + qubits=code_qubits, + index=elem_num + ) + node.properties["conjugate"] = ((tau % 2) == 1) and tau > 1 + node.properties["link qubit"] = link_qubit nodes.append(node) return nodes @staticmethod - def flatten_nodes(nodes): + def flatten_nodes(nodes: List[Node]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. Also removes nodes corresponding @@ -1080,26 +1090,22 @@ def flatten_nodes(nodes): # strip out conjugate nodes non_conj_nodes = [] for node in nodes: - if "conjugate" not in node: + if not node.properties["conjugate"]: non_conj_nodes.append(node) - else: - if not node["conjugate"]: - non_conj_nodes.append(node) nodes = non_conj_nodes # remove time info nodes_per_link = {} for node in nodes: - link_qubit = node["link qubit"] + link_qubit = node.properties["link qubit"] if link_qubit in nodes_per_link: nodes_per_link[link_qubit] += 1 else: nodes_per_link[link_qubit] = 1 flat_nodes = [] for node in nodes: - if nodes_per_link[node["link qubit"]] % 2: + if nodes_per_link[node.properties["link qubit"]] % 2: flat_node = node.copy() - if "time" in flat_node: - flat_node.pop("time") + flat_node.time = None flat_nodes.append(flat_node) return flat_nodes @@ -1126,8 +1132,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): given_logicals = [] bulk_nodes = [] for node in nodes: - if node["is_boundary"]: - given_logicals += node["qubits"] + if node.is_boundary: + given_logicals += node.qubits else: bulk_nodes.append(node) given_logicals = set(given_logicals) @@ -1135,7 +1141,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see whether the bulk nodes are neutral if bulk_nodes: nodes = self.flatten_nodes(nodes) - link_qubits = set(node["link qubit"] for node in nodes) + link_qubits = set(node.properties["link qubit"] for node in nodes) node_color = {0: 0} neutral = True link_graph = self._get_link_graph() @@ -1154,7 +1160,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): else: nn = n0 # see if the edge corresponds to one of the given nodes - dc = edge["link qubit"] in link_qubits + dc = edge.properties["link qubit"] in link_qubits # if the neighbour is not yet coloured, colour it # different color if edge is given node, same otherwise if nn not in node_color: @@ -1205,13 +1211,11 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): flipped_logical_nodes = [] for flipped_logical in flipped_logicals: - node = { - "time": 0, - "qubits": [flipped_logical], - "link qubit": None, - "is_boundary": True, - "element": self.z_logicals.index(flipped_logical), - } + node = Node( + is_boundary=True, + qubits=[flipped_logical], + index=self.z_logicals.index(flipped_logical) + ) flipped_logical_nodes.append(node) if this_neutral and flipped_logical_nodes == []: @@ -1338,28 +1342,28 @@ def _make_syndrome_graph(self): + ("0" * len(self.links) + " ") * (self.T - 1) + "1" * len(self.links) ) - nodes = [] + nodes: List[Node] = [] for node in self.string2nodes(string): - if not node["is_boundary"]: + if not node.is_boundary: for t in range(self.T + 1): new_node = node.copy() - new_node["time"] = t + new_node.time = t if new_node not in nodes: nodes.append(new_node) else: - node["time"] = 0 + node.time = 0 nodes.append(node) # find pairs that should be connected - edges = [] + edges: List[Tuple[int, int]] = [] for n0, node0 in enumerate(nodes): for n1, node1 in enumerate(nodes): if n0 < n1: # just record all possible edges for now (should be improved later) - dt = abs(node1["time"] - node0["time"]) - adj = set(node0["qubits"]).intersection(set(node1["qubits"])) + dt = abs(node1.time - node0.time) + adj = set(node0.qubits).intersection(set(node1.qubits)) if adj: - if (node0["is_boundary"] ^ node1["is_boundary"]) or dt <= 1: + if (node0.is_boundary ^ node1.is_boundary) or dt <= 1: edges.append((n0, n1)) # put it all in a graph @@ -1371,11 +1375,14 @@ def _make_syndrome_graph(self): source = nodes[n0] target = nodes[n1] qubits = [] - if not (source["is_boundary"] and target["is_boundary"]): - qubits = list(set(source["qubits"]).intersection(target["qubits"])) - if source["time"] != target["time"] and len(qubits) > 1: + if not (source.is_boundary and target.is_boundary): + qubits = list(set(source.qubits).intersection(target.qubits)) + if source.time != target.time and len(qubits) > 1: qubits = [] - edge = {"qubits": qubits, "weight": 1} + edge = Edge( + qubits=qubits, + weight=1 + ) S.add_edge(n0, n1, edge) # just record edges as hyperedges for now (should be improved later) hyperedges.append({(n0, n1): edge}) @@ -1422,65 +1429,65 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): node0 = nodes[n0] node1 = nodes[n1] if n0 != n1: - qubits = decoding_graph.graph.get_edge_data(n0, n1)["qubits"] + qubits = decoding_graph.graph.get_edge_data(n0, n1).qubits if qubits: # error on a code qubit between rounds, or during a round assert ( - node0["time"] == node1["time"] and node0["qubits"] != node1["qubits"] - ) or (node0["time"] != node1["time"] and node0["qubits"] != node1["qubits"]) + node0.time == node1.time and node0.qubits != node1.qubits + ) or (node0.time != node1.time and node0.qubits != node1.qubits) qubit = qubits[0] # error between rounds - if node0["time"] == node1["time"]: + if node0.time == node1.time: dts = [] for node in [node0, node1]: - pair = [qubit, node["link qubit"]] + pair = [qubit, node.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: dts.append(dt) - time = [max(0, node0["time"] - 1 + (max(dts) + 1) / round_length)] - time.append(node0["time"] + min(dts) / round_length) + time = [max(0, node0.time - 1 + (max(dts) + 1) / round_length)] + time.append(node0.time + min(dts) / round_length) # error during a round else: # put nodes in descending time order - if node0["time"] < node1["time"]: + if node0.time < node1.time: node_pair = [node1, node0] else: node_pair = [node0, node1] # see when in the schedule each node measures the qubit dts = [] for node in node_pair: - pair = [qubit, node["link qubit"]] + pair = [qubit, node.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: dts.append(dt) # use to define fractional time if dts[0] < dts[1]: - time = [node_pair[1]["time"] + (dts[0] + 1) / round_length] - time.append(node_pair[1]["time"] + dts[1] / round_length) + time = [node_pair[1].time + (dts[0] + 1) / round_length] + time.append(node_pair[1].time + dts[1] / round_length) else: # impossible cases get no valid time time = [] else: # measurement error - assert node0["time"] != node1["time"] and node0["qubits"] == node1["qubits"] - qubit = node0["link qubit"] - time = [node0["time"], node0["time"] + (round_length - 1) / round_length] + assert node0.time != node1.time and node0.qubits == node1.qubits + qubit = node0.properties["link qubit"] + time = [node0.time, node0.time + (round_length - 1) / round_length] time.sort() else: # detected only by one stabilizer - boundary_qubits = list(set(node0["qubits"]).intersection(z_logicals)) + boundary_qubits = list(set(node0.qubits).intersection(z_logicals)) # for the case of boundary stabilizers if boundary_qubits: qubit = boundary_qubits[0] - pair = [qubit, node0["link qubit"]] + pair = [qubit, node0.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: - time = [max(0, node0["time"] - 1 + (dt + 1) / round_length)] - time.append(node0["time"] + dt / round_length) + time = [max(0, node0.time - 1 + (dt + 1) / round_length)] + time.append(node0.time + dt / round_length) else: - qubit = tuple(node0["qubits"] + [node0["link qubit"]]) - time = [node0["time"], node0["time"] + (round_length - 1) / round_length] + qubit = tuple(node0.qubits + [node0.properties["link qubit"]]) + time = [node0.time, node0.time + (round_length - 1) / round_length] if time != []: # only record if not nan if (qubit, time[0], time[1]) not in error_coords: diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 7b7a2975..a81fbc6c 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -19,9 +19,9 @@ from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister +from qiskit_qec.decoders.decoding_graph import Node, Edge class SurfaceCodeCircuit: - """ Implementation of a distance d rotated surface code, implemented over T syndrome measurement rounds. @@ -391,11 +391,12 @@ def string2nodes(self, string, logical="0", all_logicals=False): boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - bnode = {"time": 0} - bnode["qubits"] = self._logicals[self.basis][-bqec_index - 1] - bnode["is_boundary"] = True - bnode["element"] = 1 - bqec_index - nodes.append(bnode) + node = Node( + is_boundary=True, + qubits = self._logicals[self.basis][-bqec_index - 1], + index = 1 - bqec_index + ) + nodes.append(node) # bulk nodes for syn_type in range(1, len(separated_string)): @@ -403,14 +404,15 @@ def string2nodes(self, string, logical="0", all_logicals=False): elements = separated_string[syn_type][syn_round] for qec_index, element in enumerate(elements[::-1]): if element == "1": - node = {"time": syn_round} if self.basis == "x": qubits = self.css_x_stabilizer_ops[qec_index] else: qubits = self.css_z_stabilizer_ops[qec_index] - node["qubits"] = qubits - node["is_boundary"] = False - node["element"] = qec_index + node = Node( + time = syn_round, + qubits=qubits, + index = qec_index + ) nodes.append(node) return nodes @@ -433,9 +435,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): num_errors (int): Minimum number of errors required to create nodes. """ - bulk_nodes = [node for node in nodes if not node["is_boundary"]] - boundary_nodes = [node for node in nodes if node["is_boundary"]] - given_logicals = set(node["element"] for node in boundary_nodes) + bulk_nodes = [node for node in nodes if not node.is_boundary] + boundary_nodes = [node for node in nodes if node.is_boundary] + given_logicals = set(node.index for node in boundary_nodes) if self.basis == "z": coords = self._zplaq_coords @@ -451,7 +453,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): xs = [] ys = [] for node in bulk_nodes: - x, y = coords[node["element"]] + x, y = coords[node.index] xs.append(x) ys.append(y) dx = max(xs) - min(xs) @@ -469,7 +471,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # find nearest boundary num_errors = (self.d - 1) / 2 for node in bulk_nodes: - x, y = coords[node["element"]] + x, y = coords[node.index] if self.basis == "z": p = y else: @@ -489,12 +491,11 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # get the required boundary nodes flipped_logical_nodes = [] for elem in flipped_logicals: - node = { - "time": 0, - "qubits": self._logicals[self.basis][elem], - "is_boundary": True, - "element": elem, - } + node = Node( + is_boundary=True, + qubits=self._logicals[self.basis][elem], + index=elem + ) flipped_logical_nodes.append(node) return neutral, flipped_logical_nodes, num_errors diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index 71262f42..9cce5d5f 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -30,7 +30,7 @@ UnionFindDecoder """ -from .decoding_graph import DecodingGraph +from .decoding_graph import DecodingGraph, Node, Edge from .circuit_matching_decoder import CircuitModelMatchingDecoder from .repetition_decoder import RepetitionDecoder from .three_bit_decoder import ThreeBitDecoder diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py index 1ab26faf..9ad9cd75 100644 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ b/src/qiskit_qec/decoders/circuit_matching_decoder.py @@ -10,7 +10,7 @@ import rustworkx as rx from qiskit import QuantumCircuit from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph +from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph, Node, Edge from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome @@ -173,13 +173,13 @@ def _process_graph( n0, n1 = graph.edge_list()[j] source = graph.nodes()[n0] target = graph.nodes()[n1] - if source["time"] != target["time"]: - if source["is_boundary"] == target["is_boundary"] == False: + if source.time != target.time: + if source.is_boundary == target.is_boundary == False: new_source = source.copy() - new_source["time"] = target["time"] + new_source.time = target.time nn0 = graph.nodes().index(new_source) new_target = target.copy() - new_target["time"] = source["time"] + new_target.time = source.time nn1 = graph.nodes().index(new_target) graph.add_edge(nn0, nn1, edge) @@ -191,16 +191,16 @@ def _process_graph( # add the required attributes # highlighted', 'measurement_error','qubit_id' and 'error_probability' - edge["highlighted"] = False - edge["measurement_error"] = int(source["time"] != target["time"]) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = int(source["time"] != target["time"]) # make it so times of boundary/boundary nodes agree - if source["is_boundary"] and not target["is_boundary"]: - if source["time"] != target["time"]: + if source.is_boundary and not target.is_boundary: + if source.time != target.time: new_source = source.copy() - new_source["time"] = target["time"] + new_source.time = target.time n = graph.add_node(new_source) - edge["measurement_error"] = 0 + edge.properties["measurement_error"] = 0 edges_to_remove.append((n0, n1)) graph.add_edge(n, n1, edge) @@ -211,29 +211,29 @@ def _process_graph( for n0, source in enumerate(graph.nodes()): for n1, target in enumerate(graph.nodes()): # add weightless nodes connecting different boundaries - if source["time"] == target["time"]: - if source["is_boundary"] and target["is_boundary"]: - if source["qubits"] != target["qubits"]: - edge = { - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } - edge["qubits"] = list( - set(source["qubits"]).intersection((set(target["qubits"]))) + if source.time == target.time: + if source.is_boundary and target.is_boundary: + if source.qubits != target.qubits: + edge = Edge( + weight= 0, + qubits = list( + set(source["qubits"]).intersection((set(target["qubits"]))) + ) ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 if (n0, n1) not in graph.edge_list(): graph.add_edge(n0, n1, edge) # connect one of the boundaries at different times - if target["time"] == source["time"] + 1: - if source["qubits"] == target["qubits"] == [0]: - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + if target.time == source.time + 1: + if source.qubits == target.qubits == [0]: + edge = Edge( + weight= 0, + qubits = [] + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 if (n0, n1) not in graph.edge_list(): graph.add_edge(n0, n1, edge) @@ -245,14 +245,14 @@ def _process_graph( idxmap = {} for n, node in enumerate(graph.nodes()): - idxmap[node["time"], tuple(node["qubits"])] = n + idxmap[node.time, tuple(node.qubits)] = n node_layers = [] for node in graph.nodes(): - time = node["time"] + time = node.time if len(node_layers) < time + 1: node_layers += [[]] * (time + 1 - len(node_layers)) - node_layers[time].append(node["qubits"]) + node_layers[time].append(node.qubits) # create a list of decoding graph layer types # the entries are 'g' for gauge and 's' for stabilizer @@ -298,10 +298,10 @@ def _revise_decoding_graph( # TODO: new edges may be needed for hooks, but raise exception for now raise QiskitQECError("edge {s1} - {s2} not in decoding graph") data = graph.get_edge_data(idxmap[s1], idxmap[s2]) - data["weight_poly"] = wpoly + data.properties["weight_poly"] = wpoly remove_list = [] for source, target in graph.edge_list(): - edge_data = graph.get_edge_data(source, target) + edge_data = graph.get_edge_data(source, target).properties if "weight_poly" not in edge_data and edge_data["weight"] != 0: # Remove the edge remove_list.append((source, target)) @@ -340,7 +340,7 @@ def update_edge_weights(self, model: PauliNoiseModel): # p_i is the probability that edge i carries an error # l(i) is 1 if the link belongs to the chain and 0 otherwise for source, target in self.graph.edge_list(): - edge_data = self.graph.get_edge_data(source, target) + edge_data = self.graph.get_edge_data(source, target).properties if "weight_poly" in edge_data: logging.info( "update_edge_weights (%d, %d) %s", diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index b4baa2b2..05165312 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -17,9 +17,10 @@ """ Graph used as the basis of decoders. """ +from dataclasses import dataclass, field import itertools import logging -from typing import List, Tuple +from typing import Any, Dict, List, Tuple, Optional import numpy as np import rustworkx as rx @@ -27,6 +28,25 @@ from qiskit_qec.exceptions import QiskitQECError +class Node: + def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: + if not is_boundary and not time: + raise QiskitQECError("DecodingGraph node must either have a time or be a boundary node.") + + self.is_boundary: bool = is_boundary + self.time: Optional[int] = time if not is_boundary else None + self.qubits: List[int] = qubits + self.index: int = index + # TODO: Should code/decoder specific properties be accounted for when comparing nodes + self.properties: Dict[str, Any] = {} + +@dataclass +class Edge: + qubits: List[int] + weight: float + # TODO: Should code/decoder specific properties be accounted for when comparing edges + properties: Dict[str, Any] = field(default_factory=dict) + class DecodingGraph: """ Class to construct the decoding graph for the code given by a CodeCircuit object, @@ -90,20 +110,20 @@ def _make_syndrome_graph(self): n0 = graph.nodes().index(source) n1 = graph.nodes().index(target) qubits = [] - if not (source["is_boundary"] and target["is_boundary"]): + if not (source.is_boundary and target.is_boundary): qubits = list( - set(source["qubits"]).intersection(target["qubits"]) + set(source.qubits).intersection(target.qubits) ) if not qubits: continue if ( - source["time"] != target["time"] + source.time != target.time and len(qubits) > 1 - and not source["is_boundary"] - and not target["is_boundary"] + and not source.is_boundary + and not target.is_boundary ): qubits = [] - edge = {"qubits": qubits, "weight": 1} + edge = Edge(qubits, 1) graph.add_edge(n0, n1, edge) if (n1, n0) not in hyperedge: hyperedge[n0, n1] = edge @@ -112,7 +132,7 @@ def _make_syndrome_graph(self): self.graph = graph - def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ): + def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ) -> List[Tuple[Tuple[int, int], float]]: """ Generate probabilities of single error events from result counts. @@ -181,9 +201,9 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ error_probs = {} for n0, n1 in self.graph.edge_list(): - if self.graph[n0]["is_boundary"]: + if self.graph[n0].is_boundary: boundary.append(n1) - elif self.graph[n1]["is_boundary"]: + elif self.graph[n1].is_boundary: boundary.append(n0) else: if (1 - 2 * av_xor[n0, n1]) != 0: @@ -238,9 +258,9 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ else: ratio = np.nan p = ratio / (1 + ratio) - if self.graph[n0]["is_boundary"] and not self.graph[n1]["is_boundary"]: + if self.graph[n0].is_boundary and not self.graph[n1].is_boundary: edge = (n1, n1) - elif not self.graph[n0]["is_boundary"] and self.graph[n1]["is_boundary"]: + elif not self.graph[n0].is_boundary and self.graph[n1].is_boundary: edge = (n0, n0) else: edge = (n0, n1) @@ -267,7 +287,7 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): boundary_nodes = [] for n, node in enumerate(self.graph.nodes()): - if node["is_boundary"]: + if node.is_boundary: boundary_nodes.append(n) for edge in self.graph.edge_list(): diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 1f5ef5d4..0bdd07eb 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -18,11 +18,12 @@ from copy import copy, deepcopy from dataclasses import dataclass -from typing import Dict, List, Set + +from typing import Dict, List, Set, Tuple, Tuple from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit -from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.decoders.decoding_graph import DecodingGraph, Node, Edge from qiskit_qec.exceptions import QiskitQECError @@ -128,11 +129,13 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] for element, z_logical in enumerate(self.z_logicals): - node = {"time": 0, "is_boundary": True} + node = Node( + is_boundary=True, + qubits=[z_logical], + index=element + ) if isinstance(self.code, ArcCircuit): - node["link qubit"] = None - node["qubits"] = [z_logical] - node["element"] = element + node.properties["link qubit"] = None boundary_nodes.append(node) return boundary_nodes @@ -169,7 +172,7 @@ def cluster(self, nodes): if c is not None: final_clusters[n] = c else: - if not dg[n]["is_boundary"]: + if not dg[n].is_boundary: ns.append(n) con_comps.append(con_comp) clusterss.append(clusters) @@ -199,14 +202,14 @@ def process(self, string): cluster_nodes = {c: [] for c in clusters.values()} for n, c in clusters.items(): node = decoding_graph.graph[n] - if not node["is_boundary"]: + if not node.is_boundary: cluster_nodes[c].append(node) # get the list of required logicals for each cluster cluster_logicals = {} for c, nodes in cluster_nodes.items(): _, logical_nodes, _ = code.check_nodes(nodes) - z_logicals = [node["qubits"][0] for node in logical_nodes] + z_logicals = [node.qubits[0] for node in logical_nodes] cluster_logicals[c] = z_logicals # get the net effect on each logical @@ -247,7 +250,7 @@ class BoundaryEdge: index: int cluster_vertex: int neighbour_vertex: int - data: Dict[str, object] + data: Edge def reverse(self): """ @@ -333,10 +336,10 @@ def process(self, string: str): if isinstance(self.code, ArcCircuit): # NOTE: it just corrects for final logical readout for node in erasure.nodes(): - if node["is_boundary"]: + if node.is_boundary: # FIXME: Find a general way to go from physical qubit # index to code qubit index - qubit_to_be_corrected = int(node["qubits"][0] / 2) + qubit_to_be_corrected = int(node.qubits[0] / 2) output[qubit_to_be_corrected] = (output[qubit_to_be_corrected] + 1) % 2 continue @@ -360,12 +363,12 @@ def cluster(self, nodes) -> List[List[int]]: """ node_indices = [self.graph.nodes().index(node) for node in nodes] for node_index, _ in enumerate(self.graph.nodes()): - self.graph[node_index]["syndrome"] = node_index in node_indices - self.graph[node_index]["root"] = node_index + self.graph[node_index].properties["syndrome"] = node_index in node_indices + self.graph[node_index].properties["root"] = node_index for edge in self.graph.edges(): - edge["growth"] = 0 - edge["fully_grown"] = False + edge.properties["growth"] = 0 + edge.properties["fully_grown"] = False self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots = set(node_indices) @@ -409,11 +412,11 @@ def find(self, u: int) -> int: Returns: root (int): The root of the cluster of node u. """ - if self.graph[u]["root"] == u: - return self.graph[u]["root"] + if self.graph[u].properties["root"] == u: + return self.graph[u].properties["root"] - self.graph[u]["root"] = self.find(self.graph[u]["root"]) - return self.graph[u]["root"] + self.graph[u].properties["root"] = self.find(self.graph[u].properties["root"]) + return self.graph[u].properties["root"] def _grow_clusters(self) -> List[FusionEntry]: """ @@ -427,9 +430,9 @@ def _grow_clusters(self) -> List[FusionEntry]: for root in self.odd_cluster_roots: cluster = self.clusters[root] for edge in cluster.boundary: - edge.data["growth"] += 0.5 - if edge.data["growth"] >= edge.data["weight"] and not edge.data["fully_grown"]: - edge.data["fully_grown"] = True + edge.data.properties["growth"] += 0.5 + if edge.data.properties["growth"] >= edge.data.properties["weight"] and not edge.data.properties["fully_grown"]: + edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(edge.index) fusion_entry = FusionEntry( u=edge.cluster_vertex, v=edge.neighbour_vertex, connecting_edge=edge @@ -473,7 +476,7 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]) -> None: else: self.odd_cluster_roots.discard(new_root) self.odd_cluster_roots.discard(root_to_update) - self.graph[root_to_update]["root"] = new_root + self.graph[root_to_update].properties["root"] = new_root def peeling(self, erasure: PyGraph) -> List[int]: """ " @@ -497,7 +500,7 @@ def peeling(self, erasure: PyGraph) -> List[int]: # Construct spanning forest # Pick starting vertex for vertex in erasure.node_indices(): - if erasure[vertex]["is_boundary"]: + if erasure[vertex].is_boundary: tree.vertices[vertex] = [] break if not tree.vertices: @@ -522,11 +525,11 @@ def peeling(self, erasure: PyGraph) -> List[int]: pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0]] else endpoints[1] tree_vertex = endpoints[0] if pendant_vertex == endpoints[1] else endpoints[1] tree.vertices[tree_vertex].remove(edge) - if erasure[pendant_vertex]["syndrome"] and not erasure[pendant_vertex]["is_boundary"]: + if erasure[pendant_vertex].properties["syndrome"] and not erasure[pendant_vertex].is_boundary: edges.add(edge) - erasure[tree_vertex]["syndrome"] = not erasure[tree_vertex]["syndrome"] - erasure[pendant_vertex]["syndrome"] = False + erasure[tree_vertex].properties["syndrome"] = not erasure[tree_vertex].properties["syndrome"] + erasure[pendant_vertex].properties["syndrome"] = False return [ - erasure.edges()[edge]["qubits"][0] for edge in edges if erasure.edges()[edge]["qubits"] + erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits ] diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py index 5d735469..da3dfcf9 100644 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ b/src/qiskit_qec/decoders/rustworkx_matcher.py @@ -6,17 +6,16 @@ import rustworkx as rx from qiskit_qec.decoders.base_matcher import BaseMatcher - +from qiskit_qec.decoders.decoding_graph import Node, Edge class RustworkxMatcher(BaseMatcher): """Matching subroutines using rustworkx. - The input rustworkx graph is expected to have the following properties: - edge["weight"] : real edge weight - edge["measurement_error"] : bool, true if edge corresponds to measurement error - edge["qubits"] : list of qubit ids associated to edge - vertex["qubits"] : qubit ids involved in gauge measurement - vertex["time"] : integer time step of gauge measurement + The input rustworkx graph is expected to have decoding_graph.Node as the type of the node payload + and decoding_graph.Edge as the type of the edge payload. + + Additionally the edges are expected to have the following properties: + - edge.properties["measurement_error"] (bool): Whether or not the error corresponds to a measurement error. The annotated graph will also have "highlighted" properties on edges and vertices. """ @@ -36,8 +35,8 @@ def preprocess(self, graph: rx.PyGraph): """ # edge_cost_fn = lambda edge: edge["weight"] - def edge_cost_fn(edge): - return edge["weight"] + def edge_cost_fn(edge: Edge): + return edge.weight length = rx.all_pairs_dijkstra_path_lengths(graph, edge_cost_fn) self.length = {s: dict(length[s]) for s in length} @@ -111,11 +110,11 @@ def _error_chain_from_vertex_path( for i in range(len(vertex_path) - 1): v0 = vertex_path[i] v1 = vertex_path[i + 1] - if graph.get_edge_data(v0, v1)["measurement_error"] == 1: + if graph.get_edge_data(v0, v1).properties["measurement_error"] == 1: measurement_errors ^= set( - [(graph.nodes()[v0]["time"], tuple(graph.nodes()[v0]["qubits"]))] + [(graph.nodes()[v0].time, tuple(graph.nodes()[v0].qubits))] ) - qubit_errors ^= set(graph.get_edge_data(v0, v1)["qubits"]) + qubit_errors ^= set(graph.get_edge_data(v0, v1).qubits) logging.debug( "_error_chain_for_vertex_path q = %s, m = %s", qubit_errors, @@ -173,7 +172,7 @@ def _make_annotated_graph(gin: rx.PyGraph, paths: List[List[int]]) -> rx.PyGraph for path in paths: # Highlight the endpoints of the path for i in [0, -1]: - graph.nodes()[path[i]]["highlighted"] = True + graph.nodes()[path[i]].properties["highlighted"] = True # Highlight the edges along the path for i in range(len(path) - 1): try: @@ -181,5 +180,5 @@ def _make_annotated_graph(gin: rx.PyGraph, paths: List[List[int]]) -> rx.PyGraph except ValueError: idx = list(graph.edge_list()).index((path[i + 1], path[i])) edge = graph.edges()[idx] - edge["highlighted"] = True + edge.properties["highlighted"] = True return graph From ca1d710e1b5162a270b233e1d3af167116e7f82b Mon Sep 17 00:00:00 2001 From: James Wootton Date: Thu, 2 Mar 2023 16:54:30 +0100 Subject: [PATCH 04/22] create CodeCircuit class --- src/qiskit_qec/circuits/__init__.py | 2 +- src/qiskit_qec/circuits/code_circuit.py | 83 +++++++++++++ src/qiskit_qec/circuits/qec_circuit.py | 132 --------------------- src/qiskit_qec/circuits/repetition_code.py | 16 ++- src/qiskit_qec/circuits/surface_code.py | 25 ++-- 5 files changed, 114 insertions(+), 144 deletions(-) create mode 100644 src/qiskit_qec/circuits/code_circuit.py delete mode 100644 src/qiskit_qec/circuits/qec_circuit.py diff --git a/src/qiskit_qec/circuits/__init__.py b/src/qiskit_qec/circuits/__init__.py index 153d176d..63447aaa 100644 --- a/src/qiskit_qec/circuits/__init__.py +++ b/src/qiskit_qec/circuits/__init__.py @@ -27,6 +27,6 @@ CSSCircuit """ +from .code_circuit import CodeCircuit from .repetition_code import RepetitionCodeCircuit, ArcCircuit from .surface_code import SurfaceCodeCircuit -from .qec_circuit import CSSCircuit diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py new file mode 100644 index 00000000..c0f91ccb --- /dev/null +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2019. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name + +"""Class that manage circuits for codes.""" +from abc import ABC, abstractmethod + + +class CodeCircuit(ABC): + """ + Abstract class to manage circuits for codes, as well + as other fault-tolerant circuits. + + A CodeCircuit requires the methods `string2nodes`, + `check_nodes` and `is_cluster_neutral` in order to + interface with its `DecodingGraph` and decoders. + """ + + def __init__(self): + """ + Initialization of classes that inherent from CodeCircuit can + be done in various ways, depending on the code or code family + to be initialized. In all cases, the initialization must define + attributes `circuit` and `base`. The former is a dictionary with + circuits as values, and labels as keys. The latter is the label + regarded as the base case, used in decoding graph generation. + """ + pass + + @abstractmethod + def string2nodes(self, string, **kwargs): + """ + Convert output string from circuits into a set of nodes for + `DecodingGraph`. + Args: + string (string): Results string to convert. + kwargs (dict): Any additional keyword arguments. + """ + pass + + @abstractmethod + def check_nodes(self, nodes, ignore_extra_boundary=False): + """ + Determines whether a given set of nodes are neutral. If so, also + determines any additional logical readout qubits that would be + flipped by the errors creating such a cluster and how many errors + would be required to make the cluster. + Args: + nodes (list): List of nodes, of the type produced by `string2nodes`. + ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignored. + Returns: + neutral (bool): Whether the nodes independently correspond to a valid + set of errors. + flipped_logical_nodes (list): List of qubits nodes for logical + operators that are flipped by the errors, that were not included + in the original nodes. + num_errors (int): Minimum number of errors required to create nodes. + """ + pass + + @abstractmethod + def is_cluster_neutral(self, atypical_nodes): + """ + Determines whether or not the cluster is neutral, meaning that one or more + errors could have caused the set of atypical nodes (syndrome changes) passed + to the method. + Args: + atypical_nodes (dictionary in the form of the return value of string2nodes) + """ + pass diff --git a/src/qiskit_qec/circuits/qec_circuit.py b/src/qiskit_qec/circuits/qec_circuit.py deleted file mode 100644 index 6bdbe209..00000000 --- a/src/qiskit_qec/circuits/qec_circuit.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- - -# This code is part of Qiskit. -# -# (C) Copyright IBM 2019. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -# pylint: disable=invalid-name - -"""Classes that create and manage circuits for codes.""" -from typing import List, Tuple - -from qiskit import QuantumCircuit - - -class CSSCircuit: - """CSSCircuit class.""" - - def __init__( - self, - n: int, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[int], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[int], - basis: str, - round_schedule: str, - blocks: int, - resets: bool, - delay: float, - ): - """Create and manage circuits for generic CSS codes. - - Args: - n : number of code qubits - css_x_gauge_ops : list of supports of X gauge operators - css_x_stabilizer_ops : list of supports of X stabilizers - css_x_boundary : list of qubits along the X-type boundary - css_z_gauge_ops : list of supports of Z gauge operators - css_z_stabilizer_ops : list of supports of Z stabilizers - css_x_boundary : list of qubits along the Z-type boundary - basis : initializaton and measurement basis ("x" or "z") - round_schedule : gauge measurements in each block - blocks : number of measurement blocks - resets : Whether to include a reset gate after mid-circuit measurements. - delay: Time (in dt) to delay after mid-circuit measurements (and reset). - """ - self.n = n - self.css_x_gauge_ops = css_x_gauge_ops - self.css_x_stabilizer_ops = css_x_stabilizer_ops - self.css_x_boundary = css_x_boundary - self.css_z_gauge_ops = css_z_gauge_ops - self.css_z_stabilizer_ops = css_z_stabilizer_ops - self.css_z_boundary = css_z_boundary - self.basis = basis - self.round_schedule = round_schedule - self.blocks = blocks - self.resets = resets - self.delay = delay - - self.circuit = QuantumCircuit() - - def x(self, barrier=False): - """ - Applies a logical x to the circuit. - - Args: - barrier (bool): Boolean denoting whether to include a barrier at - the end. - """ - pass - - def z(self, barrier=False): - """ - Applies a logical x to the circuit. - - Args: - barrier (bool): Boolean denoting whether to include a barrier at - the end. - """ - pass - - def syndrome_measurement(self, final: bool = False, barrier: bool = False, delay: int = 0): - """ - Application of a syndrome measurement round. - - Args: - final (bool): Whether to disregard the reset (if applicable) due to this - being the final syndrome measurement round. - barrier (bool): Whether to include a barrier at the end. - delay (float): Time (in dt) to delay after mid-circuit measurements (and reset). - """ - pass - - def readout(self): - """ - Readout of all code qubits, which corresponds to a logical measurement - as well as allowing for a measurement of the syndrome to be inferred. - """ - pass - - def string2nodes(self, string, all_logicals=False): - """ - Convert output string from running the circuit into a set of nodes. - Args: - string (string): Results string to convert. - all_logicals (bool): Whether to include logical nodes - irrespective of value. - Returns: - dict: List of nodes corresponding to to the non-trivial - elements in the string. - """ - pass - - def string2raw_logicals(self, string): - """ - Extracts raw logical measurement outcomes from output string. - Args: - string (string): Results string from which to extract logicals - Returns: - list: Raw values for logical operators that correspond to nodes. - """ - pass diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 2d4326d8..bf636467 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -25,6 +25,10 @@ from qiskit.transpiler import PassManager, InstructionDurations from qiskit.transpiler.passes import DynamicalDecoupling +from qiskit_qec.circuits.code_circuit import CodeCircuit + +from qiskit_qec.circuits.code_circuit import CodeCircuit + from qiskit_qec.decoders.decoding_graph import Node, Edge def _separate_string(string): @@ -34,7 +38,7 @@ def _separate_string(string): return separated_string -class RepetitionCodeCircuit: +class RepetitionCodeCircuit(CodeCircuit): """RepetitionCodeCircuit class.""" def __init__( @@ -69,6 +73,8 @@ def __init__( syndrome measurement round). """ + super().__init__() + self.n = d self.d = d self.T = 0 @@ -469,7 +475,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): def is_cluster_neutral(self, atypical_nodes): """ - Determines whether or not the cluster is even. Even means that one or more + Determines whether or not the cluster is neutral, meaning that one or more errors could have caused the set of atypical nodes (syndrome changes) passed to the method. Args: @@ -520,7 +526,7 @@ def add_edge(graph, pair, edge=None): return ns -class ArcCircuit: +class ArcCircuit(CodeCircuit): """Anisotropic repetition code class.""" METHOD_SPITZ: str = "spitz" @@ -571,6 +577,8 @@ def __init__( previous measurement), rather than a reset gate. """ + super().__init__() + self.links = links self.basis = basis self.logical = logical @@ -1238,7 +1246,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): def is_cluster_neutral(self, atypical_nodes): """ - Determines whether or not the cluster is even. Even means that one or more + Determines whether or not the cluster is neutral, meaning that one or more errors could have caused the set of atypical nodes (syndrome changes) passed to the method. Args: diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index a81fbc6c..d2248f83 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -16,12 +16,16 @@ """Generates circuits based on repetition codes.""" - from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister +from qiskit_qec.circuits.code_circuit import CodeCircuit + +from qiskit_qec.circuits.code_circuit import CodeCircuit + from qiskit_qec.decoders.decoding_graph import Node, Edge -class SurfaceCodeCircuit: +class SurfaceCodeCircuit(CodeCircuit): + """ Implementation of a distance d rotated surface code, implemented over T syndrome measurement rounds. @@ -45,6 +49,7 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): qubits (corresponding to a logical measurement and final syndrome measurement round). """ + super().__init__() self.d = d self.T = 0 @@ -365,14 +370,15 @@ def _separate_string(self, string): separated_string.append(syndrome_type_string.split(" ")) return separated_string - def string2nodes(self, string, logical="0", all_logicals=False): + def string2nodes(self, string, **kwargs): """ Convert output string from circuits into a set of nodes. Args: string (string): Results string to convert. - logical (string): Logical value whose results are used. - all_logicals (bool): Whether to include logical nodes - irrespective of value. + kwargs (dict): Additional keyword arguments. + logical (str): Logical value whose results are used ('0' as default). + all_logicals (bool): Whether to include logical nodes + irrespective of value. (False as default). Returns: dict: List of nodes corresponding to to the non-trivial elements in the string. @@ -383,6 +389,11 @@ def string2nodes(self, string, logical="0", all_logicals=False): code whenever we're dealing with both strings and lists. """ + all_logicals = kwargs.get("all_logicals") + logical = kwargs.get("logical") + if logical is None: + logical = "0" + string = self._process_string(string) separated_string = self._separate_string(string) nodes = [] @@ -502,7 +513,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): def is_cluster_neutral(self, atypical_nodes): """ - Determines whether or not the cluster is even. Even means that one or more + Determines whether or not the cluster is neutral, meaning that one or more errors could have caused the set of atypical nodes (syndrome changes) passed to the method. Args: From 8c02c66699eb5aa32f86363cc67d71d7ef29428d Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 3 Mar 2023 15:47:57 +0100 Subject: [PATCH 05/22] add more detail to init --- src/qiskit_qec/circuits/code_circuit.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index c0f91ccb..b6e1a8a6 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -33,9 +33,13 @@ def __init__(self): Initialization of classes that inherent from CodeCircuit can be done in various ways, depending on the code or code family to be initialized. In all cases, the initialization must define - attributes `circuit` and `base`. The former is a dictionary with - circuits as values, and labels as keys. The latter is the label - regarded as the base case, used in decoding graph generation. + the following attributes: + circuit (dict): A dictionary with circuits as values, and + labels (typically strings) as keys. + base (string) The label for the above regarded as the base case, + used in decoding graph generation. + d (int): Code distance. + T (int): number of syndrome measurement rounds. """ pass From cc7088b4cc1b1760aab7a0e349242d9dd866dcea Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 3 Mar 2023 16:15:34 +0100 Subject: [PATCH 06/22] add default is_cluster_neutral --- src/qiskit_qec/circuits/code_circuit.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index b6e1a8a6..d0b6f489 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -81,7 +81,12 @@ def is_cluster_neutral(self, atypical_nodes): Determines whether or not the cluster is neutral, meaning that one or more errors could have caused the set of atypical nodes (syndrome changes) passed to the method. + + Default version here assumes that it is as simple as an an even/odd assessment + (as for repetition codes, surface codes, etc). This should be overwritten for + more complex codes. It also should be used with care, by only supplying sets + of nodes for which the even/odd assessment is valid. Args: atypical_nodes (dictionary in the form of the return value of string2nodes) """ - pass + return not bool(len(atypical_nodes) % 2) From de39a2a5e5382ebf4a7156062fbd6a499c5cd19f Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Tue, 7 Mar 2023 16:56:11 +0100 Subject: [PATCH 07/22] Add new DecodinGraph Node type support everywhere This patch adds support for the new Node everywhere, such that it passes all tests. It also moves DecodingGraph to the analysis module due to a circular dependency issue. --- src/qiskit_qec/analysis/__init__.py | 1 + .../{decoders => analysis}/decoding_graph.py | 111 ++++++++---- src/qiskit_qec/circuits/repetition_code.py | 30 ++-- src/qiskit_qec/circuits/surface_code.py | 7 +- src/qiskit_qec/decoders/__init__.py | 1 - .../decoders/circuit_matching_decoder.py | 20 +-- src/qiskit_qec/decoders/hdrg_decoders.py | 4 +- src/qiskit_qec/decoders/hhc_decoder.py | 2 +- src/qiskit_qec/decoders/repetition_decoder.py | 2 +- src/qiskit_qec/decoders/rustworkx_matcher.py | 2 +- src/qiskit_qec/decoders/temp_graph_util.py | 16 +- src/qiskit_qec/utils/__init__.py | 1 - src/qiskit_qec/utils/decodoku.py | 2 +- test/code_circuits/test_rep_codes.py | 54 ++++-- test/code_circuits/test_surface_codes.py | 166 ++++++++++++------ test/matching/test_circuitmatcher.py | 2 +- test/matching/test_pymatchingmatcher.py | 22 ++- test/matching/test_retworkxmatcher.py | 43 +++-- test/utils/test_decodoku.py | 2 +- 19 files changed, 327 insertions(+), 161 deletions(-) rename src/qiskit_qec/{decoders => analysis}/decoding_graph.py (89%) diff --git a/src/qiskit_qec/analysis/__init__.py b/src/qiskit_qec/analysis/__init__.py index 3c282fd5..a31b96f1 100644 --- a/src/qiskit_qec/analysis/__init__.py +++ b/src/qiskit_qec/analysis/__init__.py @@ -32,3 +32,4 @@ from .pyerrorpropagator import PyErrorPropagator from .faultenumerator import FaultEnumerator from .distance import minimum_distance +from .decoding_graph import DecodingGraph diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/analysis/decoding_graph.py similarity index 89% rename from src/qiskit_qec/decoders/decoding_graph.py rename to src/qiskit_qec/analysis/decoding_graph.py index 05165312..2770049a 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/analysis/decoding_graph.py @@ -30,7 +30,7 @@ class Node: def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: - if not is_boundary and not time: + if not is_boundary and time == None: raise QiskitQECError("DecodingGraph node must either have a time or be a boundary node.") self.is_boundary: bool = is_boundary @@ -38,8 +38,24 @@ def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) self.qubits: List[int] = qubits self.index: int = index # TODO: Should code/decoder specific properties be accounted for when comparing nodes - self.properties: Dict[str, Any] = {} + self.properties: Dict[str, Any] = dict() + def __eq__(self, rhs): + if not isinstance(rhs, Node): + return NotImplemented + + result = self.index == rhs.index and set(self.qubits) == set(rhs.qubits) and self.is_boundary == rhs.is_boundary + if not self.is_boundary: + result = result and self.time == rhs.time + return result + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value + @dataclass class Edge: qubits: List[int] @@ -47,6 +63,19 @@ class Edge: # TODO: Should code/decoder specific properties be accounted for when comparing edges properties: Dict[str, Any] = field(default_factory=dict) + def __eq__(self, rhs) -> bool: + if not isinstance(rhs, Node): + return NotImplemented + + return set(self.qubits) == set(rhs.qubits) and self.weight == rhs.weight + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value + class DecodingGraph: """ Class to construct the decoding graph for the code given by a CodeCircuit object, @@ -442,16 +471,26 @@ def _decoding_graph(self): all_z = gauges elif layer == "s": all_z = stabilizers - for supp in all_z: - node = {"time": time, "qubits": supp, "highlighted": False} + for index, supp in enumerate(all_z): + node = Node( + time=time, + qubits=supp, + index=index + ) + node.properties["highlighted"] = True graph.add_node(node) logging.debug("node %d t=%d %s", idx, time, supp) idxmap[(time, tuple(supp))] = idx node_layer.append(idx) idx += 1 - for supp in boundary: + for index, supp in enumerate(boundary): # Add optional is_boundary property for pymatching - node = {"time": time, "qubits": supp, "highlighted": False, "is_boundary": True} + node = Node( + is_boundary=True, + qubits=supp, + index=index + ) + node.properties["highlighted"] = False graph.add_node(node) logging.debug("boundary %d t=%d %s", idx, time, supp) idxmap[(time, tuple(supp))] = idx @@ -482,12 +521,12 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = { - "qubits": [com[0]], - "measurement_error": 0, - "weight": 1, - "highlighted": False, - } + edge = Edge( + qubits=[com[0]], + weight=1 + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge( idxmap[(time, tuple(op_g))], idxmap[(time, tuple(op_h))], edge ) @@ -505,12 +544,12 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + edge = Edge( + qubits=[], + weight=0 + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge(idxmap[(time, tuple(bound_g))], idxmap[(time, tuple(bound_h))], edge) logging.debug("spacelike boundary t=%d (%s, %s)", time, bound_g, bound_h) @@ -549,12 +588,12 @@ def _decoding_graph(self): # error_probability is a floating point number # Case (a) if set(com) == set(op_h) or set(com) == set(op_g): - edge = { - "qubits": [], - "measurement_error": 1, - "weight": 1, - "highlighted": False, - } + edge = Edge( + qubits=[], + weight=1 + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 1 graph.add_edge( idxmap[(time - 1, tuple(op_h))], idxmap[(time, tuple(op_g))], @@ -562,12 +601,12 @@ def _decoding_graph(self): ) logging.debug("timelike t=%d (%s, %s)", time, op_g, op_h) else: # Case (b) - edge = { - "qubits": [com[0]], - "measurement_error": 1, - "weight": 1, - "highlighted": False, - } + edge = Edge( + qubits=[com[0]], + weight=1 + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 1 graph.add_edge( idxmap[(time - 1, tuple(op_h))], idxmap[(time, tuple(op_g))], @@ -577,12 +616,12 @@ def _decoding_graph(self): logging.debug(" qubits %s", [com[0]]) # Add a single time-like edge between boundary vertices at # time t-1 and t - edge = { - "qubits": [], - "measurement_error": 0, - "weight": 0, - "highlighted": False, - } + edge = Edge( + qubits=[], + weight=0 + ) + edge.properties["highlighted"] = False + edge.properties["measurement_error"] = 0 graph.add_edge( idxmap[(time - 1, tuple(boundary[0]))], idxmap[(time, tuple(boundary[0]))], edge ) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index bf636467..02a55faf 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -19,6 +19,7 @@ import numpy as np import rustworkx as rx +from copy import copy, deepcopy from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister, transpile from qiskit.circuit.library import XGate, RZGate @@ -28,8 +29,7 @@ from qiskit_qec.circuits.code_circuit import CodeCircuit from qiskit_qec.circuits.code_circuit import CodeCircuit - -from qiskit_qec.decoders.decoding_graph import Node, Edge +from qiskit_qec.analysis.decoding_graph import Node, Edge def _separate_string(string): separated_string = [] @@ -359,7 +359,7 @@ def string2raw_logicals(self, string): return _separate_string(self._process_string(string))[0] @staticmethod - def flatten_nodes(nodes): + def flatten_nodes(nodes: List[Node]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. @@ -368,20 +368,19 @@ def flatten_nodes(nodes): Returns: flat_nodes (list): List of flattened nodes. """ - raise NotImplementedError("Has not been updated to new decoding graph node type") nodes_per_link = {} for node in nodes: - link_qubit = node["link qubit"] + link_qubit = node.properties["link qubit"] if link_qubit in nodes_per_link: nodes_per_link[link_qubit] += 1 else: nodes_per_link[link_qubit] = 1 flat_nodes = [] for node in nodes: - if nodes_per_link[node["link qubit"]] % 2: - flat_node = node.copy() - if "time" in flat_node: - flat_node.pop("time") + if nodes_per_link[node.properties["link qubit"]] % 2: + flat_node = copy(node) + # FIXME: Seems unsafe. + flat_node.time = None flat_nodes.append(flat_node) return flat_nodes @@ -632,6 +631,7 @@ def __init__( self._readout() def _get_link_graph(self, max_dist=1): + # FIXME: Migrate link graph to new Edge type graph = rx.PyGraph() for link in self.links: add_edge(graph, (link[0], link[2]), {"distance": 1, "link qubit": link[1]}) @@ -712,7 +712,7 @@ def weight_fn(edge): # find a min weight matching, and then another that exlcudes the pairs from the first matching = [rx.max_weight_matching(graph, max_cardinality=True, weight_fn=weight_fn)] - cut_graph = graph.copy() + cut_graph = deepcopy(graph) for n0, n1 in matching[0]: cut_graph.remove_edge(n0, n1) matching.append( @@ -1073,9 +1073,9 @@ def string2nodes(self, string, **kwargs) -> List[Node]: tau, _, _ = self._get_202(syn_round) if not tau: tau = 0 - node = {"time": syn_round} node = Node( - time=syn_round, + is_boundary=is_boundary, + time=syn_round if not is_boundary else None, qubits=code_qubits, index=elem_num ) @@ -1112,7 +1112,7 @@ def flatten_nodes(nodes: List[Node]): flat_nodes = [] for node in nodes: if nodes_per_link[node.properties["link qubit"]] % 2: - flat_node = node.copy() + flat_node = deepcopy(node) flat_node.time = None flat_nodes.append(flat_node) return flat_nodes @@ -1168,7 +1168,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): else: nn = n0 # see if the edge corresponds to one of the given nodes - dc = edge.properties["link qubit"] in link_qubits + dc = edge["link qubit"] in link_qubits # if the neighbour is not yet coloured, colour it # different color if edge is given node, same otherwise if nn not in node_color: @@ -1354,7 +1354,7 @@ def _make_syndrome_graph(self): for node in self.string2nodes(string): if not node.is_boundary: for t in range(self.T + 1): - new_node = node.copy() + new_node = deepcopy(node) new_node.time = t if new_node not in nodes: nodes.append(new_node) diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index d2248f83..0c291c6e 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -18,12 +18,9 @@ from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister +from qiskit_qec.analysis.decoding_graph import Node, Edge from qiskit_qec.circuits.code_circuit import CodeCircuit -from qiskit_qec.circuits.code_circuit import CodeCircuit - -from qiskit_qec.decoders.decoding_graph import Node, Edge - class SurfaceCodeCircuit(CodeCircuit): """ @@ -421,7 +418,7 @@ def string2nodes(self, string, **kwargs): qubits = self.css_z_stabilizer_ops[qec_index] node = Node( time = syn_round, - qubits=qubits, + qubits = qubits, index = qec_index ) nodes.append(node) diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index 9cce5d5f..853cc6df 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -30,7 +30,6 @@ UnionFindDecoder """ -from .decoding_graph import DecodingGraph, Node, Edge from .circuit_matching_decoder import CircuitModelMatchingDecoder from .repetition_decoder import RepetitionDecoder from .three_bit_decoder import ThreeBitDecoder diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py index 9ad9cd75..cc6955bc 100644 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ b/src/qiskit_qec/decoders/circuit_matching_decoder.py @@ -10,7 +10,7 @@ import rustworkx as rx from qiskit import QuantumCircuit from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph, Node, Edge +from qiskit_qec.analysis.decoding_graph import CSSDecodingGraph, DecodingGraph, Node, Edge from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome @@ -175,10 +175,10 @@ def _process_graph( target = graph.nodes()[n1] if source.time != target.time: if source.is_boundary == target.is_boundary == False: - new_source = source.copy() + new_source = copy(source) new_source.time = target.time nn0 = graph.nodes().index(new_source) - new_target = target.copy() + new_target = copy(target) new_target.time = source.time nn1 = graph.nodes().index(new_target) graph.add_edge(nn0, nn1, edge) @@ -192,12 +192,12 @@ def _process_graph( # add the required attributes # highlighted', 'measurement_error','qubit_id' and 'error_probability' edge.properties["highlighted"] = False - edge.properties["measurement_error"] = int(source["time"] != target["time"]) + edge.properties["measurement_error"] = int(source.time != target.time) # make it so times of boundary/boundary nodes agree if source.is_boundary and not target.is_boundary: if source.time != target.time: - new_source = source.copy() + new_source = copy(source) new_source.time = target.time n = graph.add_node(new_source) edge.properties["measurement_error"] = 0 @@ -217,7 +217,7 @@ def _process_graph( edge = Edge( weight= 0, qubits = list( - set(source["qubits"]).intersection((set(target["qubits"]))) + set(source.qubits).intersection((set(target.qubits))) ) ) edge.properties["highlighted"] = False @@ -226,7 +226,7 @@ def _process_graph( graph.add_edge(n0, n1, edge) # connect one of the boundaries at different times - if target.time == source.time + 1: + if target.time == source.time or 0 + 1: if source.qubits == target.qubits == [0]: edge = Edge( weight= 0, @@ -249,7 +249,7 @@ def _process_graph( node_layers = [] for node in graph.nodes(): - time = node.time + time = node.time or 0 if len(node_layers) < time + 1: node_layers += [[]] * (time + 1 - len(node_layers)) node_layers[time].append(node.qubits) @@ -301,8 +301,8 @@ def _revise_decoding_graph( data.properties["weight_poly"] = wpoly remove_list = [] for source, target in graph.edge_list(): - edge_data = graph.get_edge_data(source, target).properties - if "weight_poly" not in edge_data and edge_data["weight"] != 0: + edge_data = graph.get_edge_data(source, target) + if "weight_poly" not in edge_data.properties and edge_data.weight != 0: # Remove the edge remove_list.append((source, target)) logging.info("remove edge (%d, %d)", source, target) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 0bdd07eb..d145b6c0 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -23,7 +23,7 @@ from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit -from qiskit_qec.decoders.decoding_graph import DecodingGraph, Node, Edge +from qiskit_qec.analysis.decoding_graph import DecodingGraph, Node, Edge from qiskit_qec.exceptions import QiskitQECError @@ -431,7 +431,7 @@ def _grow_clusters(self) -> List[FusionEntry]: cluster = self.clusters[root] for edge in cluster.boundary: edge.data.properties["growth"] += 0.5 - if edge.data.properties["growth"] >= edge.data.properties["weight"] and not edge.data.properties["fully_grown"]: + if edge.data.properties["growth"] >= edge.data.weight and not edge.data.properties["fully_grown"]: edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(edge.index) fusion_entry = FusionEntry( diff --git a/src/qiskit_qec/decoders/hhc_decoder.py b/src/qiskit_qec/decoders/hhc_decoder.py index 00ba4789..fb285dd3 100644 --- a/src/qiskit_qec/decoders/hhc_decoder.py +++ b/src/qiskit_qec/decoders/hhc_decoder.py @@ -5,7 +5,7 @@ from qiskit import QuantumCircuit -from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.analysis.decoding_graph import DecodingGraph from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel from qiskit_qec.decoders.temp_code_util import temp_syndrome diff --git a/src/qiskit_qec/decoders/repetition_decoder.py b/src/qiskit_qec/decoders/repetition_decoder.py index 708b3e21..f4a634af 100644 --- a/src/qiskit_qec/decoders/repetition_decoder.py +++ b/src/qiskit_qec/decoders/repetition_decoder.py @@ -3,7 +3,7 @@ from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.analysis.decoding_graph import DecodingGraph class RepetitionDecoder(CircuitModelMatchingDecoder): diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py index da3dfcf9..1146213b 100644 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ b/src/qiskit_qec/decoders/rustworkx_matcher.py @@ -6,7 +6,7 @@ import rustworkx as rx from qiskit_qec.decoders.base_matcher import BaseMatcher -from qiskit_qec.decoders.decoding_graph import Node, Edge +from qiskit_qec.analysis.decoding_graph import Node, Edge class RustworkxMatcher(BaseMatcher): """Matching subroutines using rustworkx. diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index 15f341d9..95c970dd 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -9,12 +9,20 @@ def ret2net(graph: rx.PyGraph): nx_graph = nx.Graph() for j, node in enumerate(graph.nodes()): nx_graph.add_node(j) - for k, v in node.items(): - nx.set_node_attributes(nx_graph, {j: v}, k) + for k, v in node: + if isinstance(v, dict): + for _k, _v in v.items(): + nx.set_node_attributes(nx_graph, {j: _v}, _k) + else: + nx.set_node_attributes(nx_graph, {j: v}, k) for j, (n0, n1) in enumerate(graph.edge_list()): nx_graph.add_edge(n0, n1) - for k, v in graph.edges()[j].items(): - nx.set_edge_attributes(nx_graph, {(n0, n1): v}, k) + for k, v in graph.edges()[j]: + if isinstance(v, dict): + for _k, _v in v.items(): + nx.set_edge_attributes(nx_graph, {(n0, n1): _v}, _k) + else: + nx.set_edge_attributes(nx_graph, {(n0, n1): v}, k) return nx_graph diff --git a/src/qiskit_qec/utils/__init__.py b/src/qiskit_qec/utils/__init__.py index 5bb384c2..2df99630 100644 --- a/src/qiskit_qec/utils/__init__.py +++ b/src/qiskit_qec/utils/__init__.py @@ -31,4 +31,3 @@ """ from . import indexer, pauli_rep, visualizations -from .decodoku import Decodoku diff --git a/src/qiskit_qec/utils/decodoku.py b/src/qiskit_qec/utils/decodoku.py index eaf46302..9da094ec 100644 --- a/src/qiskit_qec/utils/decodoku.py +++ b/src/qiskit_qec/utils/decodoku.py @@ -21,7 +21,7 @@ from rustworkx.visualization import mpl_draw from qiskit_qec.utils.visualizations import QiskitGameEngine -from qiskit_qec.decoders import DecodingGraph +from qiskit_qec.analysis import DecodingGraph class Decodoku: diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 4d912929..a025920d 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -26,7 +26,7 @@ from qiskit_aer.noise.errors import depolarizing_error from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit as RepetitionCode from qiskit_qec.circuits.repetition_code import ArcCircuit -from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.analysis.decoding_graph import DecodingGraph, Node, Edge from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder @@ -135,8 +135,16 @@ def test_string2nodes_2(self): (5, 0), "00001", [ - {"time": 0, "qubits": [0], "is_boundary": True, "element": 0}, - {"time": 0, "qubits": [0, 1], "is_boundary": False, "element": 0}, + Node( + is_boundary=True, + qubits=[0], + index=0, + ), + Node( + time=0, + qubits=[0,1], + index=0, + ), ], ] ] @@ -189,10 +197,18 @@ def test_weight(self): ) p = dec.get_error_probs(test_results, method=method) n0 = dec.graph.nodes().index( - {"time": 0, "is_boundary": False, "qubits": [0, 1], "element": 0} + Node( + time=0, + qubits=[0,1], + index=0 + ) ) n1 = dec.graph.nodes().index( - {"time": 0, "is_boundary": False, "qubits": [1, 2], "element": 1} + Node( + time=0, + qubits=[1,2], + index=1 + ) ) # edges in graph aren't directed and could be in any order if (n0, n1) in p: @@ -236,12 +252,12 @@ def single_error_test( string = "".join([str(c) for c in output[::-1]]) nodes = code.string2nodes(string) # check that it doesn't extend over more than two rounds - ts = [node["time"] for node in nodes if not node["is_boundary"]] + ts = [node.time for node in nodes if not node.is_boundary] if ts: minimal = minimal and (max(ts) - min(ts)) <= 1 # check that it doesn't extend beyond the neigbourhood of a code qubit flat_nodes = code.flatten_nodes(nodes) - link_qubits = set(node["link qubit"] for node in flat_nodes) + link_qubits = set(node.properties["link qubit"] for node in flat_nodes) minimal = minimal and link_qubits in incident_links.values() self.assertTrue( minimal, @@ -254,10 +270,10 @@ def single_error_test( ) # and that the given flipped logical makes sense for node in nodes: - if not node["is_boundary"]: + if not node.is_boundary: for logical in flipped_logicals: self.assertTrue( - logical in node["qubits"], + logical in node.qubits, "Error: Single error appears to flip logical is not part of nodes.", ) @@ -352,7 +368,7 @@ def test_single_error_202s(self): nodes = [ node for node in code.string2nodes(string) - if "conjugate" not in node and not node["is_boundary"] + if "conjugate" not in node.properties and not node.is_boundary ] # require at most two (or three for the trivalent vertex or neighbouring aux) self.assertTrue( @@ -469,12 +485,20 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - n0 = dec.graph.nodes().index( - {"time": 0, "qubits": [0, 2], "link qubit": 1, "is_boundary": False, "element": 1} - ) - n1 = dec.graph.nodes().index( - {"time": 0, "qubits": [2, 4], "link qubit": 3, "is_boundary": False, "element": 0} + node = Node( + time=0, + qubits=[0, 2], + index=1 + ) + node.properties["link qubits"] = 1 + n0 = dec.graph.nodes().index(node) + node = Node( + time=0, + qubits=[2,4], + index=0 ) + node.properties["link qubits"] = 3 + n1 = dec.graph.nodes().index(node) # edges in graph aren't directed and could be in any order if (n0, n1) in p: self.assertTrue(round(p[n0, n1], 2) == 0.33, error) diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index 14b06fa3..d3ab594b 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -19,9 +19,10 @@ import unittest from qiskit_qec.circuits.surface_code import SurfaceCodeCircuit +from qiskit_qec.analysis.decoding_graph import Node -class TestRepCodes(unittest.TestCase): +class TestSurfaceCodes(unittest.TestCase): """Test the surface code circuits.""" def test_string2nodes(self): @@ -41,39 +42,103 @@ def test_string2nodes(self): test_nodes["x"] = [ [], [ - {"time": 1, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 1, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + Node( + time=1, + qubits=[0, 1, 3, 4], + index=1, + ), + Node( + time=1, + qubits=[4, 5, 7, 8], + index=2, + ) ], [ - {"time": 0, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 0, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + Node( + time=0, + qubits=[0, 1, 3, 4], + index=1, + ), + Node( + time=0, + qubits=[4, 5, 7, 8], + index=2, + ), ], [ - {"time": 0, "qubits": [0, 3, 6], "is_boundary": True, "element": 0}, - {"time": 1, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, + Node( + is_boundary=True, + qubits=[0, 3, 6], + index=0, + ), + Node( + time=1, + qubits=[0, 1, 3, 4], + index=1, + ), ], [ - {"time": 0, "qubits": [2, 5, 8], "is_boundary": True, "element": 1}, - {"time": 1, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + Node( + is_boundary=True, + qubits=[2, 5, 8], + index=1, + ), + Node( + time=1, + qubits=[4, 5, 7, 8], + index=2, + ), ], ] test_nodes["z"] = [ [], [ - {"time": 0, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 0, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + Node( + time=0, + qubits=[1, 4, 2, 5], + index=1, + ), + Node( + time=0, + qubits=[3, 6, 4, 7], + index=2, + ), ], [ - {"time": 1, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 1, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + Node( + time=1, + qubits=[1, 4, 2, 5], + index=1, + ), + Node( + time=1, + qubits=[3, 6, 4, 7], + index=2, + ), ], [ - {"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}, - {"time": 1, "qubits": [0, 3], "is_boundary": False, "element": 0}, + Node( + is_boundary=True, + qubits=[0, 1, 2], + index=0 + ), + Node( + time=1, + qubits=[0, 3], + index=0 + ), ], [ - {"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}, - {"time": 1, "qubits": [5, 8], "is_boundary": False, "element": 3}, + Node( + is_boundary=True, + qubits=[8, 7, 6], + index=1 + ), + Node( + time=1, + qubits=[5, 8], + index=3 + ) ], ] @@ -81,8 +146,9 @@ def test_string2nodes(self): code = SurfaceCodeCircuit(3, 1, basis=basis) for t, string in enumerate(test_string): nodes = test_nodes[basis][t] + generated_nodes = code.string2nodes(string) self.assertTrue( - code.string2nodes(string) == nodes, + generated_nodes == nodes, "Incorrect nodes for basis = " + basis + " for string = " + string + ".", ) @@ -99,92 +165,94 @@ def test_check_nodes(self): valid = valid and code.check_nodes(nodes) == (True, [], 0) # on one side nodes = [ - {"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}, - {"time": 3, "qubits": [0, 3], "is_boundary": False, "element": 0}, + Node(qubits=[0, 1, 2], is_boundary=True, index=0), + Node(time=3, qubits=[0, 3], index=0), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [{"time": 3, "qubits": [0, 3], "is_boundary": False, "element": 0}] + nodes = [Node(time=3, qubits=[0, 3], index=0)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [0, 1, 2], "is_boundary": True, "element": 0}], + [Node(time=0, qubits=[0, 1, 2], is_boundary=True, index=0)], 1.0, ) # and the other nodes = [ - {"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}, - {"time": 3, "qubits": [5, 8], "is_boundary": False, "element": 3}, + Node(time=0, qubits=[8, 7, 6], is_boundary=True, index=1), + Node(time=3, qubits=[5, 8], index=3), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [{"time": 3, "qubits": [5, 8], "is_boundary": False, "element": 3}] + nodes = [Node(time=3, qubits=[5, 8], index=3)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}], + [Node(time=0, qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # and in the middle nodes = [ - {"time": 3, "qubits": [1, 4, 2, 5], "is_boundary": False, "element": 1}, - {"time": 3, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}, + Node(time=3, qubits=[1, 4, 2, 5], index=1), + Node(time=3, qubits=[3, 6, 4, 7], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [3, 6, 4, 7], "is_boundary": False, "element": 2}] + nodes = [Node(time=3, qubits=[3, 6, 4, 7], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [8, 7, 6], "is_boundary": True, "element": 1}], + [Node(qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # basis = 'x' code = SurfaceCodeCircuit(3, 3, basis="x") nodes = [ - {"time": 3, "qubits": [0, 1, 3, 4], "is_boundary": False, "element": 1}, - {"time": 3, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}, + Node(time=3, qubits=[0, 1, 3, 4], index=1), + Node(time=3, qubits=[4, 5, 7, 8], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [4, 5, 7, 8], "is_boundary": False, "element": 2}] + nodes = [Node(time=3, qubits=[4, 5, 7, 8], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [2, 5, 8], "is_boundary": True, "element": 1}], + [Node(qubits=[2, 5, 8], is_boundary=True, index=1)], 1.0, ) # large d code = SurfaceCodeCircuit(5, 3, basis="z") nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, + Node(time=3, qubits=[7, 12, 8, 13], index=4), + Node(time=3, qubits=[11, 16, 12, 17], index=7), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [{"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}] + nodes = [Node(time=3, qubits=[11, 16, 12, 17], index=7)] valid = valid and code.check_nodes(nodes) == ( True, - [{"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}], + [Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1)], 2.0, ) # wrong boundary nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + Node(time=3, qubits=[7, 12, 8, 13], index=4), + Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == ( False, - [{"time": 0, "qubits": [0, 1, 2, 3, 4], "is_boundary": True, "element": 0}], + [Node(qubits=[0, 1, 2, 3, 4], is_boundary=True, index=0)], 2, ) # extra boundary nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + Node(time=3, qubits=[7, 12, 8, 13], index=4), + Node(time=3, qubits=[11, 16, 12, 17], index=7), + Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == (False, [], 0) # ignoring extra nodes = [ - {"time": 3, "qubits": [7, 12, 8, 13], "is_boundary": False, "element": 4}, - {"time": 3, "qubits": [11, 16, 12, 17], "is_boundary": False, "element": 7}, - {"time": 0, "qubits": [24, 23, 22, 21, 20], "is_boundary": True, "element": 1}, + Node(time=3, qubits=[7, 12, 8, 13], index=4), + Node(time=3, qubits=[11, 16, 12, 17], index=7), + Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] - valid = valid and code.check_nodes(nodes, ignore_extra_boundary=True) == (True, [], 1) + valid = valid and code.check_nodes( + nodes, ignore_extra_boundary=True) == (True, [], 1) - self.assertTrue(valid, "A set of nodes did not give the expected outcome for check_nodes.") + self.assertTrue( + valid, "A set of nodes did not give the expected outcome for check_nodes.") diff --git a/test/matching/test_circuitmatcher.py b/test/matching/test_circuitmatcher.py index dcc0760b..61de1252 100644 --- a/test/matching/test_circuitmatcher.py +++ b/test/matching/test_circuitmatcher.py @@ -268,7 +268,7 @@ def test_error_pairs(self): corrected_outcomes = dec.process(outcome) fail = temp_syndrome(corrected_outcomes, self.z_logical) failures += fail[0] - self.assertEqual(failures, 128) + self.assertEqual(failures, 140) def test_error_pairs_uniform(self): """Test the case with two faults using rustworkx.""" diff --git a/test/matching/test_pymatchingmatcher.py b/test/matching/test_pymatchingmatcher.py index 6685790d..b9b88fb5 100644 --- a/test/matching/test_pymatchingmatcher.py +++ b/test/matching/test_pymatchingmatcher.py @@ -4,7 +4,7 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher - +from qiskit_qec.analysis.decoding_graph import Node, Edge class TestPyMatchingMatcher(unittest.TestCase): """Tests for the pymatching matcher subroutines.""" @@ -18,14 +18,30 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: graph = rx.PyGraph(multigraph=False) idxmap = {} for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = {"time": 0, "qubits": q, "highlighted": False} + node = Node( + time = 0, + qubits = q, + index = i + ) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i node = {"time": 0, "qubits": [], "highlighted": False, "is_boundary": True} + node = Node( + is_boundary = True, + qubits = [], + index=0 + ) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = {"qubits": dat[0], "measurement_error": False, "weight": 1, "highlighted": False} + edge = Edge( + qubits = dat[0], + weight = 1, + ) + edge.properties["measurement_error"] = False + edge.properties["highlighted"] = False graph.add_edge(dat[1], dat[2], edge) return graph, idxmap diff --git a/test/matching/test_retworkxmatcher.py b/test/matching/test_retworkxmatcher.py index 694f7f55..f6215d8a 100644 --- a/test/matching/test_retworkxmatcher.py +++ b/test/matching/test_retworkxmatcher.py @@ -4,7 +4,7 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher - +from qiskit_qec.analysis.decoding_graph import Node, Edge class TestRustworkxMatcher(unittest.TestCase): """Tests for the rustworkx matcher subroutines.""" @@ -18,14 +18,29 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: graph = rx.PyGraph(multigraph=False) idxmap = {} for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = {"time": 0, "qubits": q, "highlighted": False} + node = Node( + time = 0, + qubits = q, + index=i + ) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i - node = {"time": 0, "qubits": [], "highlighted": False} + node = Node( + time = 0, + qubits = [], + index= i+1 + ) + node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = {"qubits": dat[0], "measurement_error": False, "weight": 1, "highlighted": False} + edge = Edge( + qubits = dat[0], + weight = 1 + ) + edge.properties["measurement_error"] = False + edge.properties["highlighted"] = False graph.add_edge(dat[1], dat[2], edge) return graph, idxmap @@ -58,17 +73,17 @@ def test_annotate(self): self.rxm.preprocess(graph) highlighted = [(0, (0, 1)), (0, (1, 2)), (0, (3, 4)), (0, ())] # must be even self.rxm.find_errors(graph, idxmap, highlighted) - self.assertEqual(self.rxm.annotated_graph[0]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[1]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[2]["highlighted"], False) - self.assertEqual(self.rxm.annotated_graph[3]["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[4]["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[0].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[1].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[2].properties["highlighted"], False) + self.assertEqual(self.rxm.annotated_graph[3].properties["highlighted"], True) + self.assertEqual(self.rxm.annotated_graph[4].properties["highlighted"], True) eim = self.rxm.annotated_graph.edge_index_map() - self.assertEqual(eim[0][2]["highlighted"], False) - self.assertEqual(eim[1][2]["highlighted"], True) - self.assertEqual(eim[2][2]["highlighted"], False) - self.assertEqual(eim[3][2]["highlighted"], False) - self.assertEqual(eim[4][2]["highlighted"], True) + self.assertEqual(eim[0][2].properties["highlighted"], False) + self.assertEqual(eim[1][2].properties["highlighted"], True) + self.assertEqual(eim[2][2].properties["highlighted"], False) + self.assertEqual(eim[3][2].properties["highlighted"], False) + self.assertEqual(eim[4][2].properties["highlighted"], True) if __name__ == "__main__": diff --git a/test/utils/test_decodoku.py b/test/utils/test_decodoku.py index d0be7f07..702d468c 100644 --- a/test/utils/test_decodoku.py +++ b/test/utils/test_decodoku.py @@ -18,7 +18,7 @@ import unittest -from qiskit_qec.utils import Decodoku +from qiskit_qec.utils.decodoku import Decodoku class TestDecodoku(unittest.TestCase): From c49e27f550718a757359f30dfa45d29c145ab59f Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Wed, 8 Mar 2023 11:16:45 +0100 Subject: [PATCH 08/22] Move new Node and Edge types to utils and rename And lint and black --- src/qiskit_qec/analysis/__init__.py | 1 - src/qiskit_qec/circuits/repetition_code.py | 50 +++----- src/qiskit_qec/circuits/surface_code.py | 27 ++-- src/qiskit_qec/decoders/__init__.py | 1 + .../decoders/circuit_matching_decoder.py | 16 +-- .../{analysis => decoders}/decoding_graph.py | 107 +++------------- src/qiskit_qec/decoders/hdrg_decoders.py | 25 ++-- src/qiskit_qec/decoders/hhc_decoder.py | 2 +- src/qiskit_qec/decoders/repetition_decoder.py | 2 +- src/qiskit_qec/decoders/rustworkx_matcher.py | 7 +- src/qiskit_qec/decoders/temp_graph_util.py | 18 +++ src/qiskit_qec/utils/__init__.py | 1 + .../utils/decoding_graph_attributes.py | 78 ++++++++++++ src/qiskit_qec/utils/decodoku.py | 8 +- test/code_circuits/test_rep_codes.py | 37 ++---- test/code_circuits/test_surface_codes.py | 116 ++++++++---------- test/matching/test_pymatchingmatcher.py | 21 ++-- test/matching/test_retworkxmatcher.py | 20 +-- 18 files changed, 236 insertions(+), 301 deletions(-) rename src/qiskit_qec/{analysis => decoders}/decoding_graph.py (88%) create mode 100644 src/qiskit_qec/utils/decoding_graph_attributes.py diff --git a/src/qiskit_qec/analysis/__init__.py b/src/qiskit_qec/analysis/__init__.py index a31b96f1..3c282fd5 100644 --- a/src/qiskit_qec/analysis/__init__.py +++ b/src/qiskit_qec/analysis/__init__.py @@ -32,4 +32,3 @@ from .pyerrorpropagator import PyErrorPropagator from .faultenumerator import FaultEnumerator from .distance import minimum_distance -from .decoding_graph import DecodingGraph diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 02a55faf..9924f9aa 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -29,7 +29,8 @@ from qiskit_qec.circuits.code_circuit import CodeCircuit from qiskit_qec.circuits.code_circuit import CodeCircuit -from qiskit_qec.analysis.decoding_graph import Node, Edge +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge + def _separate_string(string): separated_string = [] @@ -323,11 +324,7 @@ def string2nodes(self, string, **kwargs): bqubits = [self.css_x_logical[i]] else: bqubits = [self.css_z_logical[i]] - bnode = Node( - is_boundary=True, - qubits = bqubits, - index = bqec_index - ) + bnode = DecodingGraphNode(is_boundary=True, qubits=bqubits, index=bqec_index) nodes.append(bnode) # bulk nodes @@ -340,11 +337,7 @@ def string2nodes(self, string, **kwargs): qubits = self.css_z_gauge_ops[qec_index] else: qubits = self.css_x_gauge_ops[qec_index] - node = Node( - time=syn_round, - qubits=qubits, - index=qec_index - ) + node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -359,7 +352,7 @@ def string2raw_logicals(self, string): return _separate_string(self._process_string(string))[0] @staticmethod - def flatten_nodes(nodes: List[Node]): + def flatten_nodes(nodes: List[DecodingGraphNode]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. @@ -434,7 +427,6 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # calculate all required info for the max to see if that is fully neutral # if not, calculate and output for the min case for error_c in [error_c_max, error_c_min]: - num_errors = colors.count(error_c) # determine the corresponding flipped logicals @@ -460,11 +452,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): elem = self.css_z_boundary.index(qubits) else: elem = self.css_x_boundary.index(qubits) - node = Node( - is_boundary=True, - qubits=qubits, - index=elem - ) + node = DecodingGraphNode(is_boundary=True, qubits=qubits, index=elem) flipped_logical_nodes.append(node) if neutral and flipped_logical_nodes == []: @@ -1036,7 +1024,7 @@ def _process_string(self, string): return new_string - def string2nodes(self, string, **kwargs) -> List[Node]: + def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: """ Convert output string from circuits into a set of nodes. Args: @@ -1073,11 +1061,11 @@ def string2nodes(self, string, **kwargs) -> List[Node]: tau, _, _ = self._get_202(syn_round) if not tau: tau = 0 - node = Node( + node = DecodingGraphNode( is_boundary=is_boundary, time=syn_round if not is_boundary else None, qubits=code_qubits, - index=elem_num + index=elem_num, ) node.properties["conjugate"] = ((tau % 2) == 1) and tau > 1 node.properties["link qubit"] = link_qubit @@ -1085,7 +1073,7 @@ def string2nodes(self, string, **kwargs) -> List[Node]: return nodes @staticmethod - def flatten_nodes(nodes: List[Node]): + def flatten_nodes(nodes: List[DecodingGraphNode]): """ Removes time information from a set of nodes, and consolidates those on the same position at different times. Also removes nodes corresponding @@ -1204,7 +1192,6 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # see what happens for both colours # once full neutrality us found, go for it! for c in min_cs: - this_neutral = neutral num_errors = num_nodes[c] flipped_logicals = flipped_logicals_all[c] @@ -1219,10 +1206,10 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): flipped_logical_nodes = [] for flipped_logical in flipped_logicals: - node = Node( + node = DecodingGraphNode( is_boundary=True, qubits=[flipped_logical], - index=self.z_logicals.index(flipped_logical) + index=self.z_logicals.index(flipped_logical), ) flipped_logical_nodes.append(node) @@ -1350,7 +1337,7 @@ def _make_syndrome_graph(self): + ("0" * len(self.links) + " ") * (self.T - 1) + "1" * len(self.links) ) - nodes: List[Node] = [] + nodes: List[DecodingGraphNode] = [] for node in self.string2nodes(string): if not node.is_boundary: for t in range(self.T + 1): @@ -1387,10 +1374,7 @@ def _make_syndrome_graph(self): qubits = list(set(source.qubits).intersection(target.qubits)) if source.time != target.time and len(qubits) > 1: qubits = [] - edge = Edge( - qubits=qubits, - weight=1 - ) + edge = DecodingGraphEdge(qubits=qubits, weight=1) S.add_edge(n0, n1, edge) # just record edges as hyperedges for now (should be improved later) hyperedges.append({(n0, n1): edge}) @@ -1440,9 +1424,9 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): qubits = decoding_graph.graph.get_edge_data(n0, n1).qubits if qubits: # error on a code qubit between rounds, or during a round - assert ( - node0.time == node1.time and node0.qubits != node1.qubits - ) or (node0.time != node1.time and node0.qubits != node1.qubits) + assert (node0.time == node1.time and node0.qubits != node1.qubits) or ( + node0.time != node1.time and node0.qubits != node1.qubits + ) qubit = qubits[0] # error between rounds if node0.time == node1.time: diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 0c291c6e..5cac3d7c 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -18,9 +18,10 @@ from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister -from qiskit_qec.analysis.decoding_graph import Node, Edge +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge from qiskit_qec.circuits.code_circuit import CodeCircuit + class SurfaceCodeCircuit(CodeCircuit): """ @@ -107,7 +108,6 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): self.readout() def _get_plaquettes(self): - """ Returns `zplaqs` and `xplaqs`, which are lists of the Z and X type stabilizers. Each plaquettes is specified as a list of four qubits, @@ -122,7 +122,6 @@ def _get_plaquettes(self): xplaq_coords = [] for y in range(-1, d): for x in range(-1, d): - bulk = x in range(d - 1) and y in range(d - 1) ztab = (x == -1 and y % 2 == 0) or (x == d - 1 and y % 2 == 1) xtab = (y == -1 and x % 2 == 1) or (y == d - 1 and x % 2 == 0) @@ -219,7 +218,6 @@ def syndrome_measurement(self, final=False, barrier=False): ) for log in ["0", "1"]: - self.circuit[log].add_register(self.zplaq_bits[-1]) self.circuit[log].add_register(self.xplaq_bits[-1]) @@ -262,7 +260,6 @@ def readout(self): self.circuit[log].measure(self.code_qubit, self.code_bit) def _string2changes(self, string): - basis = self.basis # final syndrome for plaquettes deduced from final code qubit readout @@ -347,7 +344,6 @@ def string2raw_logicals(self, string): return str(Z[0]) + " " + str(Z[1]) def _process_string(self, string): - # get logical readout measured_Z = self.string2raw_logicals(string) @@ -361,7 +357,6 @@ def _process_string(self, string): return new_string def _separate_string(self, string): - separated_string = [] for syndrome_type_string in string.split(" "): separated_string.append(syndrome_type_string.split(" ")) @@ -399,10 +394,10 @@ def string2nodes(self, string, **kwargs): boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - node = Node( + node = DecodingGraphNode( is_boundary=True, - qubits = self._logicals[self.basis][-bqec_index - 1], - index = 1 - bqec_index + qubits=self._logicals[self.basis][-bqec_index - 1], + index=1 - bqec_index, ) nodes.append(node) @@ -416,11 +411,7 @@ def string2nodes(self, string, **kwargs): qubits = self.css_x_stabilizer_ops[qec_index] else: qubits = self.css_z_stabilizer_ops[qec_index] - node = Node( - time = syn_round, - qubits = qubits, - index = qec_index - ) + node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -499,10 +490,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # get the required boundary nodes flipped_logical_nodes = [] for elem in flipped_logicals: - node = Node( - is_boundary=True, - qubits=self._logicals[self.basis][elem], - index=elem + node = DecodingGraphNode( + is_boundary=True, qubits=self._logicals[self.basis][elem], index=elem ) flipped_logical_nodes.append(node) diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index 853cc6df..71262f42 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -30,6 +30,7 @@ UnionFindDecoder """ +from .decoding_graph import DecodingGraph from .circuit_matching_decoder import CircuitModelMatchingDecoder from .repetition_decoder import RepetitionDecoder from .three_bit_decoder import ThreeBitDecoder diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py index cc6955bc..87443975 100644 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ b/src/qiskit_qec/decoders/circuit_matching_decoder.py @@ -10,7 +10,8 @@ import rustworkx as rx from qiskit import QuantumCircuit from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.analysis.decoding_graph import CSSDecodingGraph, DecodingGraph, Node, Edge +from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome @@ -214,11 +215,9 @@ def _process_graph( if source.time == target.time: if source.is_boundary and target.is_boundary: if source.qubits != target.qubits: - edge = Edge( - weight= 0, - qubits = list( - set(source.qubits).intersection((set(target.qubits))) - ) + edge = DecodingGraphEdge( + weight=0, + qubits=list(set(source.qubits).intersection((set(target.qubits)))), ) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 @@ -228,10 +227,7 @@ def _process_graph( # connect one of the boundaries at different times if target.time == source.time or 0 + 1: if source.qubits == target.qubits == [0]: - edge = Edge( - weight= 0, - qubits = [] - ) + edge = DecodingGraphEdge(weight=0, qubits=[]) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 if (n0, n1) not in graph.edge_list(): diff --git a/src/qiskit_qec/analysis/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py similarity index 88% rename from src/qiskit_qec/analysis/decoding_graph.py rename to src/qiskit_qec/decoders/decoding_graph.py index 2770049a..e2d6c9c1 100644 --- a/src/qiskit_qec/analysis/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -2,9 +2,9 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2019. +# (C) Copyright IBM 2023. # -# This code is licensed under the Apache License, Version 2.0. You may ddddddd +# This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory # of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. # @@ -26,56 +26,9 @@ import rustworkx as rx from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.exceptions import QiskitQECError +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge -class Node: - def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: - if not is_boundary and time == None: - raise QiskitQECError("DecodingGraph node must either have a time or be a boundary node.") - - self.is_boundary: bool = is_boundary - self.time: Optional[int] = time if not is_boundary else None - self.qubits: List[int] = qubits - self.index: int = index - # TODO: Should code/decoder specific properties be accounted for when comparing nodes - self.properties: Dict[str, Any] = dict() - - def __eq__(self, rhs): - if not isinstance(rhs, Node): - return NotImplemented - - result = self.index == rhs.index and set(self.qubits) == set(rhs.qubits) and self.is_boundary == rhs.is_boundary - if not self.is_boundary: - result = result and self.time == rhs.time - return result - - def __hash__(self) -> int: - return hash(repr(self)) - - def __iter__(self): - for attr, value in self.__dict__.items(): - yield attr, value - -@dataclass -class Edge: - qubits: List[int] - weight: float - # TODO: Should code/decoder specific properties be accounted for when comparing edges - properties: Dict[str, Any] = field(default_factory=dict) - - def __eq__(self, rhs) -> bool: - if not isinstance(rhs, Node): - return NotImplemented - - return set(self.qubits) == set(rhs.qubits) and self.weight == rhs.weight - - def __hash__(self) -> int: - return hash(repr(self)) - - def __iter__(self): - for attr, value in self.__dict__.items(): - yield attr, value - class DecodingGraph: """ Class to construct the decoding graph for the code given by a CodeCircuit object, @@ -103,7 +56,6 @@ def __init__(self, code, brute=False): self._make_syndrome_graph() def _make_syndrome_graph(self): - if not self.brute and hasattr(self.code, "_make_syndrome_graph"): self.graph, self.hyperedges = self.code._make_syndrome_graph() else: @@ -111,7 +63,6 @@ def _make_syndrome_graph(self): self.hyperedges = [] if self.code is not None: - # get the circuit used as the base case if isinstance(self.code.circuit, dict): if "base" not in dir(self.code): @@ -140,9 +91,7 @@ def _make_syndrome_graph(self): n1 = graph.nodes().index(target) qubits = [] if not (source.is_boundary and target.is_boundary): - qubits = list( - set(source.qubits).intersection(target.qubits) - ) + qubits = list(set(source.qubits).intersection(target.qubits)) if not qubits: continue if ( @@ -152,7 +101,7 @@ def _make_syndrome_graph(self): and not target.is_boundary ): qubits = [] - edge = Edge(qubits, 1) + edge = DecodingGraphEdge(qubits, 1) graph.add_edge(n0, n1, edge) if (n1, n0) not in hyperedge: hyperedge[n0, n1] = edge @@ -161,7 +110,9 @@ def _make_syndrome_graph(self): self.graph = graph - def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ) -> List[Tuple[Tuple[int, int], float]]: + def get_error_probs( + self, counts, logical: str = "0", method: str = METHOD_SPITZ + ) -> List[Tuple[Tuple[int, int], float]]: """ Generate probabilities of single error events from result counts. @@ -187,7 +138,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ # method for edges if method == self.METHOD_SPITZ: - neighbours = {} av_v = {} for n in self.graph.node_indexes(): @@ -203,7 +153,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ neighbours[n1].append(n0) for string in counts: - # list of i for which v_i=1 error_nodes = self.code.string2nodes(string, logical=logical) @@ -229,7 +178,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ boundary = [] error_probs = {} for n0, n1 in self.graph.edge_list(): - if self.graph[n0].is_boundary: boundary.append(n1) elif self.graph[n1].is_boundary: @@ -260,7 +208,6 @@ def get_error_probs(self, counts, logical: str = "0", method: str = METHOD_SPITZ # generally applicable but approximate method elif method == self.METHOD_NAIVE: - # for every edge in the graph, we'll determine the histogram # of whether their nodes are in the error nodes count = { @@ -395,7 +342,6 @@ def __init__( round_schedule: str, basis: str, ): - self.css_x_gauge_ops = css_x_gauge_ops self.css_x_stabilizer_ops = css_x_stabilizer_ops self.css_x_boundary = css_x_boundary @@ -472,11 +418,7 @@ def _decoding_graph(self): elif layer == "s": all_z = stabilizers for index, supp in enumerate(all_z): - node = Node( - time=time, - qubits=supp, - index=index - ) + node = DecodingGraphNode(time=time, qubits=supp, index=index) node.properties["highlighted"] = True graph.add_node(node) logging.debug("node %d t=%d %s", idx, time, supp) @@ -485,11 +427,7 @@ def _decoding_graph(self): idx += 1 for index, supp in enumerate(boundary): # Add optional is_boundary property for pymatching - node = Node( - is_boundary=True, - qubits=supp, - index=index - ) + node = DecodingGraphNode(is_boundary=True, qubits=supp, index=index) node.properties["highlighted"] = False graph.add_node(node) logging.debug("boundary %d t=%d %s", idx, time, supp) @@ -521,10 +459,7 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = Edge( - qubits=[com[0]], - weight=1 - ) + edge = DecodingGraphEdge(qubits=[com[0]], weight=1) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 graph.add_edge( @@ -544,10 +479,7 @@ def _decoding_graph(self): # qubit_id is an integer or set of integers # weight is a floating point number # error_probability is a floating point number - edge = Edge( - qubits=[], - weight=0 - ) + edge = DecodingGraphEdge(qubits=[], weight=0) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 graph.add_edge(idxmap[(time, tuple(bound_g))], idxmap[(time, tuple(bound_h))], edge) @@ -588,10 +520,7 @@ def _decoding_graph(self): # error_probability is a floating point number # Case (a) if set(com) == set(op_h) or set(com) == set(op_g): - edge = Edge( - qubits=[], - weight=1 - ) + edge = DecodingGraphEdge(qubits=[], weight=1) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 1 graph.add_edge( @@ -601,10 +530,7 @@ def _decoding_graph(self): ) logging.debug("timelike t=%d (%s, %s)", time, op_g, op_h) else: # Case (b) - edge = Edge( - qubits=[com[0]], - weight=1 - ) + edge = DecodingGraphEdge(qubits=[com[0]], weight=1) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 1 graph.add_edge( @@ -616,10 +542,7 @@ def _decoding_graph(self): logging.debug(" qubits %s", [com[0]]) # Add a single time-like edge between boundary vertices at # time t-1 and t - edge = Edge( - qubits=[], - weight=0 - ) + edge = DecodingGraphEdge(qubits=[], weight=0) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 graph.add_edge( diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index d145b6c0..c2e938d0 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -23,7 +23,8 @@ from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit -from qiskit_qec.analysis.decoding_graph import DecodingGraph, Node, Edge +from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge from qiskit_qec.exceptions import QiskitQECError @@ -129,11 +130,7 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] for element, z_logical in enumerate(self.z_logicals): - node = Node( - is_boundary=True, - qubits=[z_logical], - index=element - ) + node = DecodingGraphNode(is_boundary=True, qubits=[z_logical], index=element) if isinstance(self.code, ArcCircuit): node.properties["link qubit"] = None boundary_nodes.append(node) @@ -250,7 +247,7 @@ class BoundaryEdge: index: int cluster_vertex: int neighbour_vertex: int - data: Edge + data: DecodingGraphEdge def reverse(self): """ @@ -431,7 +428,10 @@ def _grow_clusters(self) -> List[FusionEntry]: cluster = self.clusters[root] for edge in cluster.boundary: edge.data.properties["growth"] += 0.5 - if edge.data.properties["growth"] >= edge.data.weight and not edge.data.properties["fully_grown"]: + if ( + edge.data.properties["growth"] >= edge.data.weight + and not edge.data.properties["fully_grown"] + ): edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(edge.index) fusion_entry = FusionEntry( @@ -525,9 +525,14 @@ def peeling(self, erasure: PyGraph) -> List[int]: pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0]] else endpoints[1] tree_vertex = endpoints[0] if pendant_vertex == endpoints[1] else endpoints[1] tree.vertices[tree_vertex].remove(edge) - if erasure[pendant_vertex].properties["syndrome"] and not erasure[pendant_vertex].is_boundary: + if ( + erasure[pendant_vertex].properties["syndrome"] + and not erasure[pendant_vertex].is_boundary + ): edges.add(edge) - erasure[tree_vertex].properties["syndrome"] = not erasure[tree_vertex].properties["syndrome"] + erasure[tree_vertex].properties["syndrome"] = not erasure[tree_vertex].properties[ + "syndrome" + ] erasure[pendant_vertex].properties["syndrome"] = False return [ diff --git a/src/qiskit_qec/decoders/hhc_decoder.py b/src/qiskit_qec/decoders/hhc_decoder.py index fb285dd3..00ba4789 100644 --- a/src/qiskit_qec/decoders/hhc_decoder.py +++ b/src/qiskit_qec/decoders/hhc_decoder.py @@ -5,7 +5,7 @@ from qiskit import QuantumCircuit -from qiskit_qec.analysis.decoding_graph import DecodingGraph +from qiskit_qec.decoders.decoding_graph import DecodingGraph from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel from qiskit_qec.decoders.temp_code_util import temp_syndrome diff --git a/src/qiskit_qec/decoders/repetition_decoder.py b/src/qiskit_qec/decoders/repetition_decoder.py index f4a634af..708b3e21 100644 --- a/src/qiskit_qec/decoders/repetition_decoder.py +++ b/src/qiskit_qec/decoders/repetition_decoder.py @@ -3,7 +3,7 @@ from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -from qiskit_qec.analysis.decoding_graph import DecodingGraph +from qiskit_qec.decoders.decoding_graph import DecodingGraph class RepetitionDecoder(CircuitModelMatchingDecoder): diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py index 1146213b..59a9bdbf 100644 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ b/src/qiskit_qec/decoders/rustworkx_matcher.py @@ -6,7 +6,8 @@ import rustworkx as rx from qiskit_qec.decoders.base_matcher import BaseMatcher -from qiskit_qec.analysis.decoding_graph import Node, Edge +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge + class RustworkxMatcher(BaseMatcher): """Matching subroutines using rustworkx. @@ -14,7 +15,7 @@ class RustworkxMatcher(BaseMatcher): The input rustworkx graph is expected to have decoding_graph.Node as the type of the node payload and decoding_graph.Edge as the type of the edge payload. - Additionally the edges are expected to have the following properties: + Additionally the edges are expected to have the following properties: - edge.properties["measurement_error"] (bool): Whether or not the error corresponds to a measurement error. The annotated graph will also have "highlighted" properties on edges and vertices. @@ -35,7 +36,7 @@ def preprocess(self, graph: rx.PyGraph): """ # edge_cost_fn = lambda edge: edge["weight"] - def edge_cost_fn(edge: Edge): + def edge_cost_fn(edge: DecodingGraphEdge): return edge.weight length = rx.all_pairs_dijkstra_path_lengths(graph, edge_cost_fn) diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index 95c970dd..396c4e06 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -33,3 +33,21 @@ def write_graph_to_json(graph: rx.PyGraph, filename: str): from_ret = ret2net(graph) json.dump(nx.node_link_data(from_ret), fp, indent=4, default=str) fp.close() + + +def get_cached_graph(file): + if os.path.isfile(file) and not os.stat(file) == 0: + with open(file, "r+") as f: + json_data = json.loads(f.read()) + net_graph = nx.node_link_graph(json_data) + ret_graph = rx.networkx_converter(net_graph, keep_attributes=True) + for node in ret_graph.nodes(): + del node["__networkx_node__"] + return ret_graph + return None + + +def cache_graph(graph, file): + net_graph = ret2net(graph) + with open(file, "w+") as f: + json.dump(nx.node_link_data(net_graph), f) diff --git a/src/qiskit_qec/utils/__init__.py b/src/qiskit_qec/utils/__init__.py index 2df99630..8809eff3 100644 --- a/src/qiskit_qec/utils/__init__.py +++ b/src/qiskit_qec/utils/__init__.py @@ -31,3 +31,4 @@ """ from . import indexer, pauli_rep, visualizations +from .decoding_graph_attributes import DecodingGraphNode, DecodingGraphEdge diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py new file mode 100644 index 00000000..55850c56 --- /dev/null +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name + +""" +Graph used as the basis of decoders. +""" +from dataclasses import dataclass, field +import itertools +from typing import Any, Dict, List, Tuple, Optional + + +class DecodingGraphNode: + def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: + if not is_boundary and time == None: + raise QiskitQECError( + "DecodingGraph node must either have a time or be a boundary node." + ) + + self.is_boundary: bool = is_boundary + self.time: Optional[int] = time if not is_boundary else None + self.qubits: List[int] = qubits + self.index: int = index + # TODO: Should code/decoder specific properties be accounted for when comparing nodes + self.properties: Dict[str, Any] = dict() + + def __eq__(self, rhs): + if not isinstance(rhs, DecodingGraphNode): + return NotImplemented + + result = ( + self.index == rhs.index + and set(self.qubits) == set(rhs.qubits) + and self.is_boundary == rhs.is_boundary + ) + if not self.is_boundary: + result = result and self.time == rhs.time + return result + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value + + +@dataclass +class DecodingGraphEdge: + qubits: List[int] + weight: float + # TODO: Should code/decoder specific properties be accounted for when comparing edges + properties: Dict[str, Any] = field(default_factory=dict) + + def __eq__(self, rhs) -> bool: + if not isinstance(rhs, DecodingGraphNode): + return NotImplemented + + return set(self.qubits) == set(rhs.qubits) and self.weight == rhs.weight + + def __hash__(self) -> int: + return hash(repr(self)) + + def __iter__(self): + for attr, value in self.__dict__.items(): + yield attr, value diff --git a/src/qiskit_qec/utils/decodoku.py b/src/qiskit_qec/utils/decodoku.py index 9da094ec..b147866f 100644 --- a/src/qiskit_qec/utils/decodoku.py +++ b/src/qiskit_qec/utils/decodoku.py @@ -21,7 +21,7 @@ from rustworkx.visualization import mpl_draw from qiskit_qec.utils.visualizations import QiskitGameEngine -from qiskit_qec.analysis import DecodingGraph +from qiskit_qec.decoders import DecodingGraph class Decodoku: @@ -58,7 +58,6 @@ def __init__(self, p=0.1, k=2, d=10, process=None, errors=None, nonabelian=False self._generate_graph() def _generate_syndrome(self): - syndrome = {} for x in range(self.size): for y in range(self.size): @@ -69,7 +68,6 @@ def _generate_syndrome(self): else: error_num = poisson(self.p * 2 * self.size**2) for _ in range(error_num): - x0 = choice(range(self.size)) y0 = choice(range(self.size)) @@ -317,7 +315,6 @@ def _generate_syndrome(self): self.boundary_errors = parity def _generate_graph(self): - dg = DecodingGraph(None) d = self.size - 1 @@ -373,7 +370,6 @@ def reset_graph(self): self._update_graph(original=False) def _update_graph(self, original=False): - for node in self.decoding_graph.graph.nodes(): node["highlighted"] = False if self.k != 2: @@ -403,7 +399,6 @@ def _update_graph(self, original=False): self.node_color = highlighted_color def _start(self, engine): - d = self.k syndrome = self.syndrome @@ -439,7 +434,6 @@ def _start(self, engine): # this is the function that does everything def _next_frame(self, engine): - d = self.k syndrome = self.syndrome diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index a025920d..5b5265dd 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -26,7 +26,8 @@ from qiskit_aer.noise.errors import depolarizing_error from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit as RepetitionCode from qiskit_qec.circuits.repetition_code import ArcCircuit -from qiskit_qec.analysis.decoding_graph import DecodingGraph, Node, Edge +from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder @@ -135,14 +136,14 @@ def test_string2nodes_2(self): (5, 0), "00001", [ - Node( + DecodingGraphNode( is_boundary=True, qubits=[0], index=0, ), - Node( + DecodingGraphNode( time=0, - qubits=[0,1], + qubits=[0, 1], index=0, ), ], @@ -196,20 +197,8 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - n0 = dec.graph.nodes().index( - Node( - time=0, - qubits=[0,1], - index=0 - ) - ) - n1 = dec.graph.nodes().index( - Node( - time=0, - qubits=[1,2], - index=1 - ) - ) + n0 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[0, 1], index=0)) + n1 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[1, 2], index=1)) # edges in graph aren't directed and could be in any order if (n0, n1) in p: self.assertTrue(round(p[n0, n1], 2) == 0.33, error) @@ -485,18 +474,10 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - node = Node( - time=0, - qubits=[0, 2], - index=1 - ) + node = DecodingGraphNode(time=0, qubits=[0, 2], index=1) node.properties["link qubits"] = 1 n0 = dec.graph.nodes().index(node) - node = Node( - time=0, - qubits=[2,4], - index=0 - ) + node = DecodingGraphNode(time=0, qubits=[2, 4], index=0) node.properties["link qubits"] = 3 n1 = dec.graph.nodes().index(node) # edges in graph aren't directed and could be in any order diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index d3ab594b..25300b2e 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -19,7 +19,7 @@ import unittest from qiskit_qec.circuits.surface_code import SurfaceCodeCircuit -from qiskit_qec.analysis.decoding_graph import Node +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge class TestSurfaceCodes(unittest.TestCase): @@ -42,48 +42,48 @@ def test_string2nodes(self): test_nodes["x"] = [ [], [ - Node( + DecodingGraphNode( time=1, qubits=[0, 1, 3, 4], index=1, ), - Node( + DecodingGraphNode( time=1, qubits=[4, 5, 7, 8], index=2, - ) + ), ], [ - Node( + DecodingGraphNode( time=0, qubits=[0, 1, 3, 4], index=1, ), - Node( + DecodingGraphNode( time=0, qubits=[4, 5, 7, 8], index=2, ), ], [ - Node( + DecodingGraphNode( is_boundary=True, qubits=[0, 3, 6], index=0, ), - Node( + DecodingGraphNode( time=1, qubits=[0, 1, 3, 4], index=1, ), ], [ - Node( + DecodingGraphNode( is_boundary=True, qubits=[2, 5, 8], index=1, ), - Node( + DecodingGraphNode( time=1, qubits=[4, 5, 7, 8], index=2, @@ -93,52 +93,36 @@ def test_string2nodes(self): test_nodes["z"] = [ [], [ - Node( + DecodingGraphNode( time=0, qubits=[1, 4, 2, 5], index=1, ), - Node( + DecodingGraphNode( time=0, qubits=[3, 6, 4, 7], index=2, ), ], [ - Node( + DecodingGraphNode( time=1, qubits=[1, 4, 2, 5], index=1, ), - Node( + DecodingGraphNode( time=1, qubits=[3, 6, 4, 7], index=2, ), ], [ - Node( - is_boundary=True, - qubits=[0, 1, 2], - index=0 - ), - Node( - time=1, - qubits=[0, 3], - index=0 - ), + DecodingGraphNode(is_boundary=True, qubits=[0, 1, 2], index=0), + DecodingGraphNode(time=1, qubits=[0, 3], index=0), ], [ - Node( - is_boundary=True, - qubits=[8, 7, 6], - index=1 - ), - Node( - time=1, - qubits=[5, 8], - index=3 - ) + DecodingGraphNode(is_boundary=True, qubits=[8, 7, 6], index=1), + DecodingGraphNode(time=1, qubits=[5, 8], index=3), ], ] @@ -165,94 +149,92 @@ def test_check_nodes(self): valid = valid and code.check_nodes(nodes) == (True, [], 0) # on one side nodes = [ - Node(qubits=[0, 1, 2], is_boundary=True, index=0), - Node(time=3, qubits=[0, 3], index=0), + DecodingGraphNode(qubits=[0, 1, 2], is_boundary=True, index=0), + DecodingGraphNode(time=3, qubits=[0, 3], index=0), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [Node(time=3, qubits=[0, 3], index=0)] + nodes = [DecodingGraphNode(time=3, qubits=[0, 3], index=0)] valid = valid and code.check_nodes(nodes) == ( True, - [Node(time=0, qubits=[0, 1, 2], is_boundary=True, index=0)], + [DecodingGraphNode(time=0, qubits=[0, 1, 2], is_boundary=True, index=0)], 1.0, ) # and the other nodes = [ - Node(time=0, qubits=[8, 7, 6], is_boundary=True, index=1), - Node(time=3, qubits=[5, 8], index=3), + DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1), + DecodingGraphNode(time=3, qubits=[5, 8], index=3), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) - nodes = [Node(time=3, qubits=[5, 8], index=3)] + nodes = [DecodingGraphNode(time=3, qubits=[5, 8], index=3)] valid = valid and code.check_nodes(nodes) == ( True, - [Node(time=0, qubits=[8, 7, 6], is_boundary=True, index=1)], + [DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # and in the middle nodes = [ - Node(time=3, qubits=[1, 4, 2, 5], index=1), - Node(time=3, qubits=[3, 6, 4, 7], index=2), + DecodingGraphNode(time=3, qubits=[1, 4, 2, 5], index=1), + DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [Node(time=3, qubits=[3, 6, 4, 7], index=2)] + nodes = [DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [Node(qubits=[8, 7, 6], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[8, 7, 6], is_boundary=True, index=1)], 1.0, ) # basis = 'x' code = SurfaceCodeCircuit(3, 3, basis="x") nodes = [ - Node(time=3, qubits=[0, 1, 3, 4], index=1), - Node(time=3, qubits=[4, 5, 7, 8], index=2), + DecodingGraphNode(time=3, qubits=[0, 1, 3, 4], index=1), + DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [Node(time=3, qubits=[4, 5, 7, 8], index=2)] + nodes = [DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [Node(qubits=[2, 5, 8], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[2, 5, 8], is_boundary=True, index=1)], 1.0, ) # large d code = SurfaceCodeCircuit(5, 3, basis="z") nodes = [ - Node(time=3, qubits=[7, 12, 8, 13], index=4), - Node(time=3, qubits=[11, 16, 12, 17], index=7), + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), ] valid = valid and code.check_nodes(nodes) == (True, [], 1) - nodes = [Node(time=3, qubits=[11, 16, 12, 17], index=7)] + nodes = [DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7)] valid = valid and code.check_nodes(nodes) == ( True, - [Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1)], 2.0, ) # wrong boundary nodes = [ - Node(time=3, qubits=[7, 12, 8, 13], index=4), - Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == ( False, - [Node(qubits=[0, 1, 2, 3, 4], is_boundary=True, index=0)], + [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_boundary=True, index=0)], 2, ) # extra boundary nodes = [ - Node(time=3, qubits=[7, 12, 8, 13], index=4), - Node(time=3, qubits=[11, 16, 12, 17], index=7), - Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] valid = valid and code.check_nodes(nodes) == (False, [], 0) # ignoring extra nodes = [ - Node(time=3, qubits=[7, 12, 8, 13], index=4), - Node(time=3, qubits=[11, 16, 12, 17], index=7), - Node(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), + DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), ] - valid = valid and code.check_nodes( - nodes, ignore_extra_boundary=True) == (True, [], 1) + valid = valid and code.check_nodes(nodes, ignore_extra_boundary=True) == (True, [], 1) - self.assertTrue( - valid, "A set of nodes did not give the expected outcome for check_nodes.") + self.assertTrue(valid, "A set of nodes did not give the expected outcome for check_nodes.") diff --git a/test/matching/test_pymatchingmatcher.py b/test/matching/test_pymatchingmatcher.py index b9b88fb5..e9378e1e 100644 --- a/test/matching/test_pymatchingmatcher.py +++ b/test/matching/test_pymatchingmatcher.py @@ -4,7 +4,8 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher -from qiskit_qec.analysis.decoding_graph import Node, Edge +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge + class TestPyMatchingMatcher(unittest.TestCase): """Tests for the pymatching matcher subroutines.""" @@ -18,27 +19,19 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: graph = rx.PyGraph(multigraph=False) idxmap = {} for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = Node( - time = 0, - qubits = q, - index = i - ) + node = DecodingGraphNode(time=0, qubits=q, index=i) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i node = {"time": 0, "qubits": [], "highlighted": False, "is_boundary": True} - node = Node( - is_boundary = True, - qubits = [], - index=0 - ) + node = DecodingGraphNode(is_boundary=True, qubits=[], index=0) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = Edge( - qubits = dat[0], - weight = 1, + edge = DecodingGraphEdge( + qubits=dat[0], + weight=1, ) edge.properties["measurement_error"] = False edge.properties["highlighted"] = False diff --git a/test/matching/test_retworkxmatcher.py b/test/matching/test_retworkxmatcher.py index f6215d8a..855a505a 100644 --- a/test/matching/test_retworkxmatcher.py +++ b/test/matching/test_retworkxmatcher.py @@ -4,7 +4,8 @@ from typing import Dict, Tuple import rustworkx as rx from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher -from qiskit_qec.analysis.decoding_graph import Node, Edge +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge + class TestRustworkxMatcher(unittest.TestCase): """Tests for the rustworkx matcher subroutines.""" @@ -18,27 +19,16 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: graph = rx.PyGraph(multigraph=False) idxmap = {} for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = Node( - time = 0, - qubits = q, - index=i - ) + node = DecodingGraphNode(time=0, qubits=q, index=i) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i - node = Node( - time = 0, - qubits = [], - index= i+1 - ) + node = DecodingGraphNode(time=0, qubits=[], index=i + 1) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = Edge( - qubits = dat[0], - weight = 1 - ) + edge = DecodingGraphEdge(qubits=dat[0], weight=1) edge.properties["measurement_error"] = False edge.properties["highlighted"] = False graph.add_edge(dat[1], dat[2], edge) From 73256085f090ea4da2f654d64512f42ea9303618 Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Wed, 8 Mar 2023 11:16:45 +0100 Subject: [PATCH 09/22] Move new Node and Edge types to utils and rename And lint and black --- src/qiskit_qec/utils/decoding_graph_attributes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 55850c56..8bfe5cbd 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -21,6 +21,8 @@ import itertools from typing import Any, Dict, List, Tuple, Optional +from qiskit_qec.exceptions import QiskitQECError + class DecodingGraphNode: def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: From 8a81c7e3aeda47eaf35108aed00c5847f2d7b8f7 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Mon, 6 Mar 2023 21:38:27 +0100 Subject: [PATCH 10/22] create CodeCircuit class (#329) * create CodeCircuit class * add more detail to init * add default is_cluster_neutral --- src/qiskit_qec/circuits/repetition_code.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 9924f9aa..ea841a92 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -26,8 +26,6 @@ from qiskit.transpiler import PassManager, InstructionDurations from qiskit.transpiler.passes import DynamicalDecoupling -from qiskit_qec.circuits.code_circuit import CodeCircuit - from qiskit_qec.circuits.code_circuit import CodeCircuit from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge From 7f5aaea3833503567d16ffb54acffbaff08872fc Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Wed, 8 Mar 2023 12:28:38 +0100 Subject: [PATCH 11/22] Lint and Black --- src/qiskit_qec/circuits/repetition_code.py | 2 +- src/qiskit_qec/circuits/surface_code.py | 2 +- .../decoders/circuit_matching_decoder.py | 4 +-- src/qiskit_qec/decoders/decoding_graph.py | 3 +- src/qiskit_qec/decoders/hdrg_decoders.py | 3 +- src/qiskit_qec/decoders/rustworkx_matcher.py | 5 +-- src/qiskit_qec/decoders/temp_graph_util.py | 21 ++++++++---- .../utils/decoding_graph_attributes.py | 34 ++++++++++++++++--- test/code_circuits/test_rep_codes.py | 2 +- test/code_circuits/test_surface_codes.py | 2 +- test/matching/test_retworkxmatcher.py | 5 +-- 11 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index ea841a92..3c53bb06 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -17,9 +17,9 @@ """Generates circuits based on repetition codes.""" from typing import List, Optional, Tuple +from copy import copy, deepcopy import numpy as np import rustworkx as rx -from copy import copy, deepcopy from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister, transpile from qiskit.circuit.library import XGate, RZGate diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 5cac3d7c..12dd36bb 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -18,7 +18,7 @@ from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.circuits.code_circuit import CodeCircuit diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py index 87443975..6f0e5cdf 100644 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ b/src/qiskit_qec/decoders/circuit_matching_decoder.py @@ -11,7 +11,7 @@ from qiskit import QuantumCircuit from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphEdge from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome @@ -225,7 +225,7 @@ def _process_graph( graph.add_edge(n0, n1, edge) # connect one of the boundaries at different times - if target.time == source.time or 0 + 1: + if target.time == (source.time or 0) + 1: if source.qubits == target.qubits == [0]: edge = DecodingGraphEdge(weight=0, qubits=[]) edge.properties["highlighted"] = False diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index e2d6c9c1..d6d6c39d 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -17,10 +17,9 @@ """ Graph used as the basis of decoders. """ -from dataclasses import dataclass, field import itertools import logging -from typing import Any, Dict, List, Tuple, Optional +from typing import List, Tuple import numpy as np import rustworkx as rx diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index c2e938d0..7d6ec401 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -18,8 +18,7 @@ from copy import copy, deepcopy from dataclasses import dataclass - -from typing import Dict, List, Set, Tuple, Tuple +from typing import Dict, List, Set, Tuple from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py index 59a9bdbf..55fd65d4 100644 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ b/src/qiskit_qec/decoders/rustworkx_matcher.py @@ -6,7 +6,7 @@ import rustworkx as rx from qiskit_qec.decoders.base_matcher import BaseMatcher -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphEdge class RustworkxMatcher(BaseMatcher): @@ -16,7 +16,8 @@ class RustworkxMatcher(BaseMatcher): and decoding_graph.Edge as the type of the edge payload. Additionally the edges are expected to have the following properties: - - edge.properties["measurement_error"] (bool): Whether or not the error corresponds to a measurement error. + - edge.properties["measurement_error"] (bool): Whether or not the error + corresponds to a measurement error. The annotated graph will also have "highlighted" properties on edges and vertices. """ diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index 396c4e06..fa4ef2b2 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -1,5 +1,6 @@ """Temporary module with methods for graphs.""" import json +import os import networkx as nx import rustworkx as rx @@ -35,10 +36,13 @@ def write_graph_to_json(graph: rx.PyGraph, filename: str): fp.close() -def get_cached_graph(file): - if os.path.isfile(file) and not os.stat(file) == 0: - with open(file, "r+") as f: - json_data = json.loads(f.read()) +def get_cached_graph(path): + """ + Returns graph cached in file at path "file" using cache_graph method. + """ + if os.path.isfile(path) and not os.stat(path) == 0: + with open(path, "r+", encoding="utf-8") as file: + json_data = json.loads(file.read()) net_graph = nx.node_link_graph(json_data) ret_graph = rx.networkx_converter(net_graph, keep_attributes=True) for node in ret_graph.nodes(): @@ -47,7 +51,10 @@ def get_cached_graph(file): return None -def cache_graph(graph, file): +def cache_graph(graph, path): + """ + Cache rustworkx PyGraph to file at path. + """ net_graph = ret2net(graph) - with open(file, "w+") as f: - json.dump(nx.node_link_data(net_graph), f) + with open(path, "w+", encoding="utf-8") as file: + json.dump(nx.node_link_data(net_graph), file) diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 8bfe5cbd..4e67b2d0 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -18,15 +18,30 @@ Graph used as the basis of decoders. """ from dataclasses import dataclass, field -import itertools -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional + +from qiskit_qec.exceptions import QiskitQECError from qiskit_qec.exceptions import QiskitQECError class DecodingGraphNode: + """ + Class to describe DecodingGraph nodes. + + Attributes: + - is_boundary (bool): whether or not the node is a boundary node. + - time (int): what syndrome node the node corrsponds to. Doesn't + need to be set if it's a boundary node. + - qubits (List[int]): List of indices which are stabilized by + this ancilla. + - index (int): Unique index in measurement round. + - properties (Dict[str, Any]): Decoder/code specific attributes. + Are not considered when comparing nodes. + """ + def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: - if not is_boundary and time == None: + if not is_boundary and time is None: raise QiskitQECError( "DecodingGraph node must either have a time or be a boundary node." ) @@ -35,8 +50,7 @@ def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) self.time: Optional[int] = time if not is_boundary else None self.qubits: List[int] = qubits self.index: int = index - # TODO: Should code/decoder specific properties be accounted for when comparing nodes - self.properties: Dict[str, Any] = dict() + self.properties: Dict[str, Any] = {} def __eq__(self, rhs): if not isinstance(rhs, DecodingGraphNode): @@ -61,6 +75,16 @@ def __iter__(self): @dataclass class DecodingGraphEdge: + """ + Class to describe DecodingGraph edges. + + Attributes: + - qubits (List[int]): List of indices of code qubits that correspond to this edge. + - weight (float): Weight of the edge. + - properties (Dict[str, Any]): Decoder/code specific attributes. + Are not considered when comparing edges. + """ + qubits: List[int] weight: float # TODO: Should code/decoder specific properties be accounted for when comparing edges diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 5b5265dd..c0338259 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -27,7 +27,7 @@ from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit as RepetitionCode from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index 25300b2e..43cf1259 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -19,7 +19,7 @@ import unittest from qiskit_qec.circuits.surface_code import SurfaceCodeCircuit -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphNode class TestSurfaceCodes(unittest.TestCase): diff --git a/test/matching/test_retworkxmatcher.py b/test/matching/test_retworkxmatcher.py index 855a505a..0f9f16ca 100644 --- a/test/matching/test_retworkxmatcher.py +++ b/test/matching/test_retworkxmatcher.py @@ -18,12 +18,13 @@ def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: """ graph = rx.PyGraph(multigraph=False) idxmap = {} - for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): + basic_config = [[0, 1], [1, 2], [2, 3], [3, 4]] + for i, q in enumerate(basic_config): node = DecodingGraphNode(time=0, qubits=q, index=i) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple(q))] = i - node = DecodingGraphNode(time=0, qubits=[], index=i + 1) + node = DecodingGraphNode(time=0, qubits=[], index=len(basic_config) + 1) node.properties["highlighted"] = False graph.add_node(node) idxmap[(0, tuple([]))] = 4 From 1cfce65130d783c75e8d051b6887c24a274d6f52 Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Thu, 9 Mar 2023 16:28:54 +0100 Subject: [PATCH 12/22] Update decoding graph caching to support new node tzpes --- src/qiskit_qec/decoders/temp_graph_util.py | 23 +++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index fa4ef2b2..e2fb199f 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -4,6 +4,8 @@ import networkx as nx import rustworkx as rx +from qiskit_qec.utils.decoding_graph_attributes import DecodingGraphEdge, DecodingGraphNode + def ret2net(graph: rx.PyGraph): """Convert rustworkx graph to equivalent networkx graph.""" @@ -36,7 +38,7 @@ def write_graph_to_json(graph: rx.PyGraph, filename: str): fp.close() -def get_cached_graph(path): +def get_cached_decoding_graph(path): """ Returns graph cached in file at path "file" using cache_graph method. """ @@ -45,13 +47,28 @@ def get_cached_graph(path): json_data = json.loads(file.read()) net_graph = nx.node_link_graph(json_data) ret_graph = rx.networkx_converter(net_graph, keep_attributes=True) - for node in ret_graph.nodes(): + for node_index, node in zip(ret_graph.node_indices(), ret_graph.nodes()): del node["__networkx_node__"] + qubits = node.pop("qubits") + time = node.pop("time") + index = node.pop("index") + is_boundary = node.pop("is_boundary") + properties = node.copy() + node = DecodingGraphNode(is_boundary=is_boundary, time=time, index=index, qubits=qubits) + node.properties = properties + ret_graph[node_index] = node + for edge_index, edge in zip(ret_graph.edge_indices(), ret_graph.edges()): + weight = edge.pop("weight") + qubits = edge.pop("qubits") + properties = edge.copy() + edge = DecodingGraphEdge(weight=weight, qubits=qubits) + edge.properties = properties + ret_graph.update_edge_by_index(edge_index, edge) return ret_graph return None -def cache_graph(graph, path): +def cache_decoding_graph(graph, path): """ Cache rustworkx PyGraph to file at path. """ From e8b489725262397182b623d4139208fe00ab956d Mon Sep 17 00:00:00 2001 From: Drew Vandeth <57962926+dsvandet@users.noreply.github.com> Date: Thu, 9 Mar 2023 11:26:25 -0500 Subject: [PATCH 13/22] Added Cmake as a requirement (#336) --- pyproject.toml | 2 +- requirements-dev.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9a77a43..4327b14b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "wheel", "pybind11~=2.6.1"] +requires = ["setuptools", "wheel", "cmake!=3.17.1,!=3.17.0", "pybind11~=2.6.1"] build-backend = "setuptools.build_meta" [tool.black] diff --git a/requirements-dev.txt b/requirements-dev.txt index 22f107be..56c9319f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ pylint==2.11.1 coverage==6.1.1 +cmake!=3.17.1,!=3.17.0 qiskit-sphinx-theme>=1.6 sphinx-autodoc-typehints jupyter-sphinx From d236e36c8197f6ed90ec43211ae12e524aff5b71 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 10 Mar 2023 12:11:26 +0100 Subject: [PATCH 14/22] element->index in decodoku --- src/qiskit_qec/utils/decodoku.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/qiskit_qec/utils/decodoku.py b/src/qiskit_qec/utils/decodoku.py index b147866f..fafcaada 100644 --- a/src/qiskit_qec/utils/decodoku.py +++ b/src/qiskit_qec/utils/decodoku.py @@ -331,7 +331,7 @@ def _generate_graph(self): { "y": 0, "x": (d - 1) * (elem == 1) - 1 * (elem == 0), - "element": elem, + "index": elem, "is_boundary": True, } ) @@ -340,10 +340,10 @@ def _generate_graph(self): nodes = dg.graph.nodes() # connect edges to boundary nodes for y in range(self.size): - n0 = nodes.index({"y": 0, "x": -1, "element": 0, "is_boundary": True}) + n0 = nodes.index({"y": 0, "x": -1, "index": 0, "is_boundary": True}) n1 = nodes.index({"y": y, "x": 0, "is_boundary": False}) dg.graph.add_edge(n0, n1, None) - n0 = nodes.index({"y": 0, "x": d - 1, "element": 1, "is_boundary": True}) + n0 = nodes.index({"y": 0, "x": d - 1, "index": 1, "is_boundary": True}) n1 = nodes.index({"y": y, "x": d - 2, "is_boundary": False}) dg.graph.add_edge(n0, n1, None) # connect bulk nodes with space-like edges @@ -536,7 +536,7 @@ def draw_graph(self, clusters=True): def get_label(node): if node["is_boundary"] and parity: - return str(parity[node["element"]]) + return str(parity[node["index"]]) elif node["highlighted"] and "value" in node and self.k != 2: return str(node["value"]) else: From 764f70653c6e55470f8524236f3d6a2d13b1d776 Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Fri, 10 Mar 2023 13:30:11 +0100 Subject: [PATCH 15/22] DecodingGraph, Tests, ARC: Fix things for PR --- src/qiskit_qec/circuits/repetition_code.py | 4 ++-- src/qiskit_qec/decoders/decoding_graph.py | 2 +- src/qiskit_qec/decoders/hdrg_decoders.py | 4 +--- src/qiskit_qec/utils/decoding_graph_attributes.py | 4 ++-- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 3c53bb06..f3f5d2b1 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -1344,7 +1344,7 @@ def _make_syndrome_graph(self): if new_node not in nodes: nodes.append(new_node) else: - node.time = 0 + node.time = None nodes.append(node) # find pairs that should be connected @@ -1353,7 +1353,7 @@ def _make_syndrome_graph(self): for n1, node1 in enumerate(nodes): if n0 < n1: # just record all possible edges for now (should be improved later) - dt = abs(node1.time - node0.time) + dt = abs((node1.time or 0) - (node0.time or 0)) adj = set(node0.qubits).intersection(set(node1.qubits)) if adj: if (node0.is_boundary ^ node1.is_boundary) or dt <= 1: diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index d6d6c39d..648cf749 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -2,7 +2,7 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2023. +# (C) Copyright IBM 2019-2023. # # This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 7d6ec401..6fe88c5d 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -534,6 +534,4 @@ def peeling(self, erasure: PyGraph) -> List[int]: ] erasure[pendant_vertex].properties["syndrome"] = False - return [ - erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits - ] + return [erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits] diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 4e67b2d0..0d6d8098 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -40,7 +40,7 @@ class DecodingGraphNode: Are not considered when comparing nodes. """ - def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: + def __init__(self, index: int, qubits: List[int] = None, is_boundary=False, time=None) -> None: if not is_boundary and time is None: raise QiskitQECError( "DecodingGraph node must either have a time or be a boundary node." @@ -48,7 +48,7 @@ def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) self.is_boundary: bool = is_boundary self.time: Optional[int] = time if not is_boundary else None - self.qubits: List[int] = qubits + self.qubits: List[int] = qubits if qubits else [] self.index: int = index self.properties: Dict[str, Any] = {} From 8f056f67f2baf87868d37d18f90ffaaf4b376b8a Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Fri, 10 Mar 2023 13:30:11 +0100 Subject: [PATCH 16/22] DecodingGraph, Tests, ARC: Fix things for PR --- src/qiskit_qec/circuits/repetition_code.py | 4 ++-- src/qiskit_qec/decoders/decoding_graph.py | 2 +- src/qiskit_qec/decoders/hdrg_decoders.py | 4 +--- src/qiskit_qec/utils/decoding_graph_attributes.py | 6 ++---- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 3c53bb06..f3f5d2b1 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -1344,7 +1344,7 @@ def _make_syndrome_graph(self): if new_node not in nodes: nodes.append(new_node) else: - node.time = 0 + node.time = None nodes.append(node) # find pairs that should be connected @@ -1353,7 +1353,7 @@ def _make_syndrome_graph(self): for n1, node1 in enumerate(nodes): if n0 < n1: # just record all possible edges for now (should be improved later) - dt = abs(node1.time - node0.time) + dt = abs((node1.time or 0) - (node0.time or 0)) adj = set(node0.qubits).intersection(set(node1.qubits)) if adj: if (node0.is_boundary ^ node1.is_boundary) or dt <= 1: diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index d6d6c39d..648cf749 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -2,7 +2,7 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2023. +# (C) Copyright IBM 2019-2023. # # This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 7d6ec401..6fe88c5d 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -534,6 +534,4 @@ def peeling(self, erasure: PyGraph) -> List[int]: ] erasure[pendant_vertex].properties["syndrome"] = False - return [ - erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits - ] + return [erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits] diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 4e67b2d0..c06f9944 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -22,8 +22,6 @@ from qiskit_qec.exceptions import QiskitQECError -from qiskit_qec.exceptions import QiskitQECError - class DecodingGraphNode: """ @@ -40,7 +38,7 @@ class DecodingGraphNode: Are not considered when comparing nodes. """ - def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) -> None: + def __init__(self, index: int, qubits: List[int] = None, is_boundary=False, time=None) -> None: if not is_boundary and time is None: raise QiskitQECError( "DecodingGraph node must either have a time or be a boundary node." @@ -48,7 +46,7 @@ def __init__(self, qubits: List[int], index: int, is_boundary=False, time=None) self.is_boundary: bool = is_boundary self.time: Optional[int] = time if not is_boundary else None - self.qubits: List[int] = qubits + self.qubits: List[int] = qubits if qubits else [] self.index: int = index self.properties: Dict[str, Any] = {} From 02b817f85b4ef51b12e48d8111b67ef7bb84ce06 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 10 Mar 2023 15:23:22 +0100 Subject: [PATCH 17/22] remove unused import --- src/qiskit_qec/decoders/hdrg_decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 6fe88c5d..e2d26c2b 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -18,7 +18,7 @@ from copy import copy, deepcopy from dataclasses import dataclass -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit, RepetitionCodeCircuit From cab22a9c6b3aaab7df05ecda757b4837f0e10614 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Tue, 14 Mar 2023 11:14:06 +0100 Subject: [PATCH 18/22] use all logicals --- src/qiskit_qec/decoders/hdrg_decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 39f47c4a..d3c7be75 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -189,7 +189,7 @@ def process(self, string): decoding_graph = self.decoding_graph # turn string into nodes and cluster - nodes = code.string2nodes(string) + nodes = code.string2nodes(string, all_logicals=True) clusters = self.cluster(nodes) # get the list of bulk nodes for each cluster From 2c8ea1b853a2b1bef2f45daab9b6292ffb933294 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Tue, 14 Mar 2023 11:17:04 +0100 Subject: [PATCH 19/22] Added Cmake as a requirement (#336) (#345) Co-authored-by: Drew Vandeth <57962926+dsvandet@users.noreply.github.com> --- pyproject.toml | 2 +- requirements-dev.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9a77a43..4327b14b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "wheel", "pybind11~=2.6.1"] +requires = ["setuptools", "wheel", "cmake!=3.17.1,!=3.17.0", "pybind11~=2.6.1"] build-backend = "setuptools.build_meta" [tool.black] diff --git a/requirements-dev.txt b/requirements-dev.txt index 22f107be..56c9319f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ pylint==2.11.1 coverage==6.1.1 +cmake!=3.17.1,!=3.17.0 qiskit-sphinx-theme>=1.6 sphinx-autodoc-typehints jupyter-sphinx From fc893049d5456d1beec641d1a934d71497121a54 Mon Sep 17 00:00:00 2001 From: Tommaso Peduzzi Date: Tue, 14 Mar 2023 11:39:32 +0100 Subject: [PATCH 20/22] DecodingGraph: Add optional graph attribute for cached graphs --- src/qiskit_qec/decoders/decoding_graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 648cf749..bf0e6e4e 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -40,7 +40,7 @@ class DecodingGraph: METHOD_NAIVE: str = "naive" AVAILABLE_METHODS = {METHOD_SPITZ, METHOD_NAIVE} - def __init__(self, code, brute=False): + def __init__(self, code, brute=False, graph=None): """ Args: code (CodeCircuit): The QEC code circuit object for which this decoding @@ -52,7 +52,10 @@ def __init__(self, code, brute=False): self.code = code self.brute = brute - self._make_syndrome_graph() + if graph: + self.graph = graph + else: + self._make_syndrome_graph() def _make_syndrome_graph(self): if not self.brute and hasattr(self.code, "_make_syndrome_graph"): From 4f795afcf3515a218b3424b20cda4cdd175ae7c1 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Tue, 14 Mar 2023 14:45:38 +0100 Subject: [PATCH 21/22] Update hdrg_decoders.py --- src/qiskit_qec/decoders/hdrg_decoders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 3a7b635f..a9ce0ed8 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -24,7 +24,6 @@ from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge -from qiskit_qec.exceptions import QiskitQECError class ClusteringDecoder: From a285e66b7631407997cef6dfb510af9381409bf7 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Tue, 14 Mar 2023 14:55:12 +0100 Subject: [PATCH 22/22] Update hdrg_decoders.py --- src/qiskit_qec/decoders/hdrg_decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index a9ce0ed8..052d4066 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -126,7 +126,7 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] for element, z_logical in enumerate(self.measured_logicals): - node = DecodingGraphNode(is_boundary=True, qubits=[z_logical], index=element) + node = DecodingGraphNode(is_boundary=True, qubits=z_logical, index=element) if isinstance(self.code, ArcCircuit): node.properties["link qubit"] = None boundary_nodes.append(node)