Skip to content

Commit

Permalink
fix return base_name, qcircuit.uncompute_all
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 25, 2023
1 parent 640096e commit 031fb34
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 38 deletions.
2 changes: 1 addition & 1 deletion qlasskit/ast2logic/t_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def translate_ast(fun, types) -> LogicFun:
if not fun.returns:
raise exceptions.NoReturnTypeException()

ret_ = translate_argument(fun.returns, env)
ret_ = translate_argument(fun.returns, env, base="_ret")

exps = []
for stmt in fun.body:
Expand Down
2 changes: 1 addition & 1 deletion qlasskit/ast2logic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, name: str, ttype: object, bitvec: List[str]):
self.bitvec = bitvec

def __repr__(self):
return f"{self.name} - {self.ttype} - {', '.join(self.bitvec)}"
return f"{self.name} - {self.ttype} - [{', '.join(self.bitvec)}]"

def __len__(self) -> int:
return len(self.bitvec)
Expand Down
34 changes: 13 additions & 21 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,18 @@
class POCCompiler2(Compiler):
"""POC2 compiler translating an expression list to quantum circuit"""

def garbage_collect(self, qc):
uncomputed = qc.uncompute()
self.expqmap.remove(uncomputed)

def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircuit:
qc = QCircuit(name=name)
self.expqmap = ExpQMap()

for arg in args:
for arg_b in arg.bitvec:
# qi =
qc.add_qubit(arg_b)
# qc.ancilla_lst.add(qi)

self.expqmap = ExpQMap()
# TODO: this is redundant, since we also have qc[]
# self.expqmap[Symbol(arg_b)] = qi

for sym, exp in exprs:
is_temp = sym.name[0:2] == "__"
Expand All @@ -48,19 +46,21 @@ def compile(self, name, args: Args, returns: Arg, exprs: BoolExpList) -> QCircui
self.expqmap[sym] = iret
qc.map_qubit(sym, iret, promote=not is_temp)

self.garbage_collect(qc)

# print(sym, exp)
# circ_qi = qc.export("circuit", "qiskit")
# print(circ_qi.draw("text"))
# print()
# print()
# Remove all the temp qubits
self.expqmap.remove(qc.uncompute())

qc.remove_identities()
qc.uncompute_all(keep=[qc[r] for r in returns.bitvec])

# circ_qi = qc.export("circuit", "qiskit")
# print(circ_qi.draw("text"))
# print()
# print()

return qc

def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa: C901
if isinstance(expr, Symbol):
if isinstance(expr, Symbol) and expr.name in qc:
return qc[expr.name]

elif expr in self.expqmap:
Expand All @@ -81,8 +81,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa:
qc.cx(eret, dest)
qc.x(dest)
qc.mark_ancilla(eret)

self.garbage_collect(qc)
self.expqmap[expr] = dest

return dest
Expand All @@ -93,13 +91,9 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa:
dest = qc.get_free_ancilla()

qc.barrier("and")

qc.mcx(erets, dest)

[qc.mark_ancilla(eret) for eret in erets]

self.garbage_collect(qc)

self.expqmap[expr] = dest

return dest
Expand All @@ -125,8 +119,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean, dest=None) -> int: # noqa:
qc.cx(x, dest)

[qc.mark_ancilla(eret) for eret in erets]
self.garbage_collect(qc)

return dest

elif isinstance(expr, BooleanFalse):
Expand Down
42 changes: 29 additions & 13 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Literal, Union, get_args

from sympy import Symbol
Expand Down Expand Up @@ -130,38 +131,53 @@ def mark_ancilla(self, w):
if w in self.ancilla_lst:
self.marked_ancillas.add(w)

def uncompute_all(self, keep=[]):
"""Uncompute the whole circuit expect for the keep"""
pass
def uncompute_all(self, keep: List[Union[Symbol, int]] = []):
"""Uncompute the whole circuit expect for the keep (symbols or qubit)"""
# TODO: replace with + invert(keep)
scopy = copy.deepcopy(self.gates)
uncomputed = set()
self.barrier(label="un_all")
for g, qbs in reversed(scopy):
if qbs[-1] in keep or g == "bar" or qbs[-1] in self.free_ancilla_lst:
continue
uncomputed.add(qbs[-1])

if qbs[-1] in self.ancilla_lst:
self.free_ancilla_lst.add(qbs[-1])

self.append(g, qbs)

# Remove barrier if no uncomputed
if len(uncomputed) == 0:
self.gates.pop()

return uncomputed

def uncompute(self, to_mark=[]):
"""Uncompute all the marked ancillas plus the to_mark list"""
[self.mark_ancilla(x) for x in to_mark]

if len(self.marked_ancillas) == 0:
return []

self.barrier(label="un")

uncomputed = set()
new_gates_comp = []
not_to_uncompute = set()

for g, ws in self.gates_computed[::-1]:
if ws[-1] in self.marked_ancillas and not all(
[ww in self.marked_ancillas for ww in ws[:-1]]
):
not_to_uncompute.add(ws[-1])

for g, ws in self.gates_computed[::-1]:
if ws[-1] in self.marked_ancillas and ws[-1] not in not_to_uncompute:
for g, ws in reversed(self.gates_computed):
if ws[-1] in self.marked_ancillas:
uncomputed.add(ws[-1])
self.append(g, ws)
else:
new_gates_comp.append((g, ws))

for x in uncomputed:
for x in self.marked_ancillas:
self.free_ancilla_lst.add(x)
self.marked_ancillas = self.marked_ancillas - uncomputed
self.gates_computed = new_gates_comp[::-1]

# Remove barrier if no uncomputed
if len(uncomputed) == 0:
self.gates.pop()

Expand Down
21 changes: 19 additions & 2 deletions test/test_qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,27 @@ def test2(self):

qc.mcx(q, a[0])
qc.mcx(q + [a[0]], a[1])
qc.mcx(q + a[:1], a[2])
qc.mcx(q + a[:2], a[3])
qc.mcx(q + a[:2], a[2])
qc.mcx(q + a[:3], a[3])
qc.cx(a[3], r)
qc.uncompute(a)

self.assertEqual(len(qc.free_ancilla_lst), len(a))
# qc.draw()

def test_uncompute_all(self):
qc = QCircuit()
q = [qc.add_qubit() for x in range(4)]
a = [qc.add_ancilla(is_free=False) for x in range(4)]
r = qc.add_qubit()

qc.mcx(q, a[0])
qc.mcx(q + [a[0]], a[1])
qc.mcx(q + a[:2], a[2])
qc.mcx(q + a[:3], a[3])
qc.cx(a[3], r)
qc.uncompute(a)
qc.uncompute_all([r])
# qc.draw()


Expand Down

0 comments on commit 031fb34

Please sign in to comment.