Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Array comprehension #613

Merged
merged 16 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from typing_extensions import Self

from guppylang.ast_util import AstNode, name_nodes_in_ast
from guppylang.nodes import DesugaredListComp, NestedFunctionDef, PyExpr
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
NestedFunctionDef,
PyExpr,
)

if TYPE_CHECKING:
from guppylang.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -144,19 +151,30 @@ def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None:
self.visit(value)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
self._handle_comprehension(node.generators, node.elt)

def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None:
self._handle_comprehension([node.generator], node.elt)

def visit_DesugaredGeneratorExpr(self, node: DesugaredGeneratorExpr) -> None:
self._handle_comprehension(node.generators, node.elt)

def _handle_comprehension(
self, generators: list[DesugaredGenerator], elt: ast.expr
) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
# update `self.stats` with assignments
inner_visitor = VariableVisitor(self.bb)
inner_stats = inner_visitor.stats

# The generators are evaluated left to right
for gen in node.generators:
for gen in generators:
inner_visitor.visit(gen.iter_assign)
inner_visitor.visit(gen.hasnext_assign)
inner_visitor.visit(gen.next_assign)
for cond in gen.ifs:
inner_visitor.visit(cond)
inner_visitor.visit(node.elt)
inner_visitor.visit(elt)

self.stats.used |= {
x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned
Expand Down
19 changes: 15 additions & 4 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from guppylang.experimental import check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
IterEnd,
IterHasNext,
Expand Down Expand Up @@ -313,8 +314,18 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
# The final value is stored in the temporary variable
return make_var(tmp, node)

def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp:
check_lists_enabled(node)
generators, elt = self._build_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredListComp(elt=elt, generators=generators))

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr:
generators, elt = self._build_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators))

def _build_comprehension(
self, generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
) -> tuple[list[DesugaredGenerator], ast.expr]:
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
Expand All @@ -329,7 +340,7 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in node.generators:
for g in generators:
if g.is_async:
raise GuppyError(UnsupportedError(g, "Async generators"))
g.iter = self.visit(g.iter)
Expand All @@ -352,8 +363,8 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
)
gens.append(desugared)

node.elt = self.visit(node.elt)
return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens))
elt = self.visit(elt)
return gens, elt

def visit_Call(self, node: ast.Call) -> ast.AST:
return is_py_expression(node) or self.generic_visit(node)
Expand Down
17 changes: 17 additions & 0 deletions guppylang/checker/errors/type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,20 @@ class AssignNonPlaceHelp(Help):
"field `{field.name}`"
)
field: StructField


@dataclass(frozen=True)
class ArrayComprUnknownSizeError(Error):
title: ClassVar[str] = "Array comprehension with nonstatic size"
span_label: ClassVar[str] = "Cannot infer the size of this array comprehension ..."

@dataclass(frozen=True)
class IfGuard(Note):
span_label: ClassVar[str] = "since it depends on this condition"

@dataclass(frozen=True)
class DynamicIterator(Note):
span_label: ClassVar[str] = (
"since the number of elements yielded by this iterator is not statically "
"known"
)
32 changes: 23 additions & 9 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredGeneratorExpr,
DesugaredListComp,
FieldAccessAndDrop,
GenericParamValue,
Expand Down Expand Up @@ -254,7 +255,9 @@ def visit_DesugaredListComp(
) -> tuple[ast.expr, Subst]:
if not is_list_type(ty):
return self._fail(ty, node)
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
subst = unify(get_element_type(ty), elt_ty, {})
if subst is None:
actual = list_type(elt_ty)
Expand Down Expand Up @@ -475,10 +478,21 @@ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
return node, list_type(el_ty)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]:
node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx)
node.generators, node.elt, elt_ty = synthesize_comprehension(
node, node.generators, node.elt, self.ctx
)
result_ty = list_type(elt_ty)
return node, result_ty

def visit_DesugaredGeneratorExpr(
self, node: DesugaredGeneratorExpr
) -> tuple[ast.expr, Type]:
# This is a generator in an arbitrary expression position. We don't support
# generators as first-class value yet, so we always error out here. Special
# cases where generator are allowed need to explicitly check for them (e.g. see
# the handling of array comprehensions in the compiler for the `array` function)
raise GuppyError(UnsupportedError(node, "Generator expressions"))

def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]:
# We need to synthesise the argument type, so we can look up dunder methods
node.operand, op_ty = self.synthesize(node.operand)
Expand Down Expand Up @@ -665,7 +679,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
node.value, [], "__iter__", "iterable", exp_sig, True
)
# Unwrap the size hint if present
if is_sized_iter_type(ty):
if is_sized_iter_type(ty) and node.unwrap_size_hint:
expr, ty = self.synthesize_instance_func(expr, [], "unwrap_iter", "")

# If the iterator was created by a `for` loop, we can add some extra checks to
Expand Down Expand Up @@ -1033,15 +1047,15 @@ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type


def synthesize_comprehension(
node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context
) -> tuple[DesugaredListComp, Type]:
node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

# If there are no more generators left, we can check the list element
if not gens:
node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt)
return node, elt_ty
elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
return gens, elt, elt_ty

# Check the iterator in the outer context
gen, *gens = gens
Expand All @@ -1065,12 +1079,12 @@ def synthesize_comprehension(
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
node, elt_ty = synthesize_comprehension(node, gens, inner_ctx)
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)

# The iter finalizer is again checked in the outer context
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return node, elt_ty
return [gen, *gens], elt, elt_ty


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
Expand Down
12 changes: 8 additions & 4 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from guppylang.nodes import (
AnyCall,
CheckedNestedFunctionDef,
DesugaredArrayComp,
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
Expand Down Expand Up @@ -417,7 +418,10 @@ def visit_Expr(self, node: ast.Expr) -> None:
raise GuppyTypeError(err)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
self._check_comprehension(node, node.generators)
self._check_comprehension(node.generators, node.elt)

def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None:
self._check_comprehension([node.generator], node.elt)

def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None:
# Linearity of the nested function has already been checked. We just need to
Expand Down Expand Up @@ -457,11 +461,11 @@ def _check_assign_targets(self, targets: list[ast.expr]) -> None:
self.scope.assign(tgt_place)

def _check_comprehension(
self, node: DesugaredListComp, gens: list[DesugaredGenerator]
self, gens: list[DesugaredGenerator], elt: ast.expr
) -> None:
"""Helper function to recursively check list comprehensions."""
if not gens:
self.visit(node.elt)
self.visit(elt)
return

# Check the iterator expression in the current scope
Expand Down Expand Up @@ -502,7 +506,7 @@ def _check_comprehension(
self.visit(expr)

# Recursively check the remaining generators
self._check_comprehension(node, gens)
self._check_comprehension(gens, elt)

# Check the iter finalizer so we record a final use of the iterator
self.visit(gen.iterend)
Expand Down
20 changes: 18 additions & 2 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from hugr import Wire, ops
from hugr.build.dfg import DP, DefinitionBuilder, DfBase

from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable
from guppylang.checker.core import FieldAccess, Globals, Place, PlaceId, Variable
from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import InternalGuppyError
from guppylang.tys.ty import StructType
from guppylang.tys.ty import StructType, Type

CompiledLocals = dict[PlaceId, Wire]

Expand All @@ -26,15 +28,19 @@ class CompiledGlobals:
compiled: dict[DefId, CompiledDef]
worklist: set[DefId]

checked_globals: Globals

def __init__(
self,
checked: dict[DefId, CheckedDef],
module: DefinitionBuilder[ops.Module],
checked_globals: Globals,
) -> None:
self.module = module
self.checked = checked
self.worklist = set()
self.compiled = {}
self.checked_globals = checked_globals

def build_compiled_def(self, def_id: DefId) -> CompiledDef:
"""Returns the compiled definitions corresponding to the given ID.
Expand Down Expand Up @@ -65,6 +71,16 @@ def compile(self, defn: CheckedDef) -> None:
next_def = self.build_compiled_def(next_id)
next_def.compile_inner(self)

def get_instance_func(
self, ty: Type | TypeDef, name: str
) -> CompiledCallableDef | None:
checked_func = self.checked_globals.get_instance_func(ty, name)
if checked_func is None:
return None
compiled_func = self.build_compiled_def(checked_func.id)
assert isinstance(compiled_func, CompiledCallableDef)
return compiled_func


@dataclass
class DFContainer:
Expand Down
Loading
Loading