diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index fccb2cd9..1b92c01a 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -3,7 +3,15 @@ import itertools from collections.abc import Iterable, Iterator, Mapping from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, TypeVar +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + TypeAlias, + TypeVar, +) from typing_extensions import assert_never @@ -37,15 +45,104 @@ Type, ) +if TYPE_CHECKING: + from guppylang.definition.struct import StructField + + +#: A "place" is a description for a storage location of a local value that users +#: can refer to in their program. +#: +#: Roughly, these are values that can be lowered to a static wire within the Hugr +#: representation. The most basic example of a place is a single local variable. Beyond +#: that, we also treat some projections of local variables (e.g. nested struct field +#: accesses) as places. +#: +#: All places are equipped with a unique id, a type and an optional definition AST +#: location. During linearity checking, they are tracked separately. +Place: TypeAlias = "Variable | FieldAccess" + +#: Unique identifier for a `Place`. +PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id" + @dataclass(frozen=True) class Variable: - """Class holding data associated with a local variable.""" + """A place identifying a local variable.""" name: str ty: Type defined_at: AstNode | None + @dataclass(frozen=True) + class Id: + """Identifier for variable places.""" + + name: str + + @cached_property + def id(self) -> "Variable.Id": + """The unique `PlaceId` identifier for this place.""" + return Variable.Id(self.name) + + @property + def describe(self) -> str: + """A human-readable description of this place for error messages.""" + return f"Variable `{self}`" + + def __str__(self) -> str: + """String representation of this place.""" + return self.name + + +@dataclass(frozen=True) +class FieldAccess: + """A place identifying a field access on a local struct.""" + + parent: Place + field: "StructField" + exact_defined_at: AstNode | None + + @dataclass(frozen=True) + class Id: + """Identifier for field places.""" + + parent: PlaceId + field: str + + def __post_init__(self) -> None: + # Check that the field access is consistent + assert self.struct_ty.field_dict[self.field.name] == self.field + + @cached_property + def id(self) -> "FieldAccess.Id": + """The unique `PlaceId` identifier for this place.""" + return FieldAccess.Id(self.parent.id, self.field.name) + + @property + def ty(self) -> Type: + """The type of this place.""" + return self.field.ty + + @cached_property + def struct_ty(self) -> StructType: + """The type of the struct whose field is accessed.""" + assert isinstance(self.parent.ty, StructType) + return self.parent.ty + + @cached_property + def defined_at(self) -> AstNode | None: + """Optional location where this place was last assigned to.""" + return self.exact_defined_at or self.parent.defined_at + + @cached_property + def describe(self) -> str: + """A human-readable description of this place for error messages.""" + return f"Field `{self}`" + + def __str__(self) -> str: + """String representation of this place.""" + return f"{self.parent}.{self.field.name}" + PyScope = dict[str, Any] diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index ae67091b..c222b487 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -38,6 +38,7 @@ from guppylang.checker.core import ( Context, DummyEvalDict, + FieldAccess, Globals, Locals, Variable, @@ -53,13 +54,14 @@ from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, + FieldAccessAndDrop, GlobalName, IterEnd, IterHasNext, IterNext, LocalCall, - LocalName, MakeIter, + PlaceNode, PyExpr, TensorCall, TypeApply, @@ -81,6 +83,7 @@ FunctionType, NoneType, OpaqueType, + StructType, TupleType, Type, TypeBase, @@ -329,11 +332,11 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]: raise GuppyError("Unsupported constant", node) return node, ty - def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: + def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]: x = node.id if x in self.ctx.locals: var = self.ctx.locals[x] - return with_loc(node, LocalName(id=x)), var.ty + return with_loc(node, PlaceNode(place=var)), var.ty elif x in self.ctx.globals: # Look-up what kind of definition it is match self.ctx.globals[x]: @@ -353,6 +356,28 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: "been caught by program analysis!" ) + def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]: + # A `value.attr` attribute access + node.value, ty = self.synthesize(node.value) + if isinstance(ty, StructType) and node.attr in ty.field_dict: + field = ty.field_dict[node.attr] + expr: ast.expr + if isinstance(node.value, PlaceNode): + # Field access on a place is itself a place + expr = PlaceNode(place=FieldAccess(node.value.place, field, None)) + else: + # If the struct is not in a place, then there is no way to address the + # other fields after this one has been projected (e.g. `f().a` makes + # you loose access to all fields besides `a`). + expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field) + return with_loc(node, expr), field.ty + raise GuppyTypeError( + f"Expression of type `{ty}` has no attribute `{node.attr}`", + # Unfortunately, `node.attr` doesn't contain source annotations, so we have + # to use `node` as the error location + node, + ) + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: elems = [self.synthesize(elem) for elem in node.elts] @@ -882,9 +907,9 @@ def synthesize_comprehension( expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) - gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext) + gen.hasnext, hasnext_ty = expr_sth.visit(gen.hasnext) gen.hasnext = with_type(hasnext_ty, gen.hasnext) - gen.iter, iter_ty = expr_sth.visit_Name(gen.iter) + gen.iter, iter_ty = expr_sth.visit(gen.iter) gen.iter = with_type(iter_ty, gen.iter) # Check `if` guards diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index 1f7d5050..b6b6cb46 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -11,15 +11,15 @@ import ast from collections.abc import Sequence -from guppylang.ast_util import AstVisitor, with_loc, with_type +from guppylang.ast_util import AstVisitor, with_loc from guppylang.cfg.bb import BB, BBStatement -from guppylang.checker.core import Context, Variable +from guppylang.checker.core import Context, FieldAccess, Variable from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.nodes import NestedFunctionDef +from guppylang.nodes import NestedFunctionDef, PlaceNode from guppylang.tys.parsing import type_from_ast from guppylang.tys.subst import Subst -from guppylang.tys.ty import NoneType, TupleType, Type +from guppylang.tys.ty import NoneType, StructType, TupleType, Type class StmtChecker(AstVisitor[BBStatement]): @@ -46,17 +46,57 @@ def _check_expr( ) -> tuple[ast.expr, Subst]: return ExprChecker(self.ctx).check(node, ty, kind) - def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None: + def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr: """Helper function to check assignments with patterns.""" match lhs: # Easiest case is if the LHS pattern is a single variable. case ast.Name(id=x): - # Store the type in the AST - with_type(ty, lhs) - self.ctx.locals[x] = Variable(x, ty, lhs) + var = Variable(x, ty, lhs) + self.ctx.locals[x] = var + return with_loc(lhs, PlaceNode(place=var)) + + # The LHS could also be a field `expr.field` + case ast.Attribute(value=value, attr=attr): + value, struct_ty = self._synth_expr(value) + if ( + not isinstance(struct_ty, StructType) + or attr not in struct_ty.field_dict + ): + raise GuppyTypeError( + f"Expression of type `{struct_ty}` has no attribute `{attr}`", + # Unfortunately, `attr` doesn't contain source annotations, so + # we have to use `lhs` as the error location + lhs, + ) + field = struct_ty.field_dict[attr] + # TODO: In the future, we could infer some type args here + if field.ty != ty: + raise GuppyTypeError( + f"Cannot assign expression of type `{ty}` to field with type " + f"`{field.ty}`", + lhs, + ) + if not isinstance(value, PlaceNode): + # For now we complain if someone tries to assign to something that + # is not a place, e.g. `f().a = 4`. This would only make sense if + # there is another reference to the return value of `f`, otherwise + # the mutation cannot be observed. We can start supporting this once + # we have proper reference semantics. + raise GuppyError( + "Assigning to this expression is not supported yet. Consider " + "binding the expression to variable and mutate that variable " + "instead.", + value + ) + if not field.ty.linear: + raise GuppyError( + "Mutation of classical fields is not supported yet", lhs + ) + place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs) + return with_loc(lhs, PlaceNode(place=place)) # The only other thing we support right now are tuples - case ast.Tuple(elts=elts): + case ast.Tuple(elts=elts) as lhs: tys = ty.element_types if isinstance(ty, TupleType) else [ty] n, m = len(elts), len(tys) if n != m: @@ -65,8 +105,11 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None: f"(expected {n}, got {m})", node, ) - for pat, el_ty in zip(elts, tys, strict=True): + lhs.elts = [ self._check_assign(pat, el_ty, node) + for pat, el_ty in zip(elts, tys, strict=True) + ] + return lhs # TODO: Python also supports assignments like `[a, b] = [1, 2]` or # `a, *b = ...`. The former would require some runtime checks but @@ -81,7 +124,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign: [target] = node.targets node.value, ty = self._synth_expr(node.value) - self._check_assign(target, ty, node) + node.targets = [self._check_assign(target, ty, node)] return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: @@ -93,8 +136,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: node.value, subst = self._check_expr(node.value, ty) assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 - self._check_assign(node.target, ty, node) - return node + target = self._check_assign(node.target, ty, node) + return with_loc(node, ast.Assign(targets=[target], value=node.value)) def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: bin_op = with_loc( diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 286d7bf0..3885c822 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -5,19 +5,20 @@ from typing import TYPE_CHECKING, Any from guppylang.tys.subst import Inst -from guppylang.tys.ty import FunctionType, Type +from guppylang.tys.ty import FunctionType, StructType, Type if TYPE_CHECKING: from guppylang.cfg.cfg import CFG from guppylang.checker.cfg_checker import CheckedCFG - from guppylang.checker.core import Variable + from guppylang.checker.core import Place, Variable from guppylang.definition.common import DefId + from guppylang.definition.struct import StructField -class LocalName(ast.Name): - id: str +class PlaceNode(ast.expr): + place: "Place" - _fields = ("id",) + _fields = ("place",) class GlobalName(ast.Name): @@ -77,6 +78,20 @@ class TypeApply(ast.expr): ) +class FieldAccessAndDrop(ast.expr): + """A field access on a struct, dropping all the remaining other fields.""" + + value: ast.expr + struct_ty: "StructType" + field: "StructField" + + _fields = ( + "value", + "struct_ty", + "field", + ) + + class MakeIter(ast.expr): """Creates an iterator using the `__iter__` magic method. @@ -137,8 +152,8 @@ class DesugaredGenerator(ast.expr): hasnext_assign: ast.Assign next_assign: ast.Assign iterend: ast.expr - iter: ast.Name - hasnext: ast.Name + iter: ast.expr + hasnext: ast.expr ifs: list[ast.expr] _fields = ( diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 50a1dcb0..f8625ebd 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -561,6 +561,11 @@ def fields(self) -> list["StructField"]: inst = Instantiator(self.args) return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields] + @cached_property + def field_dict(self) -> "dict[str, StructField]": + """Mapping from names to fields of this struct type.""" + return {field.name: field for field in self.fields} + @cached_property def intrinsically_linear(self) -> bool: """Whether this type is linear, independent of the arguments."""