From 031fb3491d19b7a2cd95e76bbaf44c45d61a824c Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Wed, 25 Oct 2023 14:01:31 +0200 Subject: [PATCH] fix return base_name, qcircuit.uncompute_all --- qlasskit/ast2logic/t_ast.py | 2 +- qlasskit/ast2logic/typing.py | 2 +- qlasskit/compiler/poccompiler2.py | 34 ++++++++++--------------- qlasskit/qcircuit.py | 42 +++++++++++++++++++++---------- test/test_qcircuit.py | 21 ++++++++++++++-- 5 files changed, 63 insertions(+), 38 deletions(-) diff --git a/qlasskit/ast2logic/t_ast.py b/qlasskit/ast2logic/t_ast.py index cdedfd4b..7959460d 100644 --- a/qlasskit/ast2logic/t_ast.py +++ b/qlasskit/ast2logic/t_ast.py @@ -38,7 +38,7 @@ def translate_ast(fun, types) -> LogicFun: if not fun.returns: raise exceptions.NoReturnTypeException() - ret_ = translate_argument(fun.returns, env) + ret_ = translate_argument(fun.returns, env, base="_ret") exps = [] for stmt in fun.body: diff --git a/qlasskit/ast2logic/typing.py b/qlasskit/ast2logic/typing.py index 819215e7..c7dd29ca 100644 --- a/qlasskit/ast2logic/typing.py +++ b/qlasskit/ast2logic/typing.py @@ -25,7 +25,7 @@ def __init__(self, name: str, ttype: object, bitvec: List[str]): self.bitvec = bitvec def __repr__(self): - return f"{self.name} - {self.ttype} - {', '.join(self.bitvec)}" + return f"{self.name} - {self.ttype} - [{', '.join(self.bitvec)}]" def __len__(self) -> int: return len(self.bitvec) diff --git a/qlasskit/compiler/poccompiler2.py b/qlasskit/compiler/poccompiler2.py index 4b77498b..852cd3ff 100644 --- a/qlasskit/compiler/poccompiler2.py +++ b/qlasskit/compiler/poccompiler2.py @@ -24,12 +24,9 @@ class POCCompiler2(Compiler): """POC2 compiler translating an expression list to quantum circuit""" - def garbage_collect(self, qc): - uncomputed = qc.uncompute() - self.expqmap.remove(uncomputed) - def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircuit: qc = QCircuit(name=name) + self.expqmap = ExpQMap() for arg in args: for arg_b in arg.bitvec: @@ -37,7 +34,8 @@ def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircui qc.add_qubit(arg_b) # qc.ancilla_lst.add(qi) - self.expqmap = ExpQMap() + # TODO: this is redundant, since we also have qc[] + # self.expqmap[Symbol(arg_b)] = qi for sym, exp in exprs: is_temp = sym.name[0:2] == "__" @@ -48,19 +46,21 @@ def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircui self.expqmap[sym] = iret qc.map_qubit(sym, iret, promote=not is_temp) - self.garbage_collect(qc) - - # print(sym, exp) - # circ_qi = qc.export("circuit", "qiskit") - # print(circ_qi.draw("text")) - # print() - # print() + # Remove all the temp qubits + self.expqmap.remove(qc.uncompute()) qc.remove_identities() + qc.uncompute_all(keep=[qc[r] for r in returns.bitvec]) + + # circ_qi = qc.export("circuit", "qiskit") + # print(circ_qi.draw("text")) + # print() + # print() + return qc def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa: C901 - if isinstance(expr, Symbol): + if isinstance(expr, Symbol) and expr.name in qc: return qc[expr.name] elif expr in self.expqmap: @@ -81,8 +81,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa: qc.cx(eret, dest) qc.x(dest) qc.mark_ancilla(eret) - - self.garbage_collect(qc) self.expqmap[expr] = dest return dest @@ -93,13 +91,9 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa: dest = qc.get_free_ancilla() qc.barrier("and") - qc.mcx(erets, dest) [qc.mark_ancilla(eret) for eret in erets] - - self.garbage_collect(qc) - self.expqmap[expr] = dest return dest @@ -125,8 +119,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa: qc.cx(x, dest) [qc.mark_ancilla(eret) for eret in erets] - self.garbage_collect(qc) - return dest elif isinstance(expr, BooleanFalse): diff --git a/qlasskit/qcircuit.py b/qlasskit/qcircuit.py index 8b81a7b7..e5f99eb7 100644 --- a/qlasskit/qcircuit.py +++ b/qlasskit/qcircuit.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import List, Literal, Union, get_args from sympy import Symbol @@ -130,38 +131,53 @@ def mark_ancilla(self, w): if w in self.ancilla_lst: self.marked_ancillas.add(w) - def uncompute_all(self, keep=[]): - """Uncompute the whole circuit expect for the keep""" - pass + def uncompute_all(self, keep: List[Union[Symbol, int]] = []): + """Uncompute the whole circuit expect for the keep (symbols or qubit)""" + # TODO: replace with + invert(keep) + scopy = copy.deepcopy(self.gates) + uncomputed = set() + self.barrier(label="un_all") + for g, qbs in reversed(scopy): + if qbs[-1] in keep or g == "bar" or qbs[-1] in self.free_ancilla_lst: + continue + uncomputed.add(qbs[-1]) + + if qbs[-1] in self.ancilla_lst: + self.free_ancilla_lst.add(qbs[-1]) + + self.append(g, qbs) + + # Remove barrier if no uncomputed + if len(uncomputed) == 0: + self.gates.pop() + + return uncomputed def uncompute(self, to_mark=[]): """Uncompute all the marked ancillas plus the to_mark list""" [self.mark_ancilla(x) for x in to_mark] + if len(self.marked_ancillas) == 0: + return [] + self.barrier(label="un") uncomputed = set() new_gates_comp = [] - not_to_uncompute = set() - - for g, ws in self.gates_computed[::-1]: - if ws[-1] in self.marked_ancillas and not all( - [ww in self.marked_ancillas for ww in ws[:-1]] - ): - not_to_uncompute.add(ws[-1]) - for g, ws in self.gates_computed[::-1]: - if ws[-1] in self.marked_ancillas and ws[-1] not in not_to_uncompute: + for g, ws in reversed(self.gates_computed): + if ws[-1] in self.marked_ancillas: uncomputed.add(ws[-1]) self.append(g, ws) else: new_gates_comp.append((g, ws)) - for x in uncomputed: + for x in self.marked_ancillas: self.free_ancilla_lst.add(x) self.marked_ancillas = self.marked_ancillas - uncomputed self.gates_computed = new_gates_comp[::-1] + # Remove barrier if no uncomputed if len(uncomputed) == 0: self.gates.pop() diff --git a/test/test_qcircuit.py b/test/test_qcircuit.py index 7b95758c..1cc47c36 100644 --- a/test/test_qcircuit.py +++ b/test/test_qcircuit.py @@ -83,10 +83,27 @@ def test2(self): qc.mcx(q, a[0]) qc.mcx(q + [a[0]], a[1]) - qc.mcx(q + a[:1], a[2]) - qc.mcx(q + a[:2], a[3]) + qc.mcx(q + a[:2], a[2]) + qc.mcx(q + a[:3], a[3]) qc.cx(a[3], r) qc.uncompute(a) + + self.assertEqual(len(qc.free_ancilla_lst), len(a)) + # qc.draw() + + def test_uncompute_all(self): + qc = QCircuit() + q = [qc.add_qubit() for x in range(4)] + a = [qc.add_ancilla(is_free=False) for x in range(4)] + r = qc.add_qubit() + + qc.mcx(q, a[0]) + qc.mcx(q + [a[0]], a[1]) + qc.mcx(q + a[:2], a[2]) + qc.mcx(q + a[:3], a[3]) + qc.cx(a[3], r) + qc.uncompute(a) + qc.uncompute_all([r]) # qc.draw()