Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expr typecheck #1

Merged
merged 4 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
- [x] Poc compiler 2 using qcircuit abstraction
- [x] OpenQASM3 exporter

#### Typechecker branch
- [x] Translate_expr should returns ttype*expr
- [x] Args should also hold the original type
- [x] Transform Env to a class holding also the original types
- [x] Typecheck all the expressions

### Week 3: (9 Oct 23)
- [ ] Int: comparison - eq
- [ ] Qubit garbage uncomputing and recycling
- [ ] Test: add qubit usage check
- [ ] Compiler: remove consecutive X gates
Expand Down Expand Up @@ -97,4 +104,8 @@
- [ ] QuTip
- [ ] Pennylane
- [ ] Cirq
- [ ] Sympy quantum computing expressions
- [ ] Sympy quantum computing expressions

### Tools

- [ ] py2qasm tools
2 changes: 1 addition & 1 deletion qlasskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

from .qcircuit import QCircuit # noqa: F401
from .qlassf import QlassF, qlassf # noqa: F401
from .typing import Qint2, Qint4, Qint8, Qint12, Qint16, Qtype # noqa: F401
from .typing import Qint, Qint2, Qint4, Qint8, Qint12, Qint16, Qtype # noqa: F401
from .ast2logic import exceptions # noqa: F401
5 changes: 1 addition & 4 deletions qlasskit/ast2logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
# limitations under the License.
# isort:skip_file

from typing import List

Env = List[str]

from .env import Env, Binding # noqa: F401, E402
from .utils import flatten # noqa: F401, E402
from .t_arguments import translate_argument, translate_arguments # noqa: F401, E402
from .t_expression import translate_expression, type_of_exp # noqa: F401, E402
Expand Down
43 changes: 43 additions & 0 deletions qlasskit/ast2logic/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 Davide Gessa

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from typing_extensions import TypeAlias

from ..typing import Arg, Qint # noqa: F401, E402
from . import exceptions

Binding: TypeAlias = Arg


class Env:
def __init__(self):
self.bindings: List[Binding] = []

def bind(self, bb: Binding):
if bb.name in self:
raise Exception("duplicate bind")

self.bindings.append(bb)

def __contains__(self, key):
if len(list(filter(lambda x: x.name == key, self.bindings))) == 1:
return True

def __getitem__(self, key):
try:
return list(filter(lambda x: x.name == key, self.bindings))[0]
except:
raise exceptions.UnboundException(key, self)
10 changes: 10 additions & 0 deletions qlasskit/ast2logic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import ast


class TypeErrorException(Exception):
def __init__(self, got, excepted):
super().__init__(f"Got '{got}' excepted '{excepted}'")


class NoReturnTypeException(Exception):
def __init__(self):
super().__init__("Return type is mandatory")
Expand All @@ -35,6 +40,11 @@ def __init__(self, ob, message=None):
super().__init__(ast.dump(ob) + f": {message}" if message else "")


class OutOfBoundException(Exception):
def __init__(self, size, i):
super().__init__(f"size is {size}: {i} accessed")


class UnboundException(Exception):
def __init__(self, symbol, env):
super().__init__(f"{symbol} in {env}")
Expand Down
25 changes: 15 additions & 10 deletions qlasskit/ast2logic/t_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import List
from typing import List, Tuple

from ..typing import Args
from . import exceptions, flatten
from ..typing import Arg, Args, Qint, Qint2, Qint4, Qint8, Qint12, Qint16 # noqa: F
from . import exceptions
from .t_expression import TType


def translate_argument(ann, base="") -> List[str]:
def translate_argument(ann, base="") -> Arg:
def to_name(a):
return a.attr if isinstance(a, ast.Attribute) else a.id

# Tuple
if isinstance(ann, ast.Subscript) and ann.value.id == "Tuple": # type: ignore
al = []
ind = 0
ttypes: List[TType] = []
for i in ann.slice.value.elts: # type: ignore
if isinstance(i, ast.Name) and to_name(i) == "bool":
al.append(f"{base}.{ind}")
ttypes.append(bool)
else:
inner_list = translate_argument(i, base=f"{base}.{ind}")
al.extend(inner_list)
inner_arg = translate_argument(i, base=f"{base}.{ind}")
ttypes.append(inner_arg.ttype)
al.extend(inner_arg.bitvec)
ind += 1
return al
ttypes_t = tuple(ttypes)
return Arg(base, Tuple[ttypes_t], al)

# QintX
elif to_name(ann)[0:4] == "Qint":
n = int(to_name(ann)[4::])
arg_list = [f"{base}.{i}" for i in range(n)]
# arg_list.append((f"{base}{arg.arg}", n))
return arg_list
return Arg(base, eval(to_name(ann)), arg_list)

# Bool
elif to_name(ann) == "bool":
return [f"{base}"]
return Arg(base, bool, [f"{base}"])

else:
raise exceptions.UnknownTypeException(ann)
Expand All @@ -55,4 +60,4 @@ def translate_arguments(args) -> Args:
args_unrolled = map(
lambda arg: translate_argument(arg.annotation, base=arg.arg), args
)
return flatten(list(args_unrolled))
return list(args_unrolled)
10 changes: 5 additions & 5 deletions qlasskit/ast2logic/t_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def translate_ast(fun) -> LogicFun:
fun_name: str = fun.name

# env contains names visible from the current scope
env: Env = []
env = Env()

args: Args = translate_arguments(fun.args.args)
# TODO: types are string; maybe a translate_type?
for a_name in args:
env.append(a_name)

[env.bind(arg) for arg in args]

if not fun.returns:
raise exceptions.NoReturnTypeException()

ret_size = len(translate_argument(fun.returns))
ret_ = translate_argument(fun.returns) # TODO: we need to preserve this
ret_size = len(ret_)

exps = []
for stmt in fun.body:
Expand Down
95 changes: 59 additions & 36 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import List, Tuple, get_args

from sympy import Symbol
from sympy.logic import ITE, And, Not, Or, false, true
from sympy.logic.boolalg import Boolean
from typing_extensions import TypeAlias

from . import Env, exceptions

TType: TypeAlias = object

def type_of_exp(vlist, base, env, res=[]):

def type_of_exp(vlist, base, res=[]) -> List[Symbol]:
"""Type inference for expressions: iterate over val, and decompose to bool"""
if isinstance(vlist, list):
i = 0
res = []
for in_val in vlist:
r_new, env = type_of_exp(in_val, f"{base}.{i}", env, res)
r_new = type_of_exp(in_val, f"{base}.{i}", res)
if isinstance(r_new, list):
res.extend(r_new)
else:
res.append(r_new)
i += 1
return res, env
return res
else:
new_symb = (f"{base}", vlist)
env.append(new_symb[0])
return [new_symb], env
return [new_symb]


def translate_expression(expr, env: Env) -> Boolean: # noqa: C901
def translate_expression(expr, env: Env) -> Tuple[TType, Boolean]: # noqa: C901
"""Translate an expression"""

# Name reference
if isinstance(expr, ast.Name):
if expr.id not in env:
# Handle complex types
rl = []
for sym in env:
if sym[0 : (len(expr.id) + 1)] == f"{expr.id}.":
rl.append(Symbol(sym))

if len(rl) == 0:
raise exceptions.UnboundException(expr.id, env)

return rl
return Symbol(expr.id)
binding = env[expr.id]
return (binding.ttype, binding.to_exp())

# Subscript: a[0][1]
elif isinstance(expr, ast.Subscript):
Expand All @@ -77,10 +70,24 @@ def unroll_subscripts(sub, st):
else:
sn = unroll_subscripts(expr, "")

if sn not in env:
if sn.split(".")[0] not in env:
raise exceptions.UnboundException(sn, env)

return Symbol(sn)
# Get the inner type
inner_type = env[sn.split(".")[0]].ttype
for i in sn.split(".")[1:]:
if hasattr(inner_type, "BIT_SIZE"):
if int(i) < inner_type.BIT_SIZE:
inner_type = bool
else:
raise exceptions.OutOfBoundException(inner_type.BIT_SIZE, i)
else:
if int(i) < len(get_args(inner_type)):
inner_type = get_args(inner_type)[int(i)]
else:
raise exceptions.OutOfBoundException(len(get_args(inner_type)), i)

return (inner_type, Symbol(sn))

# Boolop: and, or
elif isinstance(expr, ast.BoolOp):
Expand All @@ -90,43 +97,59 @@ def unfold(v_exps, op):
op(v_exps[0], unfold(v_exps[1::], op)) if len(v_exps) > 1 else v_exps[0]
)

v_exps = [translate_expression(e_in, env) for e_in in expr.values]
vt_exps = [translate_expression(e_in, env) for e_in in expr.values]
v_exps = [x[1] for x in vt_exps]
for x in vt_exps:
if x[0] != bool:
raise exceptions.TypeErrorException(x[0], bool)

return unfold(v_exps, And if isinstance(expr.op, ast.And) else Or)
return (bool, unfold(v_exps, And if isinstance(expr.op, ast.And) else Or))

# Unary: not
elif isinstance(expr, ast.UnaryOp):
if isinstance(expr.op, ast.Not):
return Not(translate_expression(expr.operand, env))
texp, exp = translate_expression(expr.operand, env)

if texp != bool:
raise exceptions.TypeErrorException(texp, bool)

return (bool, Not(exp))
else:
raise exceptions.ExpressionNotHandledException(expr)

# If expression
elif isinstance(expr, ast.IfExp):
# (condition) and (true_value) or (not condition) and (false_value)
# return Or(
# And(translate_expression(expr.test, env), translate_expression(expr.body, env)),
# And(Not(translate_expression(expr.test, env)), translate_expression(expr.orelse, env))
# )
return ITE(
translate_expression(expr.test, env),
translate_expression(expr.body, env),
translate_expression(expr.orelse, env),
te_test = translate_expression(expr.test, env)
te_true = translate_expression(expr.body, env)
te_false = translate_expression(expr.orelse, env)

if te_test[0] != bool:
raise exceptions.TypeErrorException(te_test[0], bool)

if te_true[0] != te_false[0]:
raise exceptions.TypeErrorException(te_false[0], te_true[0])

return (
te_true[0],
ITE(te_test[1], te_true[1], te_false[1]),
)

# Constant
elif isinstance(expr, ast.Constant):
if expr.value is True:
return true
return (bool, true)
elif expr.value is False:
return false
return (bool, false)
else:
raise exceptions.ExpressionNotHandledException(expr)

# Tuple
elif isinstance(expr, ast.Tuple):
elts = [translate_expression(elt, env) for elt in expr.elts]
return elts
telts = [translate_expression(elt, env) for elt in expr.elts]
elts = [x[1] for x in telts]
tlts = [x[0] for x in telts]

return (Tuple[tuple(tlts)], elts)

# Compare operator
elif isinstance(expr, ast.Compare):
Expand Down
12 changes: 7 additions & 5 deletions qlasskit/ast2logic/t_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sympy import Symbol
from sympy.logic.boolalg import Boolean

from . import Env, exceptions, translate_expression, type_of_exp
from . import Binding, Env, exceptions, translate_expression, type_of_exp


def translate_statement( # noqa: C901
Expand Down Expand Up @@ -57,14 +57,16 @@ def translate_statement( # noqa: C901
if target in env:
raise exceptions.SymbolReassignedException(target)

val = translate_expression(stmt.value, env)
res, env = type_of_exp(val, f"{target}", env)
tval, val = translate_expression(stmt.value, env) # TODO: typecheck
res = type_of_exp(val, f"{target}")
env.bind(Binding(target, tval, [x[0] for x in res]))
res = list(map(lambda x: (Symbol(x[0]), x[1]), res))
return res, env

elif isinstance(stmt, ast.Return):
vexp = translate_expression(stmt.value, env)
res, env = type_of_exp(vexp, "_ret", env)
texp, vexp = translate_expression(stmt.value, env) # TODO: typecheck
res = type_of_exp(vexp, "_ret")
env.bind(Binding("_ret", texp, [x[0] for x in res]))
res = list(map(lambda x: (Symbol(x[0]), x[1]), res))
return res, env

Expand Down
Loading
Loading