Skip to content

Commit

Permalink
Fix drawing of stabilizer graph for zero syndrome (#9)
Browse files Browse the repository at this point in the history
* Fix drawing of stabilizer graph for zero syndrome

* Ran black

* Fix docstring

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Fix docstring

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Remove unused 'code' argument

* Update stabilizer_graph.py

* Generalize stab graph edge weights assignement

* Add conversion to nx and default drawing to rx

* Update flamingpy/codes/graphs/stabilizer_graph.py

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Update flamingpy/codes/graphs/stabilizer_graph.py

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Add check for high low nodes before setting weight

* Ran formatter

* Update flamingpy/codes/graphs/stabilizer_graph.py

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Fix rx edge weight function

* Update flamingpy/codes/graphs/stabilizer_graph.py

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Remove code argument for shortest path

* Ran formatter

* Update stabilizer_graph.py

* Run black, docformatter

* Update _version.py

Co-authored-by: maxtremblay <m@xtremblay.ca>
Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 12, 2022
1 parent bd85114 commit b905698
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 63 deletions.
2 changes: 1 addition & 1 deletion flamingpy/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Version number (major.minor.patch[label])"""


__version__ = "0.5.0a2"
__version__ = "0.6.0a3"
127 changes: 80 additions & 47 deletions flamingpy/codes/graphs/stabilizer_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ class StabilizerGraph(ABC):
many nodes in order to construct the matching graph.
Parameters:
ec (str): the error complex ("primal" or "dual"). Determines whether
ec (str, optional): the error complex ("primal" or "dual"). Determines whether
the graph is generated from primal or dual stabilizers in the code.
code (SurfaceCode): the code from which to initialize the graph.
code (SurfaceCode, optional): the code from which to initialize the graph.
Note:
Both ec and code must be provided to initialize the graph with the proper edges.
If one of them is not provided, the graph is left empty.
Attributes:
stabilizers (List[Stabilizer]): All the stabilizer nodes of the graph.
Expand All @@ -64,13 +68,14 @@ class StabilizerGraph(ABC):
translating the outcomes.
"""

def __init__(self, ec, code=None):
# pylint: disable=too-many-public-methods
def __init__(self, ec=None, code=None):
self.add_node("low")
self.add_node("high")
self.stabilizers = []
self.low_bound_points = []
self.high_bound_points = []
if code is not None:
if code is not None and ec is not None:
self.add_stabilizers(getattr(code, ec + "_stabilizers"))
bound_points = getattr(code, ec + "_bound_points")
mid = int(len(bound_points) / 2)
Expand All @@ -92,6 +97,10 @@ def add_node(self, node):
"""
raise NotImplementedError

def nodes(self):
"""Return an iterable of all nodes in the graph."""
raise NotImplementedError

def add_edge(
self,
node1,
Expand Down Expand Up @@ -131,12 +140,15 @@ def edges(self):
graph."""
raise NotImplementedError

def shortest_paths_without_high_low(self, source, code):
def shortest_paths_without_high_low(self, source):
"""Compute the shortest path from source to every other node in the
graph, except the 'high' and 'low' connector.
This assumes that the edge weights are asssigned.
Note: a path can't use the 'high' and 'low' node.
Arguments:
source: The source node for each path.
Returns:
Expand All @@ -147,21 +159,25 @@ def shortest_paths_without_high_low(self, source, code):

raise NotImplementedError

def shortest_paths_from_high(self, code):
def shortest_paths_from_high(self):
"""Compute the shortest path from the 'high' node to every other node
in the graph.
This assumes that the edge weights are asssigned.
Returns:
(dict, dict): The first dictionary maps a target node to the weight
of the corresponding path. The second one maps a target node to
the list of nodes along the corresponding path.
"""
raise NotImplementedError

def shortest_paths_from_low(self, code):
def shortest_paths_from_low(self):
"""Compute the shortest path from the 'low' node to every other node in
the graph.
This assumes that the edge weights are asssigned.
Returns:
(dict, dict): The first dictionary maps a target node to the weight
of the corresponding path. The second one maps a target node to
Expand Down Expand Up @@ -289,20 +305,44 @@ def real_nodes(self):
def real_edges(self):
"""Returns an iterable of all edges excluding the ones connected to the
'low' or 'high' points."""
is_high_or_low = lambda n: n in ("low", "high")
return filter(
lambda edge: not (is_high_or_low(edge[0]) or is_high_or_low(edge[1])),
lambda edge: edge[0] not in ("low", "high") and edge[1] not in ("low", "high"),
self.edges(),
)

def assign_weights(self, code):
"""Assign the weights to the graph based on the weight of the common
vertex of each stabilizer pair of the code."""
for edge in self.edges():
data = self.edge_data(*edge)
if data["common_vertex"] is not None:
data["weight"] = code.graph.nodes[data["common_vertex"]].get("weight")
elif "high" in edge or "low" in edge:
data["weight"] = 0

def to_nx(self):
"""Convert the same graph into a NxStabilizerGraph.
This involves converting the graph representation to a networkx
graph representation.
"""
if isinstance(self, NxStabilizerGraph):
return self
nx_graph = NxStabilizerGraph()
for edge in self.edges():
nx_graph.add_edge(*edge, self.edge_data(*edge)["common_vertex"])
if "weight" in self.edge_data(*edge):
nx_graph.edge_data(*edge)["weight"] = self.edge_data(*edge)["weight"]
return nx_graph

def draw(self, **kwargs):
"""Draw the stabilizer graph with matplotlib.
See flamingpy.utils.viz.draw_dec_graph for more details.
"""
from flamingpy.utils.viz import draw_dec_graph

draw_dec_graph(self, **kwargs)
draw_dec_graph(self.to_nx(), **kwargs)


class NxStabilizerGraph(StabilizerGraph):
Expand All @@ -314,14 +354,17 @@ class NxStabilizerGraph(StabilizerGraph):
graph (networkx.Graph): The actual graph backend.
"""

def __init__(self, ec, code=None):
def __init__(self, ec=None, code=None):
self.graph = nx.Graph()
StabilizerGraph.__init__(self, ec, code)

def add_node(self, node):
self.graph.add_node(node)
return self

def nodes(self):
return self.graph.nodes

def add_edge(self, node1, node2, common_vertex=None):
self.graph.add_edge(node1, node2, common_vertex=common_vertex)
return self
Expand All @@ -332,26 +375,21 @@ def edge_data(self, node1, node2):
def edges(self):
return self.graph.edges()

def shortest_paths_without_high_low(self, source, code):
def shortest_paths_without_high_low(self, source):
subgraph = self.graph.subgraph(
self.stabilizers + self.low_bound_points + self.high_bound_points
)
return nx_shortest_paths_from(subgraph, source, code)
return nx_shortest_paths_from(subgraph, source)

def shortest_paths_from_high(self, code):
return nx_shortest_paths_from(self.graph, "high", code)
def shortest_paths_from_high(self):
return nx_shortest_paths_from(self.graph, "high")

def shortest_paths_from_low(self, code):
return nx_shortest_paths_from(self.graph, "low", code)
def shortest_paths_from_low(self):
return nx_shortest_paths_from(self.graph, "low")


def nx_shortest_paths_from(graph, source, code):
def nx_shortest_paths_from(graph, source):
"""The NetworkX shortest path implementation."""
for edge in graph.edges.data():
if edge[2].get("common_vertex") is not None:
edge[2]["weight"] = code.graph.nodes[edge[2]["common_vertex"]]["weight"]
else:
edge[2]["weight"] = 0.0
(weights, paths) = sp.single_source_dijkstra(graph, source)
del weights[source]
del paths[source]
Expand All @@ -372,7 +410,7 @@ class RxStabilizerGraph(StabilizerGraph):
corresponding nodes.
"""

def __init__(self, ec, code=None):
def __init__(self, ec=None, code=None):
self.graph = rx.PyGraph()
self.node_to_index = {}
self.index_to_node = {}
Expand All @@ -384,6 +422,9 @@ def add_node(self, node):
self.index_to_node[index] = node
return self

def nodes(self):
return self.graph.nodes()

def add_edge(self, node1, node2, common_vertex=None):
index1 = self.node_to_index[node1]
index2 = self.node_to_index[node2]
Expand All @@ -401,37 +442,35 @@ def edges(self):
for edge in self.graph.edge_list()
)

def shortest_paths_without_high_low(self, source, code):
def shortest_paths_without_high_low(self, source):
subgraph = self.graph.copy() # This is a shallow copy.
# We know that nodes 0 and 1 are the 'high' and 'low' nodes.
subgraph.remove_nodes_from([0, 1])
return self._shortest_paths_from(subgraph, source, code)
return self._shortest_paths_from(subgraph, source)

def shortest_paths_from_high(self, code):
return self._shortest_paths_from(self.graph, "high", code)
def shortest_paths_from_high(self):
return self._shortest_paths_from(self.graph, "high")

def shortest_paths_from_low(self, code):
return self._shortest_paths_from(self.graph, "low", code)
def shortest_paths_from_low(self):
return self._shortest_paths_from(self.graph, "low")

# The following methods are helpers for the shortest paths methods.

def _shortest_paths_from(self, graph, source, code):
def _shortest_paths_from(self, graph, source):
paths = rx.graph_dijkstra_shortest_paths(
graph, self.node_to_index[source], weight_fn=rx_weight_fn(code)
graph, self.node_to_index[source], weight_fn=rx_weight_fn
)
return self._all_path_weights(paths, code), self._all_path_nodes(paths)
return self._all_path_weights(paths), self._all_path_nodes(paths)

def _all_path_weights(self, paths, code):
def _all_path_weights(self, paths):
return {
self.index_to_node[target]: self._path_weight(path, code)
for (target, path) in paths.items()
self.index_to_node[target]: self._path_weight(path) for (target, path) in paths.items()
}

def _path_weight(self, path, code):
weight_fn = rx_weight_fn(code)
def _path_weight(self, path):
weight = 0
for e in range(len(path) - 1):
weight += int(weight_fn(self.graph.get_edge_data(path[e], path[e + 1])))
weight += int(rx_weight_fn(self.graph.get_edge_data(path[e], path[e + 1])))
return weight

def _all_path_nodes(self, paths):
Expand All @@ -446,12 +485,6 @@ def _path_nodes(self, path):
return nodes


def rx_weight_fn(code):
def rx_weight_fn(edge):
"""A function for returning the weight from the common vertex."""

def fn(edge):
if edge["common_vertex"] is not None:
return float(code.graph.nodes[edge["common_vertex"]]["weight"])
return 0.0

return fn
return float(edge["weight"])
9 changes: 9 additions & 0 deletions flamingpy/codes/surface_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,12 @@ def draw(self, **kwargs):
}
updated_opts = {**default_opts, **kwargs}
return self.graph.draw(**updated_opts)

def draw_stabilizer_graph(self, ec, **kwargs):
"""Draw the stabilizer graph with matplotlib.
See flamingpy.utils.viz.draw_dec_graph for more details.
"""
graph = getattr(self, ec + "_stab_graph")
graph.assign_weights(self)
return graph.draw(**kwargs)
3 changes: 3 additions & 0 deletions flamingpy/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def assign_weights(code, **kwargs):
if weight_options.get("method") == "unit":
for node in syndrome_coords:
G.nodes[node]["weight"] = 1
# Also assign the weights to the stabilizer graph edges.
for ec in code.ec:
getattr(code, f"{ec}_stab_graph").assign_weights(code)


def CV_decoder(code, translator=GKP_binner):
Expand Down
16 changes: 10 additions & 6 deletions flamingpy/decoders/mwpm/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,13 @@ def with_edges_from(self, code, ec):
code (SurfaceCode): The code from which to build the edges.
"""
stab_graph = getattr(code, ec + "_stab_graph")
self = self._with_edges_between_real_odd_nodes(code, ec)
if stab_graph.has_bound_points():
return self._with_edges_from_low_or_high_connector(code, ec)
odd_stabilizers = list(stab_graph.odd_parity_stabilizers())
# If the syndrome is trivial, we don't need to add edges into
# the matching graph.
if len(odd_stabilizers) > 0:
self = self._with_edges_between_real_odd_nodes(code, ec)
if stab_graph.has_bound_points():
return self._with_edges_from_low_or_high_connector(code, ec)
return self

def _with_edges_between_real_odd_nodes(self, code, ec):
Expand All @@ -128,7 +132,7 @@ def _with_edges_between_real_odd_nodes(self, code, ec):
odd_adjacency[pair[0]] += [pair[1]]
# Find the shortest paths between odd-parity stabs.
for stab1 in odd_parity_stabs[:-1]:
lengths, paths = stab_graph.shortest_paths_without_high_low(stab1, code)
lengths, paths = stab_graph.shortest_paths_without_high_low(stab1)
for stab2 in odd_adjacency[stab1]:
length = lengths[stab2]
path = paths[stab2]
Expand All @@ -139,8 +143,8 @@ def _with_edges_between_real_odd_nodes(self, code, ec):

def _with_edges_from_low_or_high_connector(self, code, ec):
stab_graph = getattr(code, ec + "_stab_graph")
low_lengths, low_paths = stab_graph.shortest_paths_from_low(code)
high_lengths, high_paths = stab_graph.shortest_paths_from_high(code)
low_lengths, low_paths = stab_graph.shortest_paths_from_low()
high_lengths, high_paths = stab_graph.shortest_paths_from_high()
for i, cube in enumerate(stab_graph.odd_parity_stabilizers()):
distances = (low_lengths[cube], high_lengths[cube])
where_shortest = np.argmin(distances)
Expand Down
13 changes: 10 additions & 3 deletions flamingpy/examples/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@

# Code and code lattice (cluster state)
RHG_code = SurfaceCode(
distance=distance, ec=ec, boundaries=boundaries, polarity=alternating_polarity
distance=distance,
ec=ec,
boundaries=boundaries,
polarity=alternating_polarity,
backend="retworkx",
)
RHG_lattice = RHG_code.graph

Expand Down Expand Up @@ -93,11 +97,14 @@

# An integer label for each nodes in the stabilizer and matching graphs.
# This is useful to identify the nodes in the plots.
node_labels = {node: index for index, node in enumerate(G_stabilizer.graph)}
node_labels = {node: index for index, node in enumerate(G_stabilizer.nodes())}

# The draw_dec_graph function requires the networkx backend. Most backends implement
# the to_nx() method to perform the conversion if needed.
G_stabilizer.draw(title=ec.capitalize() + " stabilizer graph", node_labels=node_labels)
RHG_code.draw_stabilizer_graph(
ec, title=ec.capitalize() + " stabilizer graph", node_labels=node_labels
)

ax = viz.syndrome_plot(RHG_code, ec, drawing_opts=dw, index_dict=node_labels)
viz.draw_matching_on_syndrome_plot(ax, matching, G_stabilizer, G_match, dw.get("label_edges"))
if len(G_match.graph):
Expand Down
Loading

0 comments on commit b905698

Please sign in to comment.