Skip to content

Commit

Permalink
access tuple of tuple by var
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Jul 6, 2024
1 parent 6d40588 commit 1217a17
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 58 deletions.
2 changes: 2 additions & 0 deletions qlasskit/ast2ast/ast2ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def ast2ast(a_tree):
if sys.version_info < (3, 9):
a_tree = IndexReplacer().visit(a_tree)

# Matrix translator

# Fold constants
a_tree = ConstantFolder().visit(a_tree)

Expand Down
73 changes: 72 additions & 1 deletion qlasskit/ast2ast/astrewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,84 @@ def __unroll_arg(self, arg):
def generic_visit(self, node):
return super().generic_visit(node)

def visit_Subscript(self, node):
def visit_Subscript(self, node): # noqa: C901
_sval = node.slice

# Replace L[a] with const a, to L[const]
if isinstance(_sval, ast.Name) and _sval.id in self.const:
node.slice = self.const[_sval.id]

# Handle inner access L[i][j]
elif (
isinstance(node, ast.Subscript)
and isinstance(node.value, ast.Subscript)
and isinstance(node.value.value, ast.Name)
and isinstance(node.value.slice, ast.Name)
and isinstance(node.slice, ast.Name)
):
nname = node.value.value.id
iname = node.value.slice.id
jname = node.slice.id

def create_if_exp(i, j, max_i, max_j):
if i == max_i and j == max_j:
return ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
slice=ast.Constant(value=j),
ctx=ast.Load(),
)
else:
next_j = j + 1 if j < max_j else 0
next_i = i if j < max_j else i + 1
return ast.IfExp(
test=ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Name(id=iname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=i)],
),
ast.Compare(
left=ast.Name(id=jname, ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Constant(value=j)],
),
],
),
body=ast.Subscript(
value=ast.Subscript(
value=ast.Name(id=nname, ctx=ast.Load()),
slice=ast.Constant(value=i),
ctx=ast.Load(),
),
slice=ast.Constant(value=j),
ctx=ast.Load(),
),
orelse=create_if_exp(next_i, next_j, max_i, max_j),
)

# Infer i and j sizes from env['a']
a_type = self.env[nname]

# self.env[nname] is a constant
if isinstance(a_type, ast.Tuple):
max_i = len(a_type.elts) - 1
max_j = len(a_type.elts[0].elts) - 1 # type: ignore
# self.env[nname] is a type annotation
else:
outer_tuple = a_type.slice
max_i = len(outer_tuple.elts) - 1
inner_tuple = outer_tuple.elts
max_j = len(inner_tuple) - 1

# Create the IfExp structure
return create_if_exp(0, 0, max_i, max_j)

# Unroll L[a] with (L[0] if a == 0 else L[1] if a == 1 ...)
elif (isinstance(_sval, ast.Name) and _sval.id not in self.const) or isinstance(
_sval, ast.Subscript
Expand Down
80 changes: 39 additions & 41 deletions test/qlassf/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,44 +89,42 @@ def test_matrix_len(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

# TODO: this raises Not a tuple in ast2ast visit subscript with not constant _sval: Subscript
# (value=Name(id='a', ctx=Load()), slice=Name(id='i', ctx=Load()), ctx=Load())
# def test_matrix_access2(self):
# f = (
# "def test(a: Qmatrix[Qint[2], 2, 2]) -> Qint[2]:\n\ti = 1\n"
# "\tj = i + 1\n\treturn a[i][j]"
# )
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)
# def test_matrix_access3(self):
# f = (
# "def test(a: Qmatrix[Qint[2], 2, 2], i: Qint[2], j: Qint[2]) -> Qint[2]:\n"
# "\treturn a[i][j]"
# )
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)

# def test_matrix_access_with_var(self):
# f = "def test(a: Qint[2]) -> Qint[2]:\n\tc = [[1,2],[3,4]]\n\tb = c[a][a]\n\treturn b"
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)

# def test_list_access_with_var_on_tuple(self):
# # TODO: this fails on internal compiler
# if self.compiler == "internal":
# return

# f = ("def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n\tai,bi = ab\n"
# "\td = c[ai] + c[bi]\n\treturn d")
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)

# def test_list_access_with_var_on_tuple2(self):
# # TODO: this fails on internal compiler
# if self.compiler == "internal":
# return

# f = ("def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n"
# "\td = c[ab[0]] + c[ab[1]]\n\treturn d")
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)
def test_matrix_access2(self):
f = (
"def test(a: Qmatrix[Qint[2], 2, 2]) -> Qint[2]:\n\ti = 0\n"
"\tj = i + 1\n\treturn a[i][j]"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_matrix_access3(self):
f = (
"def test(a: Qmatrix[Qint[2], 2, 2], i: Qint[2], j: Qint[2]) -> Qint[2]:\n"
"\treturn a[i][j]"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_matrix_access_with_var(self):
f = (
"def test(a: Qint[2]) -> Qint[4]:\n\tc = [[1,2,7,8],[3,4,8,8],[5,6,9,1],[1,2,7,8]]\n"
"\tb = c[a][a]\n\treturn b"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_list_access_with_var_on_tuple(self):
f = (
"def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n\tai,bi = ab\n"
"\td = c[ai] + c[bi]\n\treturn d"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

def test_list_access_with_var_on_tuple2(self):
f = (
"def test(ab: Tuple[Qint[2], Qint[2]]) -> Qint[2]:\n\tc = [1,2,3,2]\n"
"\td = c[ab[0]] + c[ab[1]]\n\treturn d"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
17 changes: 8 additions & 9 deletions test/qlassf/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,11 @@ def test_tuple_iterator_vartuple(self):
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)

# TODO: failing for #63
# def test_tuple_of_tuple_var_access(self):
# f = (
# "def test(a: Tuple[Tuple[bool, bool], Tuple[bool, bool]], "
# "i: Qint[2], j: Qint[2]) -> bool:\n"
# "\treturn a[i][j]"
# )
# qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
# compute_and_compare_results(self, qf)
def test_tuple_of_tuple_var_access(self):
f = (
"def test(a: Tuple[Tuple[bool, bool, bool], Tuple[bool, bool, bool], "
"Tuple[bool, bool, bool]], i: Qint[2], j: Qint[2]) -> bool:\n"
"\treturn a[i][j]"
)
qf = qlassf(f, to_compile=COMPILATION_ENABLED, compiler=self.compiler)
compute_and_compare_results(self, qf)
17 changes: 10 additions & 7 deletions test/test_decopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import random
import itertools
import random
import unittest

from qlasskit import QCircuit, boolopt, qlassf
from qlasskit.decompiler import circuit_boolean_optimizer
Expand Down Expand Up @@ -95,14 +95,17 @@ def test_circuit_boolean_optimizer_random_2(self):
qc_n_un = qiskit_unitary(qc_n.export())
self.assertEqual(qc_un, qc_n_un)


def test_circuit_boolean_optimizer_random_x_cx(self):
g_simp = 0

possib = [(gates.CX, x[0], x[1]) for x in itertools.permutations([0,1,2],r=2)]
possib += [(gates.X, x[0]) for x in itertools.permutations([0,1,2],r=1)]

for i in random.choices(list(itertools.combinations_with_replacement(possib, r=8)), k=32):
possib = [
(gates.CX, x[0], x[1]) for x in itertools.permutations([0, 1, 2], r=2)
]
possib += [(gates.X, x[0]) for x in itertools.permutations([0, 1, 2], r=1)]

for i in random.choices(
list(itertools.combinations_with_replacement(possib, r=8)), k=32
):
qc = QCircuit(3)
for g in i:
qc.append(g[0](), g[1:])
Expand Down

0 comments on commit 1217a17

Please sign in to comment.