diff --git a/guppylang/cfg/bb.py b/guppylang/cfg/bb.py index aea8b7a8..6bb6b75a 100644 --- a/guppylang/cfg/bb.py +++ b/guppylang/cfg/bb.py @@ -7,7 +7,14 @@ from typing_extensions import Self from guppylang.ast_util import AstNode, name_nodes_in_ast -from guppylang.nodes import DesugaredListComp, NestedFunctionDef, PyExpr +from guppylang.nodes import ( + DesugaredArrayComp, + DesugaredGenerator, + DesugaredGeneratorExpr, + DesugaredListComp, + NestedFunctionDef, + PyExpr, +) if TYPE_CHECKING: from guppylang.cfg.cfg import BaseCFG @@ -144,19 +151,30 @@ def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None: self.visit(value) def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: + self._handle_comprehension(node.generators, node.elt) + + def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None: + self._handle_comprehension([node.generator], node.elt) + + def visit_DesugaredGeneratorExpr(self, node: DesugaredGeneratorExpr) -> None: + self._handle_comprehension(node.generators, node.elt) + + def _handle_comprehension( + self, generators: list[DesugaredGenerator], elt: ast.expr + ) -> None: # Names bound in the comprehension are only available inside, so we shouldn't # update `self.stats` with assignments inner_visitor = VariableVisitor(self.bb) inner_stats = inner_visitor.stats # The generators are evaluated left to right - for gen in node.generators: + for gen in generators: inner_visitor.visit(gen.iter_assign) inner_visitor.visit(gen.hasnext_assign) inner_visitor.visit(gen.next_assign) for cond in gen.ifs: inner_visitor.visit(cond) - inner_visitor.visit(node.elt) + inner_visitor.visit(elt) self.stats.used |= { x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index d2ca6590..29316a0e 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -22,6 +22,7 @@ from guppylang.experimental import check_lists_enabled from guppylang.nodes import ( DesugaredGenerator, + DesugaredGeneratorExpr, DesugaredListComp, IterEnd, IterHasNext, @@ -313,8 +314,18 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: # The final value is stored in the temporary variable return make_var(tmp, node) - def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp: check_lists_enabled(node) + generators, elt = self._build_comprehension(node.generators, node.elt, node) + return with_loc(node, DesugaredListComp(elt=elt, generators=generators)) + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr: + generators, elt = self._build_comprehension(node.generators, node.elt, node) + return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators)) + + def _build_comprehension( + self, generators: list[ast.comprehension], elt: ast.expr, node: ast.AST + ) -> tuple[list[DesugaredGenerator], ast.expr]: # Check for illegal expressions illegals = find_nodes(is_illegal_in_list_comp, node) if illegals: @@ -329,7 +340,7 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: # Desugar into statements that create the iterator, check for a next element, # get the next element, and finalise the iterator. gens = [] - for g in node.generators: + for g in generators: if g.is_async: raise GuppyError(UnsupportedError(g, "Async generators")) g.iter = self.visit(g.iter) @@ -352,8 +363,8 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: ) gens.append(desugared) - node.elt = self.visit(node.elt) - return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) + elt = self.visit(elt) + return gens, elt def visit_Call(self, node: ast.Call) -> ast.AST: return is_py_expression(node) or self.generic_visit(node) diff --git a/guppylang/checker/errors/type_errors.py b/guppylang/checker/errors/type_errors.py index c87bbdc9..3aa5d406 100644 --- a/guppylang/checker/errors/type_errors.py +++ b/guppylang/checker/errors/type_errors.py @@ -212,3 +212,20 @@ class AssignNonPlaceHelp(Help): "field `{field.name}`" ) field: StructField + + +@dataclass(frozen=True) +class ArrayComprUnknownSizeError(Error): + title: ClassVar[str] = "Array comprehension with nonstatic size" + span_label: ClassVar[str] = "Cannot infer the size of this array comprehension ..." + + @dataclass(frozen=True) + class IfGuard(Note): + span_label: ClassVar[str] = "since it depends on this condition" + + @dataclass(frozen=True) + class DynamicIterator(Note): + span_label: ClassVar[str] = ( + "since the number of elements yielded by this iterator is not statically " + "known" + ) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index e4b9f8dd..5a829d58 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -85,6 +85,7 @@ from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled from guppylang.nodes import ( DesugaredGenerator, + DesugaredGeneratorExpr, DesugaredListComp, FieldAccessAndDrop, GenericParamValue, @@ -254,7 +255,9 @@ def visit_DesugaredListComp( ) -> tuple[ast.expr, Subst]: if not is_list_type(ty): return self._fail(ty, node) - node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + node.generators, node.elt, elt_ty = synthesize_comprehension( + node, node.generators, node.elt, self.ctx + ) subst = unify(get_element_type(ty), elt_ty, {}) if subst is None: actual = list_type(elt_ty) @@ -475,10 +478,21 @@ def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: return node, list_type(el_ty) def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]: - node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + node.generators, node.elt, elt_ty = synthesize_comprehension( + node, node.generators, node.elt, self.ctx + ) result_ty = list_type(elt_ty) return node, result_ty + def visit_DesugaredGeneratorExpr( + self, node: DesugaredGeneratorExpr + ) -> tuple[ast.expr, Type]: + # This is a generator in an arbitrary expression position. We don't support + # generators as first-class value yet, so we always error out here. Special + # cases where generator are allowed need to explicitly check for them (e.g. see + # the handling of array comprehensions in the compiler for the `array` function) + raise GuppyError(UnsupportedError(node, "Generator expressions")) + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -665,7 +679,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]: node.value, [], "__iter__", "iterable", exp_sig, True ) # Unwrap the size hint if present - if is_sized_iter_type(ty): + if is_sized_iter_type(ty) and node.unwrap_size_hint: expr, ty = self.synthesize_instance_func(expr, [], "unwrap_iter", "") # If the iterator was created by a `for` loop, we can add some extra checks to @@ -1033,15 +1047,15 @@ def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type def synthesize_comprehension( - node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context -) -> tuple[DesugaredListComp, Type]: + node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context +) -> tuple[list[DesugaredGenerator], ast.expr, Type]: """Helper function to synthesise the element type of a list comprehension.""" from guppylang.checker.stmt_checker import StmtChecker # If there are no more generators left, we can check the list element if not gens: - node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt) - return node, elt_ty + elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt) + return gens, elt, elt_ty # Check the iterator in the outer context gen, *gens = gens @@ -1065,12 +1079,12 @@ def synthesize_comprehension( gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) # Check remaining generators - node, elt_ty = synthesize_comprehension(node, gens, inner_ctx) + gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx) # The iter finalizer is again checked in the outer context gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) gen.iterend = with_type(iterend_ty, gen.iterend) - return node, elt_ty + return [gen, *gens], elt, elt_ty def eval_py_expr(node: PyExpr, ctx: Context) -> Any: diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 17add223..5ff7534e 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -43,6 +43,7 @@ from guppylang.nodes import ( AnyCall, CheckedNestedFunctionDef, + DesugaredArrayComp, DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, @@ -417,7 +418,10 @@ def visit_Expr(self, node: ast.Expr) -> None: raise GuppyTypeError(err) def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: - self._check_comprehension(node, node.generators) + self._check_comprehension(node.generators, node.elt) + + def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> None: + self._check_comprehension([node.generator], node.elt) def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None: # Linearity of the nested function has already been checked. We just need to @@ -457,11 +461,11 @@ def _check_assign_targets(self, targets: list[ast.expr]) -> None: self.scope.assign(tgt_place) def _check_comprehension( - self, node: DesugaredListComp, gens: list[DesugaredGenerator] + self, gens: list[DesugaredGenerator], elt: ast.expr ) -> None: """Helper function to recursively check list comprehensions.""" if not gens: - self.visit(node.elt) + self.visit(elt) return # Check the iterator expression in the current scope @@ -502,7 +506,7 @@ def _check_comprehension( self.visit(expr) # Recursively check the remaining generators - self._check_comprehension(node, gens) + self._check_comprehension(gens, elt) # Check the iter finalizer so we record a final use of the iterator self.visit(gen.iterend) diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index ebaf0784..7e1c2fe3 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -5,10 +5,12 @@ from hugr import Wire, ops from hugr.build.dfg import DP, DefinitionBuilder, DfBase -from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable +from guppylang.checker.core import FieldAccess, Globals, Place, PlaceId, Variable from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId +from guppylang.definition.ty import TypeDef +from guppylang.definition.value import CompiledCallableDef from guppylang.error import InternalGuppyError -from guppylang.tys.ty import StructType +from guppylang.tys.ty import StructType, Type CompiledLocals = dict[PlaceId, Wire] @@ -26,15 +28,19 @@ class CompiledGlobals: compiled: dict[DefId, CompiledDef] worklist: set[DefId] + checked_globals: Globals + def __init__( self, checked: dict[DefId, CheckedDef], module: DefinitionBuilder[ops.Module], + checked_globals: Globals, ) -> None: self.module = module self.checked = checked self.worklist = set() self.compiled = {} + self.checked_globals = checked_globals def build_compiled_def(self, def_id: DefId) -> CompiledDef: """Returns the compiled definitions corresponding to the given ID. @@ -65,6 +71,16 @@ def compile(self, defn: CheckedDef) -> None: next_def = self.build_compiled_def(next_id) next_def.compile_inner(self) + def get_instance_func( + self, ty: Type | TypeDef, name: str + ) -> CompiledCallableDef | None: + checked_func = self.checked_globals.get_instance_func(ty, name) + if checked_func is None: + return None + compiled_func = self.build_compiled_def(checked_func.id) + assert isinstance(compiled_func, CompiledCallableDef) + return compiled_func + @dataclass class DFContainer: diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 4f607fa3..b1d24fe7 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -1,6 +1,7 @@ import ast from collections.abc import Iterable, Iterator, Sequence -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager +from functools import partial from typing import Any, TypeGuard, TypeVar import hugr @@ -15,7 +16,7 @@ from hugr.std.collections import ListVal from typing_extensions import assert_never -from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type +from guppylang.ast_util import AstNode, AstVisitor, get_type from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import Variable from guppylang.checker.errors.generic import UnsupportedError @@ -23,9 +24,14 @@ from guppylang.compiler.core import CompilerBase, DFContainer from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp from guppylang.definition.custom import CustomFunctionDef -from guppylang.definition.value import CompiledCallableDef, CompiledValueDef +from guppylang.definition.value import ( + CallReturnWires, + CompiledCallableDef, + CompiledValueDef, +) from guppylang.error import GuppyError, InternalGuppyError from guppylang.nodes import ( + DesugaredArrayComp, DesugaredGenerator, DesugaredListComp, FieldAccessAndDrop, @@ -41,12 +47,14 @@ TensorCall, TypeApply, ) +from guppylang.std._internal.compiler.array import array_repeat from guppylang.std._internal.compiler.list import ( list_new, - list_push, ) +from guppylang.tys.arg import Argument from guppylang.tys.builtin import ( get_element_type, + int_type, is_bool_type, is_list_type, ) @@ -59,6 +67,7 @@ InputFlags, NoneType, NumericType, + OpaqueType, TupleType, Type, type_to_row, @@ -459,68 +468,97 @@ def visit_ResultExpr(self, node: ResultExpr) -> Wire: return self._pack_returns([], NoneType()) def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire: - from guppylang.compiler.stmt_compiler import StmtCompiler - - compiler = StmtCompiler(self.globals) - # Make up a name for the list under construction and bind it to an empty list list_ty = get_type(node) + assert isinstance(list_ty, OpaqueType) elem_ty = get_element_type(list_ty) list_place = Variable(next(tmp_vars), list_ty, node) - list_name = with_type(list_ty, with_loc(node, PlaceNode(place=list_place))) self.dfg[list_place] = list_new(self.builder, elem_ty.to_hugr(), []) + with self._build_generators(node.generators, [list_place]): + elt_port = self.visit(node.elt) + list_port = self.dfg[list_place] + [], [self.dfg[list_place]] = self._build_method_call( + list_ty, "append", node, [list_port, elt_port], list_ty.args + ) + return self.dfg[list_place] + + def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> Wire: + # Allocate an uninitialised array of the desired size and a counter variable + array_ty = get_type(node) + assert isinstance(array_ty, OpaqueType) + array_var = Variable(next(tmp_vars), array_ty, node) + count_var = Variable(next(tmp_vars), int_type(), node) + # See https://github.com/CQCL/guppylang/issues/629 + hugr_elt_ty = ht.Option(node.elt_ty.to_hugr()) + # Initialise array with `None`s + make_none = self.builder.define_function("init_none", [], [hugr_elt_ty]) + make_none.set_outputs(make_none.add_op(ops.Tag(0, hugr_elt_ty))) + make_none = self.builder.load_function(make_none) + self.dfg[array_var] = self.builder.add_op( + array_repeat(hugr_elt_ty, node.length.to_arg().to_hugr()), make_none + ) + self.dfg[count_var] = self.builder.load( + hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH) + ) + with self._build_generators([node.generator], [array_var, count_var]): + elt = self.visit(node.elt) + array, count = self.dfg[array_var], self.dfg[count_var] + [], [self.dfg[array_var]] = self._build_method_call( + array_ty, "__setitem__", node, [array, count, elt], array_ty.args + ) + # Update `count += 1` + one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH)) + [self.dfg[count_var]], [] = self._build_method_call( + int_type(), "__add__", node, [count, one] + ) + return self.dfg[array_var] - def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: - """Helper function to generate nested TailLoop nodes for generators""" - # If there are no more generators left, just append the element to the list - if not gens: - list_port, elt_port = self.visit(list_name), self.visit(elt) - elt_ty = get_type(elt) - if elt_ty.linear: - elt_ty_opt = ht.Option(elt_ty.to_hugr()) - elt_opt_port = self.builder.add_op(ops.Tag(1, elt_ty_opt), elt_port) - push = self.builder.add_op( - list_push(elt_ty_opt), list_port, elt_opt_port - ) - else: - push = self.builder.add_op( - list_push(elt_ty.to_hugr()), list_port, elt_port - ) - self.dfg[list_place] = push - return - - # Otherwise, compile the first iterator and construct a TailLoop - gen, *gens = gens - compiler.compile_stmts([gen.iter_assign], self.dfg) - assert isinstance(gen.iter, PlaceNode) - assert isinstance(gen.hasnext, PlaceNode) - inputs = [gen.iter, list_name] - with self._new_loop(inputs, gen.hasnext): - # If there is a next element, compile it and continue with the next - # generator + def _build_method_call( + self, + ty: Type, + method: str, + node: AstNode, + args: list[Wire], + type_args: Sequence[Argument] | None = None, + ) -> CallReturnWires: + func = self.globals.get_instance_func(ty, method) + assert func is not None + return func.compile_call(args, type_args or [], self.dfg, self.globals, node) + + @contextmanager + def _build_generators( + self, gens: list[DesugaredGenerator], loop_vars: list[Variable] + ) -> Iterator[None]: + """Context manager to build and enter the `TailLoop`s for a list of generators. + + The provided `loop_vars` will be threaded through and will be available inside + the loops. + """ + from guppylang.compiler.stmt_compiler import StmtCompiler + + compiler = StmtCompiler(self.globals) + with ExitStack() as stack: + for gen in gens: + # Build the generator + compiler.compile_stmts([gen.iter_assign], self.dfg) + assert isinstance(gen.iter, PlaceNode) + assert isinstance(gen.hasnext, PlaceNode) + inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars] + # Remember to finalize the iterator once we are done with it. Note that + # we need to use partial in the callback, so that we bind the *current* + # value of `gen` instead of only last + stack.callback(partial(lambda gen: self.visit(gen.iterend), gen)) + # Enter a new tail loop + stack.enter_context(self._new_loop(inputs, gen.hasnext)) + # Enter a conditional checking if we have a next element compiler.compile_stmts([gen.hasnext_assign], self.dfg) - with self._if_true(gen.hasnext, inputs): - - def compile_ifs(ifs: list[ast.expr]) -> None: - """Helper function to compile a series of if-guards into nested - Conditional nodes.""" - if ifs: - if_expr, *ifs = ifs - # If the condition is true, continue with the next one - with self._if_true(if_expr, inputs): - compile_ifs(ifs) - else: - # If there are no guards left, compile the next generator - compile_generators(elt, gens) - - compiler.compile_stmts([gen.next_assign], self.dfg) - compile_ifs(gen.ifs) - - # After the loop is done, we have to finalize the iterator - self.visit(gen.iterend) - - compile_generators(node.elt, node.generators) - return self.visit(list_name) + stack.enter_context(self._if_true(gen.hasnext, inputs)) + compiler.compile_stmts([gen.next_assign], self.dfg) + # Enter nested conditionals for each if guard on the generator + for if_expr in gen.ifs: + stack.enter_context(self._if_true(if_expr, inputs)) + # Yield control to the caller to build inside the loop + yield def visit_BinOp(self, node: ast.BinOp) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/module.py b/guppylang/module.py index 792270c6..443b66db 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -340,7 +340,9 @@ def compile(self) -> ModulePointer: graph.metadata["name"] = self.name # Lower definitions to Hugr - ctx = CompiledGlobals(checked_defs, graph) + ctx = CompiledGlobals( + checked_defs, graph, self._imported_globals | self._globals + ) for defn in self._checked_defs.values(): ctx.compile(defn) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 7f286e0a..101be0ab 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -142,6 +142,7 @@ class MakeIter(ast.expr): """ value: ast.expr + unwrap_size_hint: bool # Node that triggered the creation of this iterator. For example, a for loop stmt. # It is not mentioned in `_fields` so that it is not visible to AST visitors @@ -149,6 +150,13 @@ class MakeIter(ast.expr): _fields = ("value",) + def __init__( + self, value: ast.expr, origin_node: ast.AST, unwrap_size_hint: bool = True + ) -> None: + super().__init__(value) + self.origin_node = origin_node + self.unwrap_size_hint = unwrap_size_hint + class IterHasNext(ast.expr): """Checks if an iterator has a next element using the `__hasnext__` magic method. @@ -210,6 +218,18 @@ class DesugaredGenerator(ast.expr): ) +class DesugaredGeneratorExpr(ast.expr): + """A desugared generator expression.""" + + elt: ast.expr + generators: list[DesugaredGenerator] + + _fields = ( + "elt", + "generators", + ) + + class DesugaredListComp(ast.expr): """A desugared list comprehension.""" @@ -222,6 +242,22 @@ class DesugaredListComp(ast.expr): ) +class DesugaredArrayComp(ast.expr): + """A desugared array comprehension.""" + + elt: ast.expr + generator: DesugaredGenerator + length: Const + elt_ty: Type + + _fields = ( + "elt", + "generator", + "length", + "elt_ty", + ) + + class PyExpr(ast.expr): """A compile-time evaluated `py(...)` expression.""" diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 9afc259b..3487585e 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -2,10 +2,15 @@ from dataclasses import dataclass from typing import ClassVar, cast +from typing_extensions import assert_never + from guppylang.ast_util import AstNode, with_loc, with_type from guppylang.checker.core import Context from guppylang.checker.errors.generic import ExpectedError, UnsupportedError -from guppylang.checker.errors.type_errors import TypeMismatchError +from guppylang.checker.errors.type_errors import ( + ArrayComprUnknownSizeError, + TypeMismatchError, +) from guppylang.checker.expr_checker import ( ExprChecker, ExprSynthesizer, @@ -13,6 +18,7 @@ check_num_args, check_type_against, synthesize_call, + synthesize_comprehension, ) from guppylang.definition.custom import ( CustomCallChecker, @@ -22,15 +28,24 @@ from guppylang.definition.struct import CheckedStructDef, RawStructDef from guppylang.diagnostic import Error, Note from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.nodes import GlobalCall, ResultExpr +from guppylang.nodes import ( + DesugaredArrayComp, + DesugaredGeneratorExpr, + GlobalCall, + MakeIter, + ResultExpr, +) from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.builtin import ( array_type, array_type_def, bool_type, + get_iter_size, int_type, is_array_type, is_bool_type, + is_sized_iter_type, + nat_type, sized_iter_type, ) from guppylang.tys.const import Const, ConstValue @@ -41,6 +56,7 @@ NumericType, StructType, Type, + unify, ) @@ -176,21 +192,28 @@ class Suggestion(Note): ) def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - if len(args) == 0: - err = NewArrayChecker.InferenceError(self.node) - err.add_sub_diagnostic(NewArrayChecker.InferenceError.Suggestion(None)) - raise GuppyTypeError(err) - [fst, *rest] = args - fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) - checker = ExprChecker(self.ctx) - for i in range(len(rest)): - rest[i], subst = checker.check(rest[i], ty) - assert len(subst) == 0, "Array element type is closed" - result_ty = array_type(ty, len(args)) - call = GlobalCall( - def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args - ) - return with_loc(self.node, call), result_ty + match args: + case []: + err = NewArrayChecker.InferenceError(self.node) + err.add_sub_diagnostic(NewArrayChecker.InferenceError.Suggestion(None)) + raise GuppyTypeError(err) + # Either an array comprehension + case [DesugaredGeneratorExpr() as compr]: + return self.synthesize_array_comprehension(compr) + # Or a list of array elements + case [fst, *rest]: + fst, ty = ExprSynthesizer(self.ctx).synthesize(fst) + checker = ExprChecker(self.ctx) + for i in range(len(rest)): + rest[i], subst = checker.check(rest[i], ty) + assert len(subst) == 0, "Array element type is closed" + result_ty = array_type(ty, len(args)) + call = GlobalCall( + def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args + ) + return with_loc(self.node, call), result_ty + case args: + return assert_never(args) # type: ignore[arg-type] def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: if not is_array_type(ty): @@ -200,22 +223,80 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: self.node, ) raise GuppyTypeError(TypeMismatchError(self.node, ty, dummy_array_ty)) + subst: Subst = {} match ty.args: - case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]: - subst: Subst = {} - checker = ExprChecker(self.ctx) - for i in range(len(args)): - args[i], s = checker.check(args[i], elem_ty.substitute(subst)) - subst |= s - if len(args) != length: - raise GuppyTypeError( - TypeMismatchError(self.node, ty, array_type(elem_ty, len(args))) - ) - call = GlobalCall(def_id=self.func.id, args=args, type_args=ty.args) - return with_loc(self.node, call), subst + case [TypeArg(ty=elem_ty), ConstArg(length)]: + match args: + # Either an array comprehension + case [DesugaredGeneratorExpr() as compr]: + # TODO: We could use the type information to infer some stuff + # in the comprehension + arr_compr, res_ty = self.synthesize_array_comprehension(compr) + subst, _ = check_type_against(res_ty, ty, self.node) + return arr_compr, subst + # Or a list of array elements + case args: + checker = ExprChecker(self.ctx) + for i in range(len(args)): + args[i], s = checker.check( + args[i], elem_ty.substitute(subst) + ) + subst |= s + ls = unify(length, ConstValue(nat_type(), len(args)), {}) + if ls is None: + raise GuppyTypeError( + TypeMismatchError( + self.node, ty, array_type(elem_ty, len(args)) + ) + ) + subst |= ls + call = GlobalCall( + def_id=self.func.id, args=args, type_args=ty.args + ) + return with_loc(self.node, call), subst case type_args: raise InternalGuppyError(f"Invalid array type args: {type_args}") + def synthesize_array_comprehension( + self, compr: DesugaredGeneratorExpr + ) -> tuple[DesugaredArrayComp, Type]: + # Array comprehensions require a static size. To keep things simple, we'll only + # allow a single generator for now, so we don't have to reason about products + # of iterator sizes. + if len(compr.generators) > 1: + # Individual generator objects unfortunately don't have a span in Python's + # AST, so we have to use the whole expression span + raise GuppyError(UnsupportedError(compr, "Nested array comprehensions")) + [gen] = compr.generators + # Similarly, dynamic if guards are not allowed + if gen.ifs: + err = ArrayComprUnknownSizeError(compr) + err.add_sub_diagnostic(ArrayComprUnknownSizeError.IfGuard(gen.ifs[0])) + raise GuppyError(err) + # Extract the iterator size + match gen.iter_assign: + case ast.Assign(value=MakeIter() as make_iter): + sized_make_iter = MakeIter( + make_iter.value, make_iter.origin_node, unwrap_size_hint=False + ) + _, iter_ty = ExprSynthesizer(self.ctx).synthesize(sized_make_iter) + # The iterator must have a static size hint + if not is_sized_iter_type(iter_ty): + err = ArrayComprUnknownSizeError(compr) + err.add_sub_diagnostic( + ArrayComprUnknownSizeError.DynamicIterator(make_iter) + ) + raise GuppyError(err) + size = get_iter_size(iter_ty) + case _: + raise InternalGuppyError("Invalid iterator assign statement") + # Finally, type check the comprehension + [gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx) + array_compr = DesugaredArrayComp( + elt=elt, generator=gen, length=size, elt_ty=elt_ty + ) + return with_loc(compr, array_compr), array_type(elt_ty, size) + #: Maximum length of a tag in the `result` function. TAG_MAX_LEN = 200 diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index 6f32ff0e..b6d90b12 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -82,6 +82,16 @@ def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops ).ext_op +def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: + """Returns an array `repeat` operation.""" + # TODO + return UnsupportedOp( + op_name="array.repeat", + inputs=[ht.FunctionType([], [elem_ty])], + outputs=[array_type(elem_ty, length)], + ).ext_op + + # ------------------------------------------------------ # --------- Custom compilers for non-native ops -------- # ------------------------------------------------------ diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index ec6b63f1..46a09ee8 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -94,7 +94,7 @@ class nat: class array(Generic[_T, _n]): """Class to import in order to use arrays.""" - def __init__(self, *args: _T): + def __init__(self, *args: Any): pass @@ -594,7 +594,7 @@ def _array_unsafe_getitem(xs: array[L, n], idx: int) -> L: ... @guppy.extend_type(sized_iter_type_def) class SizedIter: - """A wrapper around an iterator type `T` promising that the iterator will yield + """A wrapper around an iterator type `L` promising that the iterator will yield exactly `n` values. Annotating an iterator with an incorrect size is undefined behaviour. @@ -614,8 +614,8 @@ def unwrap_iter(self: "SizedIter[L, n]" @ owned) -> L: """Extracts the actual iterator.""" @guppy.custom(NoopCompiler()) - def __iter__(self: "SizedIter[L, n]" @ owned) -> L: - """Extracts the actual iterator.""" + def __iter__(self: "SizedIter[L, n]" @ owned) -> "SizedIter[L, n]": # type: ignore[type-arg] + """Dummy implementation making sized iterators iterable themselves.""" # TODO: This is a temporary hack until we have implemented the proper results mechanism. diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index a42aa877..89ffc5c8 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -12,7 +12,7 @@ from guppylang.error import GuppyError, InternalGuppyError from guppylang.experimental import check_lists_enabled from guppylang.tys.arg import Argument, ConstArg, TypeArg -from guppylang.tys.const import ConstValue +from guppylang.tys.const import Const, ConstValue from guppylang.tys.errors import WrongNumberOfTypeArgsError from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( @@ -203,6 +203,10 @@ def bool_type() -> OpaqueType: return OpaqueType([], bool_type_def) +def nat_type() -> NumericType: + return NumericType(NumericType.Kind.Nat) + + def int_type() -> NumericType: return NumericType(NumericType.Kind.Int) @@ -211,18 +215,16 @@ def list_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], list_type_def) -def array_type(element_ty: Type, length: int) -> OpaqueType: - nat_type = NumericType(NumericType.Kind.Nat) - return OpaqueType( - [TypeArg(element_ty), ConstArg(ConstValue(nat_type, length))], array_type_def - ) +def array_type(element_ty: Type, length: int | Const) -> OpaqueType: + if isinstance(length, int): + length = ConstValue(nat_type(), length) + return OpaqueType([TypeArg(element_ty), ConstArg(length)], array_type_def) -def sized_iter_type(iter_type: Type, size: int) -> OpaqueType: - nat_type = NumericType(NumericType.Kind.Nat) - return OpaqueType( - [TypeArg(iter_type), ConstArg(ConstValue(nat_type, size))], sized_iter_type_def - ) +def sized_iter_type(iter_type: Type, size: int | Const) -> OpaqueType: + if isinstance(size, int): + size = ConstValue(nat_type(), size) + return OpaqueType([TypeArg(iter_type), ConstArg(size)], sized_iter_type_def) def is_bool_type(ty: Type) -> bool: @@ -243,17 +245,17 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]: def get_element_type(ty: Type) -> Type: assert isinstance(ty, OpaqueType) - assert ty.defn == list_type_def - (arg,) = ty.args + assert ty.defn == list_type_def or ty.defn == array_type_def + (arg, *_) = ty.args assert isinstance(arg, TypeArg) return arg.ty -def get_iter_size(ty: Type) -> int: +def get_iter_size(ty: Type) -> Const: assert isinstance(ty, OpaqueType) assert ty.defn == sized_iter_type_def match ty.args: - case [_, ConstArg(ConstValue(value=int(size)))]: - return size + case [_, ConstArg(const)]: + return const case _: raise InternalGuppyError("Unexpected type args") diff --git a/tests/error/array_errors/comprehension_wrong_length.err b/tests/error/array_errors/comprehension_wrong_length.err new file mode 100644 index 00000000..c8cfdd68 --- /dev/null +++ b/tests/error/array_errors/comprehension_wrong_length.err @@ -0,0 +1,9 @@ +Error: Type mismatch (at $FILE:13:11) + | +11 | @guppy(module) +12 | def main() -> array[int, 42]: +13 | return array(i for i in range(10)) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ Expected expression of type `array[int, 42]`, got + | `array[int, 10]` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/array_errors/comprehension_wrong_length.py b/tests/error/array_errors/comprehension_wrong_length.py new file mode 100644 index 00000000..2714de4a --- /dev/null +++ b/tests/error/array_errors/comprehension_wrong_length.py @@ -0,0 +1,16 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 42]: + return array(i for i in range(10)) + + +module.compile() diff --git a/tests/error/array_errors/guarded_comprehension.err b/tests/error/array_errors/guarded_comprehension.err new file mode 100644 index 00000000..aea96825 --- /dev/null +++ b/tests/error/array_errors/guarded_comprehension.err @@ -0,0 +1,11 @@ +Error: Array comprehension with nonstatic size (at $FILE:13:9) + | +11 | @guppy(module) +12 | def main() -> None: +13 | array(i for i in range(100) if i % 2 == 0) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Cannot infer the size of this array comprehension ... + | +13 | array(i for i in range(100) if i % 2 == 0) + | ---------- since it depends on this condition + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/array_errors/guarded_comprehension.py b/tests/error/array_errors/guarded_comprehension.py new file mode 100644 index 00000000..21557107 --- /dev/null +++ b/tests/error/array_errors/guarded_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> None: + array(i for i in range(100) if i % 2 == 0) + + +module.compile() diff --git a/tests/error/array_errors/nested_comprehension.err b/tests/error/array_errors/nested_comprehension.err new file mode 100644 index 00000000..24d545da --- /dev/null +++ b/tests/error/array_errors/nested_comprehension.err @@ -0,0 +1,8 @@ +Error: Unsupported (at $FILE:13:16) + | +11 | @guppy(module) +12 | def main() -> array[int, 50]: +13 | return array(0 for _ in range(10) for _ in range(5)) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Nested array comprehensions are not supported + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/array_errors/nested_comprehension.py b/tests/error/array_errors/nested_comprehension.py new file mode 100644 index 00000000..a643d96c --- /dev/null +++ b/tests/error/array_errors/nested_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main() -> array[int, 50]: + return array(0 for _ in range(10) for _ in range(5)) + + +module.compile() diff --git a/tests/error/array_errors/non_static_comprehension.err b/tests/error/array_errors/non_static_comprehension.err new file mode 100644 index 00000000..0a7e8eae --- /dev/null +++ b/tests/error/array_errors/non_static_comprehension.err @@ -0,0 +1,12 @@ +Error: Array comprehension with nonstatic size (at $FILE:13:9) + | +11 | @guppy(module) +12 | def main(n: int) -> None: +13 | array(i for i in range(n)) + | ^^^^^^^^^^^^^^^^^^^^^ Cannot infer the size of this array comprehension ... + | +13 | array(i for i in range(n)) + | -------- since the number of elements yielded by this iterator is + | not statically known + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/array_errors/non_static_comprehension.py b/tests/error/array_errors/non_static_comprehension.py new file mode 100644 index 00000000..0752b316 --- /dev/null +++ b/tests/error/array_errors/non_static_comprehension.py @@ -0,0 +1,16 @@ +import guppylang.std.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main(n: int) -> None: + array(i for i in range(n)) + + +module.compile() diff --git a/tests/integration/test_array_comprehension.py b/tests/integration/test_array_comprehension.py new file mode 100644 index 00000000..bdd4a05a --- /dev/null +++ b/tests/integration/test_array_comprehension.py @@ -0,0 +1,90 @@ +import pytest + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array +from guppylang.std.quantum import qubit + +import guppylang.std.quantum_functional as quantum +from tests.util import compile_guppy + + +def test_basic(validate): + @compile_guppy + def test() -> array[int, 10]: + return array(i + 1 for i in range(10)) + + validate(test) + + +def test_basic_linear(validate): + module = GuppyModule("test") + module.load_all(quantum) + module.load(qubit) + + @guppy(module) + def test() -> array[qubit, 42]: + return array(qubit() for _ in range(42)) + + validate(module.compile()) + + +def test_zero_length(validate): + @compile_guppy + def test() -> array[float, 0]: + return array(i / 0 for i in range(0)) + + validate(test) + + +def test_capture(validate): + @compile_guppy + def test(x: int) -> array[int, 42]: + return array(i + x for i in range(42)) + + validate(test) + + +@pytest.mark.skip("See https://github.com/CQCL/hugr/issues/1625") +def test_capture_struct(validate): + module = GuppyModule("test") + + @guppy.struct(module) + class MyStruct: + x: int + y: float + + @guppy(module) + def test(s: MyStruct) -> array[int, 42]: + return array(i + s.x for i in range(42)) + + validate(module.compile()) + + +def test_scope(validate): + @compile_guppy + def test() -> float: + x = 42.0 + array(x for x in range(10)) + return x + + validate(test) + + +def test_nested_left(validate): + @compile_guppy + def test() -> array[array[int, 10], 20]: + return array(array(x + y for y in range(10)) for x in range(20)) + + validate(test) + + +def test_generic(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module) + + @guppy(module) + def test(xs: array[int, n]) -> array[int, n]: + return array(x + 1 for x in xs) + + validate(module.compile())