Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggregate uncomputing #2

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qlasskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from sympy import simplify, symbols
from sympy.logic import ITE, And, Implies, Not, Or, boolalg
from sympy.logic import ITE, And, Implies, Not, Or # , boolalg

from .. import QCircuit
from ..ast2logic.typing import Args, BoolExpList
Expand All @@ -38,7 +38,7 @@ def _symplify_exp(self, exp):

# Simplify the expression
exp = simplify(exp)
exp = boolalg.to_cnf(exp)
# exp = boolalg.to_cnf(exp)
return exp

def compile(
Expand Down
30 changes: 25 additions & 5 deletions qlasskit/compiler/poccompiler2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License..

from typing import Dict

from sympy import Symbol
from sympy.logic import And, Not, Or
from sympy.logic.boolalg import Boolean, BooleanFalse, BooleanTrue
Expand All @@ -31,17 +33,29 @@ def compile(self, name, args: Args, ret_size: int, exprs: BoolExpList) -> QCircu
for arg_b in arg.bitvec:
qc.add_qubit(arg_b)

self.mapped: Dict[Boolean, int] = {}

for sym, exp in exprs:
# print(sym, self._symplify_exp(exp))
iret = self.compile_expr(qc, self._symplify_exp(exp))
print("iret", iret)
# print("iret", iret)
qc.map_qubit(sym, iret, promote=True)
uncomputed = qc.uncompute()

for k in self.mapped.keys():
if self.mapped[k] in uncomputed:
del self.mapped[k]

return qc

def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
if isinstance(expr, Symbol):
return qc[expr.name]

elif expr in self.mapped:
print("!!cachehit!!", expr)
return self.mapped[expr]

elif isinstance(expr, Not):
fa = qc.get_free_ancilla()
eret = self.compile_expr(qc, expr.args[0])
Expand All @@ -51,7 +65,11 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
qc.cx(eret, fa)
qc.x(fa)

#qc.free_ancilla(eret)
qc.uncompute()

qc.mark_ancilla(eret)

self.mapped[expr] = fa

return fa

Expand All @@ -63,13 +81,17 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:

qc.mcx(erets, fa)

qc.free_ancillas(erets)
qc.uncompute()

[qc.mark_ancilla(eret) for eret in erets]
self.mapped[expr] = fa

return fa

elif isinstance(expr, Or):
# Translate or to and
expr = Not(And(*[Not(e) for e in expr.args]))
# print("trans", expr)
return self.compile_expr(qc, expr)

# OLD TRANSLATOR
Expand All @@ -87,8 +109,6 @@ def compile_expr(self, qc: QCircuit, expr: Boolean) -> int:
# for j in range(i + 1, nclau - i):
# qc.x(iclau[j])

# qc.free_ancillas(iclau)

# return fa

elif isinstance(expr, BooleanFalse) or isinstance(expr, BooleanTrue):
Expand Down
76 changes: 18 additions & 58 deletions qlasskit/qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, num_qubits=0, name="qc"):

self.ancilla_lst = set()
self.free_ancilla_lst = set()
self.marked_ancillas = []

for x in range(num_qubits):
self.qubit_map[f"q{x}"] = x
Expand Down Expand Up @@ -72,24 +73,6 @@ def add_ancilla(self, name=None, is_free=True):
self.free_ancilla_lst.add(i)
return i

def free_ancilla(self, w):
"""Freeing of an ancilla qubit"""
w = self[w]
if w not in self.ancilla_lst:
return # we don't care
# raise Exception(f"Qubit {w} is not in the ancilla set")

if w in self.free_ancilla_lst:
raise Exception(f"Ancilla {w} is already free")

self.uncompute(w)
self.free_ancilla_lst.add(w)

def free_ancillas(self, wl):
"""Freeing of a list of ancilla qubits"""
for w in wl:
self.free_ancilla(w)

def get_free_ancilla(self):
"""Get the first free ancilla available"""
if len(self.free_ancilla_lst) == 0:
Expand All @@ -99,50 +82,27 @@ def get_free_ancilla(self):

return anc

def uncompute(self, w):
"""Uncompute a specific ancilla qubit.
def mark_ancilla(self, w):
self.marked_ancillas.append(w)

Args:
w (int): The index of the qubit to be uncomputed.
"""
w = self[w]
if w not in self.ancilla_lst:
raise Exception("qubit not in the ancilla list")

print("uncomputing ", w)
def uncompute(self, to_mark=[]):
"""Uncompute all the marked ancillas plus the to_mark list"""
[self.mark_ancilla(x) for x in to_mark]

g_comp = []
self.barrier(label=f"U{w}")
uncomputed = set()
for g, ws in self.gates_computed[::-1]:
# w is the target
if w == ws[-1]:
if (
ws[-1] in self.marked_ancillas
and ws[-1] in self.ancilla_lst
and ws[-1] not in self.free_ancilla_lst
):
uncomputed.add(ws[-1])
self.append(g, ws)
# w is a control
# elif w in ws[:-1]:
# self.append(g, ws)
else:
g_comp.append((g, ws))

self.barrier(label=f"EU{w}")
self.gates_computed = g_comp[::-1]

# def uncompute(self):
# """Uncompute released ancilla qubits"""

# g_comp = []
# self.barrier(label=f"U{''.join(map(str,self.uncomputable))}")
# for g, ws in self.gates_computed[::-1]:
# # w is the target
# if ws[-1] in self.uncomputable:
# self.append(g, ws)
# # w is a control
# # elif w in ws[:-1]:
# # self.append(g, ws)
# else:
# g_comp.append((g, ws))

# self.barrier(label=f"EU{''.join(map(str,self.uncomputable))}")
# self.gates_computed = g_comp[::-1]
self.free_ancilla_lst.add(ws[-1])

self.marked_ancillas = [] # self.marked_ancillas - uncomputed
self.gates_computed = []
return uncomputed

def map_qubit(self, name, index, promote=False):
"""Map a name to a qubit
Expand Down
13 changes: 9 additions & 4 deletions test/test_qcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,31 @@ def test_base_mapping(self):
class TestQCircuitUncomputing(unittest.TestCase):
def test1(self):
qc = QCircuit()
a, b, c, d = qc.add_qubit(), qc.add_qubit(), qc.add_ancilla(), qc.add_ancilla()
a, b, c, d = (
qc.add_qubit(),
qc.add_qubit(),
qc.add_ancilla(is_free=False),
qc.add_ancilla(is_free=False),
)
f = qc.add_qubit("res")
qc.mcx([a, b], c)
qc.mcx([a, b, c], d)
qc.cx(d, f)
qc.uncompute(c)
qc.uncompute(d) # this is invalidated
qc.uncompute([c, d])
qc.draw()

def test2(self):
qc = QCircuit()
q = [qc.add_qubit() for x in range(4)]
a = [qc.add_ancilla() for x in range(4)]
a = [qc.add_ancilla(is_free=False) for x in range(4)]
r = qc.add_qubit()

qc.mcx(q, a[0])
qc.mcx(q + [a[0]], a[1])
qc.mcx(q + a[:1], a[2])
qc.mcx(q + a[:2], a[3])
qc.cx(a[3], r)
qc.uncompute(a)
qc.draw()


Expand Down
10 changes: 5 additions & 5 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,28 @@ def compare_circuit_truth_table(cls, qf):
for i in range(qf.input_size):
qc.initialize(1 if truth_line[i] else 0, i)

#(truth_line)
# (truth_line)

qc.append(gate, list(range(qf.num_qubits)))
# print(qc.decompose().draw("text"))
print(circ_qi.draw("text"))
# print(circ_qi.draw("text"))

counts = qiskit_measure_and_count(qc)
#print(counts, circ.qubit_map)
# print(counts, circ.qubit_map)

truth_str = "".join(
map(lambda x: "1" if x else "0", truth_line[-qf.ret_size :])
)

#print(truth_str)
# print(truth_str)

res = list(counts.keys())[0][::-1]
res_str = ""
for qname in qf.truth_table_header()[-qf.ret_size :]:
res_str += res[circ.qubit_map[qname]]

# res = res[0 : len(truth_str)][::-1]
#print(res_str)
# print(res_str)

cls.assertEqual(len(counts), 1)
cls.assertEqual(truth_str, res_str)