Skip to content

Commit

Permalink
feat: Add Place union and lower AST to it during checking (#289)
Browse files Browse the repository at this point in the history
See #295 for context.

The tests are at #293
  • Loading branch information
mark-koch authored Jul 23, 2024
1 parent e965679 commit acf1242
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 27 deletions.
101 changes: 99 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import itertools
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass
from typing import Any, Generic, NamedTuple, TypeVar
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Generic,
NamedTuple,
TypeAlias,
TypeVar,
)

from typing_extensions import assert_never

Expand Down Expand Up @@ -37,15 +45,104 @@
Type,
)

if TYPE_CHECKING:
from guppylang.definition.struct import StructField


#: A "place" is a description for a storage location of a local value that users
#: can refer to in their program.
#:
#: Roughly, these are values that can be lowered to a static wire within the Hugr
#: representation. The most basic example of a place is a single local variable. Beyond
#: that, we also treat some projections of local variables (e.g. nested struct field
#: accesses) as places.
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"


@dataclass(frozen=True)
class Variable:
"""Class holding data associated with a local variable."""
"""A place identifying a local variable."""

name: str
ty: Type
defined_at: AstNode | None

@dataclass(frozen=True)
class Id:
"""Identifier for variable places."""

name: str

@cached_property
def id(self) -> "Variable.Id":
"""The unique `PlaceId` identifier for this place."""
return Variable.Id(self.name)

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Variable `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return self.name


@dataclass(frozen=True)
class FieldAccess:
"""A place identifying a field access on a local struct."""

parent: Place
field: "StructField"
exact_defined_at: AstNode | None

@dataclass(frozen=True)
class Id:
"""Identifier for field places."""

parent: PlaceId
field: str

def __post_init__(self) -> None:
# Check that the field access is consistent
assert self.struct_ty.field_dict[self.field.name] == self.field

@cached_property
def id(self) -> "FieldAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return FieldAccess.Id(self.parent.id, self.field.name)

@property
def ty(self) -> Type:
"""The type of this place."""
return self.field.ty

@cached_property
def struct_ty(self) -> StructType:
"""The type of the struct whose field is accessed."""
assert isinstance(self.parent.ty, StructType)
return self.parent.ty

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.exact_defined_at or self.parent.defined_at

@cached_property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Field `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}.{self.field.name}"


PyScope = dict[str, Any]

Expand Down
35 changes: 30 additions & 5 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Variable,
Expand All @@ -53,13 +54,14 @@
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
IterEnd,
IterHasNext,
IterNext,
LocalCall,
LocalName,
MakeIter,
PlaceNode,
PyExpr,
TensorCall,
TypeApply,
Expand All @@ -81,6 +83,7 @@
FunctionType,
NoneType,
OpaqueType,
StructType,
TupleType,
Type,
TypeBase,
Expand Down Expand Up @@ -329,11 +332,11 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]:
raise GuppyError("Unsupported constant", node)
return node, ty

def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
x = node.id
if x in self.ctx.locals:
var = self.ctx.locals[x]
return with_loc(node, LocalName(id=x)), var.ty
return with_loc(node, PlaceNode(place=var)), var.ty
elif x in self.ctx.globals:
# Look-up what kind of definition it is
match self.ctx.globals[x]:
Expand All @@ -353,6 +356,28 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
"been caught by program analysis!"
)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# A `value.attr` attribute access
node.value, ty = self.synthesize(node.value)
if isinstance(ty, StructType) and node.attr in ty.field_dict:
field = ty.field_dict[node.attr]
expr: ast.expr
if isinstance(node.value, PlaceNode):
# Field access on a place is itself a place
expr = PlaceNode(place=FieldAccess(node.value.place, field, None))
else:
# If the struct is not in a place, then there is no way to address the
# other fields after this one has been projected (e.g. `f().a` makes
# you loose access to all fields besides `a`).
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
return with_loc(node, expr), field.ty
raise GuppyTypeError(
f"Expression of type `{ty}` has no attribute `{node.attr}`",
# Unfortunately, `node.attr` doesn't contain source annotations, so we have
# to use `node` as the error location
node,
)

def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
elems = [self.synthesize(elem) for elem in node.elts]

Expand Down Expand Up @@ -882,9 +907,9 @@ def synthesize_comprehension(
expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign)
gen.next_assign = stmt_chk.visit_Assign(gen.next_assign)
gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext)
gen.hasnext, hasnext_ty = expr_sth.visit(gen.hasnext)
gen.hasnext = with_type(hasnext_ty, gen.hasnext)
gen.iter, iter_ty = expr_sth.visit_Name(gen.iter)
gen.iter, iter_ty = expr_sth.visit(gen.iter)
gen.iter = with_type(iter_ty, gen.iter)

# Check `if` guards
Expand Down
69 changes: 56 additions & 13 deletions guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import ast
from collections.abc import Sequence

from guppylang.ast_util import AstVisitor, with_loc, with_type
from guppylang.ast_util import AstVisitor, with_loc
from guppylang.cfg.bb import BB, BBStatement
from guppylang.checker.core import Context, Variable
from guppylang.checker.core import Context, FieldAccess, Variable
from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import NestedFunctionDef
from guppylang.nodes import NestedFunctionDef, PlaceNode
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.subst import Subst
from guppylang.tys.ty import NoneType, TupleType, Type
from guppylang.tys.ty import NoneType, StructType, TupleType, Type


class StmtChecker(AstVisitor[BBStatement]):
Expand All @@ -46,17 +46,57 @@ def _check_expr(
) -> tuple[ast.expr, Subst]:
return ExprChecker(self.ctx).check(node, ty, kind)

def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None:
def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr:
"""Helper function to check assignments with patterns."""
match lhs:
# Easiest case is if the LHS pattern is a single variable.
case ast.Name(id=x):
# Store the type in the AST
with_type(ty, lhs)
self.ctx.locals[x] = Variable(x, ty, lhs)
var = Variable(x, ty, lhs)
self.ctx.locals[x] = var
return with_loc(lhs, PlaceNode(place=var))

# The LHS could also be a field `expr.field`
case ast.Attribute(value=value, attr=attr):
value, struct_ty = self._synth_expr(value)
if (
not isinstance(struct_ty, StructType)
or attr not in struct_ty.field_dict
):
raise GuppyTypeError(
f"Expression of type `{struct_ty}` has no attribute `{attr}`",
# Unfortunately, `attr` doesn't contain source annotations, so
# we have to use `lhs` as the error location
lhs,
)
field = struct_ty.field_dict[attr]
# TODO: In the future, we could infer some type args here
if field.ty != ty:
raise GuppyTypeError(
f"Cannot assign expression of type `{ty}` to field with type "
f"`{field.ty}`",
lhs,
)
if not isinstance(value, PlaceNode):
# For now we complain if someone tries to assign to something that
# is not a place, e.g. `f().a = 4`. This would only make sense if
# there is another reference to the return value of `f`, otherwise
# the mutation cannot be observed. We can start supporting this once
# we have proper reference semantics.
raise GuppyError(
"Assigning to this expression is not supported yet. Consider "
"binding the expression to variable and mutate that variable "
"instead.",
value
)
if not field.ty.linear:
raise GuppyError(
"Mutation of classical fields is not supported yet", lhs
)
place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs)
return with_loc(lhs, PlaceNode(place=place))

# The only other thing we support right now are tuples
case ast.Tuple(elts=elts):
case ast.Tuple(elts=elts) as lhs:
tys = ty.element_types if isinstance(ty, TupleType) else [ty]
n, m = len(elts), len(tys)
if n != m:
Expand All @@ -65,8 +105,11 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None:
f"(expected {n}, got {m})",
node,
)
for pat, el_ty in zip(elts, tys, strict=True):
lhs.elts = [
self._check_assign(pat, el_ty, node)
for pat, el_ty in zip(elts, tys, strict=True)
]
return lhs

# TODO: Python also supports assignments like `[a, b] = [1, 2]` or
# `a, *b = ...`. The former would require some runtime checks but
Expand All @@ -81,7 +124,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:

[target] = node.targets
node.value, ty = self._synth_expr(node.value)
self._check_assign(target, ty, node)
node.targets = [self._check_assign(target, ty, node)]
return node

def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
Expand All @@ -93,8 +136,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
node.value, subst = self._check_expr(node.value, ty)
assert not ty.unsolved_vars # `ty` must be closed!
assert len(subst) == 0
self._check_assign(node.target, ty, node)
return node
target = self._check_assign(node.target, ty, node)
return with_loc(node, ast.Assign(targets=[target], value=node.value))

def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt:
bin_op = with_loc(
Expand Down
29 changes: 22 additions & 7 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
from typing import TYPE_CHECKING, Any

from guppylang.tys.subst import Inst
from guppylang.tys.ty import FunctionType, Type
from guppylang.tys.ty import FunctionType, StructType, Type

if TYPE_CHECKING:
from guppylang.cfg.cfg import CFG
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Variable
from guppylang.checker.core import Place, Variable
from guppylang.definition.common import DefId
from guppylang.definition.struct import StructField


class LocalName(ast.Name):
id: str
class PlaceNode(ast.expr):
place: "Place"

_fields = ("id",)
_fields = ("place",)


class GlobalName(ast.Name):
Expand Down Expand Up @@ -77,6 +78,20 @@ class TypeApply(ast.expr):
)


class FieldAccessAndDrop(ast.expr):
"""A field access on a struct, dropping all the remaining other fields."""

value: ast.expr
struct_ty: "StructType"
field: "StructField"

_fields = (
"value",
"struct_ty",
"field",
)


class MakeIter(ast.expr):
"""Creates an iterator using the `__iter__` magic method.
Expand Down Expand Up @@ -137,8 +152,8 @@ class DesugaredGenerator(ast.expr):
hasnext_assign: ast.Assign
next_assign: ast.Assign
iterend: ast.expr
iter: ast.Name
hasnext: ast.Name
iter: ast.expr
hasnext: ast.expr
ifs: list[ast.expr]

_fields = (
Expand Down
5 changes: 5 additions & 0 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,11 @@ def fields(self) -> list["StructField"]:
inst = Instantiator(self.args)
return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields]

@cached_property
def field_dict(self) -> "dict[str, StructField]":
"""Mapping from names to fields of this struct type."""
return {field.name: field for field in self.fields}

@cached_property
def intrinsically_linear(self) -> bool:
"""Whether this type is linear, independent of the arguments."""
Expand Down

0 comments on commit acf1242

Please sign in to comment.