diff --git a/pyzx/hrules.py b/pyzx/hrules.py index 13d36f44..f362d2fc 100644 --- a/pyzx/hrules.py +++ b/pyzx/hrules.py @@ -97,17 +97,17 @@ def fuse_hboxes(g: BaseGraph[VT,ET], matches: List[ET]) -> rules.RewriteOutputTy etab[upair(v1,n)] = [1,0] else: etab[upair(v1,n)] = [0,1] - + return (etab, rem_verts, [], True) -MatchCopyType = Tuple[VT,VT,VertexType,FractionLike,FractionLike,List[VT]] +MatchCopyType = Tuple[VT,VT,VertexType,FractionLike,FractionLike,List[ET]] def match_copy( g: BaseGraph[VT,ET], vertexf:Optional[Callable[[VT],bool]]=None - ) -> List[MatchCopyType[VT]]: + ) -> List[MatchCopyType[VT, ET]]: """Finds arity-1 spiders (with a 0 or pi phase) that can be copied through their neighbor.""" if vertexf is not None: candidates = set([v for v in g.vertices() if vertexf(v)]) else: candidates = g.vertex_set() @@ -159,7 +159,8 @@ def match_copy( else: continue neigh = [n for n in g.neighbors(w) if n != v] - m.append((v,w,copy_type,phases[v],phases[w],neigh)) + neigh_edges = [e for e in g.incident_edges(w) if v not in g.edge_st(e)] + m.append((v,w,copy_type,phases[v],phases[w],neigh_edges)) candidates.discard(w) candidates.difference_update(neigh) taken.add(w) @@ -168,36 +169,37 @@ def match_copy( return m def apply_copy( - g: BaseGraph[VT,ET], - matches: List[MatchCopyType[VT]] + g: BaseGraph[VT,ET], + matches: List[MatchCopyType[VT, ET]] ) -> rules.RewriteOutputType[VT,ET]: """Copy arity-1 spider through their neighbor.""" rem = [] types = g.types() - for v,w,copy_type,a,alpha,neigh in matches: + for v,w,copy_type,a,alpha,neigh_edges in matches: rem.append(v) - if copy_type == VertexType.BOUNDARY: + if copy_type == VertexType.BOUNDARY: g.scalar.add_power(1) continue # Don't have to do anything more for this case rem.append(w) if vertex_is_zx(types[w]): if a: g.scalar.add_phase(alpha) - g.scalar.add_power(-(len(neigh)-1)) + g.scalar.add_power(-(len(neigh_edges)-1)) else: #types[w] == H_BOX if copy_type == VertexType.Z: g.scalar.add_power(1) else: - g.scalar.add_power(-(len(neigh)-2)) + g.scalar.add_power(-(len(neigh_edges)-2)) if alpha != 1: g.scalar.add_power(-2) g.scalar.add_node(alpha+1) - for n in neigh: + for edge in neigh_edges: + st = g.edge_st(edge) + n = st[0] if st[1] == w else st[1] r = 0.7*g.row(w) + 0.3*g.row(n) q = 0.7*g.qubit(w) + 0.3*g.qubit(n) - + u = g.add_vertex(copy_type, q, r, a) - e = g.edge(n,w) - et = g.edge_type(e) + et = g.edge_type(edge) g.add_edge((n,u), et) return ({}, rem, [], True)