Skip to content

Commit

Permalink
type*expr as return type of transalte_expression
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Oct 6, 2023
1 parent 756250f commit 4ad476b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
51 changes: 30 additions & 21 deletions qlasskit/ast2logic/t_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import Any, List, Tuple

from sympy import Symbol
from sympy.logic import ITE, And, Not, Or, false, true
from sympy.logic.boolalg import Boolean
from typing_extensions import TypeAlias

from . import Env, exceptions

TType: TypeAlias = object

def type_of_exp(vlist, base, env, res=[]):

def type_of_exp(vlist, base, env, res=[]) -> Tuple[List[Symbol], Env]:
"""Type inference for expressions: iterate over val, and decompose to bool"""
if isinstance(vlist, list):
i = 0
Expand All @@ -39,7 +43,7 @@ def type_of_exp(vlist, base, env, res=[]):
return [new_symb], env


def translate_expression(expr, env: Env) -> Boolean: # noqa: C901
def translate_expression(expr, env: Env) -> Tuple[TType, Boolean]: # noqa: C901
"""Translate an expression"""

# Name reference
Expand All @@ -54,8 +58,8 @@ def translate_expression(expr, env: Env) -> Boolean: # noqa: C901
if len(rl) == 0:
raise exceptions.UnboundException(expr.id, env)

return rl
return Symbol(expr.id)
return (Any, rl) # TODO: typecheck
return (Any, Symbol(expr.id)) # TODO: typecheck

# Subscript: a[0][1]
elif isinstance(expr, ast.Subscript):
Expand All @@ -80,7 +84,7 @@ def unroll_subscripts(sub, st):
if sn not in env:
raise exceptions.UnboundException(sn, env)

return Symbol(sn)
return (Any, Symbol(sn))

# Boolop: and, or
elif isinstance(expr, ast.BoolOp):
Expand All @@ -90,43 +94,48 @@ def unfold(v_exps, op):
op(v_exps[0], unfold(v_exps[1::], op)) if len(v_exps) > 1 else v_exps[0]
)

v_exps = [translate_expression(e_in, env) for e_in in expr.values]
v_exps = [
translate_expression(e_in, env)[1] for e_in in expr.values
] # TODO: typecheck

return unfold(v_exps, And if isinstance(expr.op, ast.And) else Or)
return (bool, unfold(v_exps, And if isinstance(expr.op, ast.And) else Or))

# Unary: not
elif isinstance(expr, ast.UnaryOp):
if isinstance(expr.op, ast.Not):
return Not(translate_expression(expr.operand, env))
return (
bool,
Not(translate_expression(expr.operand, env)[1]),
) # TODO: typecheck
else:
raise exceptions.ExpressionNotHandledException(expr)

# If expression
elif isinstance(expr, ast.IfExp):
# (condition) and (true_value) or (not condition) and (false_value)
# return Or(
# And(translate_expression(expr.test, env), translate_expression(expr.body, env)),
# And(Not(translate_expression(expr.test, env)), translate_expression(expr.orelse, env))
# )
return ITE(
translate_expression(expr.test, env),
translate_expression(expr.body, env),
translate_expression(expr.orelse, env),
return (
bool,
ITE(
translate_expression(expr.test, env)[1], # TODO: typecheck
translate_expression(expr.body, env)[1], # TODO: typecheck
translate_expression(expr.orelse, env)[1], # TODO: typecheck
),
)

# Constant
elif isinstance(expr, ast.Constant):
if expr.value is True:
return true
return (bool, true)
elif expr.value is False:
return false
return (bool, false)
else:
raise exceptions.ExpressionNotHandledException(expr)

# Tuple
elif isinstance(expr, ast.Tuple):
elts = [translate_expression(elt, env) for elt in expr.elts]
return elts
elts = [
translate_expression(elt, env)[1] for elt in expr.elts
] # TODO: typecheck
return (Any, elts) # TODO: typecheck

# Compare operator
elif isinstance(expr, ast.Compare):
Expand Down
4 changes: 2 additions & 2 deletions qlasskit/ast2logic/t_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def translate_statement( # noqa: C901
if target in env:
raise exceptions.SymbolReassignedException(target)

val = translate_expression(stmt.value, env)
tval, val = translate_expression(stmt.value, env) # TODO: typecheck
res, env = type_of_exp(val, f"{target}", env)
res = list(map(lambda x: (Symbol(x[0]), x[1]), res))
return res, env

elif isinstance(stmt, ast.Return):
vexp = translate_expression(stmt.value, env)
texp, vexp = translate_expression(stmt.value, env) # TODO: typecheck
res, env = type_of_exp(vexp, "_ret", env)
res = list(map(lambda x: (Symbol(x[0]), x[1]), res))
return res, env
Expand Down

0 comments on commit 4ad476b

Please sign in to comment.