Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Place union and lower AST to it during checking #289

Merged
merged 16 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import itertools
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass
from typing import Any, Generic, NamedTuple, TypeVar
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Generic,
NamedTuple,
TypeAlias,
TypeVar,
)

from typing_extensions import assert_never

Expand Down Expand Up @@ -37,15 +45,104 @@
Type,
)

if TYPE_CHECKING:
from guppylang.definition.struct import StructField


#: A "place" is a description for a storage location of a local value that users
#: can refer to in their program.
#:
#: Roughly, these are values that can be lowered to a static wire within the Hugr
#: representation. The most basic example of a place is a single local variable. Beyond
#: that, we also treat some projections of local variables (e.g. nested struct field
#: accesses) as places.
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"


@dataclass(frozen=True)
class Variable:
"""Class holding data associated with a local variable."""
"""A place identifying a local variable."""

name: str
ty: Type
defined_at: AstNode | None

@dataclass(frozen=True)
class Id:
"""Identifier for variable places."""

name: str

@cached_property
def id(self) -> "Variable.Id":
"""The unique `PlaceId` identifier for this place."""
return Variable.Id(self.name)

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Variable `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return self.name


@dataclass(frozen=True)
class FieldAccess:
"""A place identifying a field access on a local struct."""

parent: Place
field: "StructField"
exact_defined_at: AstNode | None

@dataclass(frozen=True)
class Id:
"""Identifier for field places."""

parent: PlaceId
field: str

def __post_init__(self) -> None:
# Check that the field access is consistent
assert self.struct_ty.field_dict[self.field.name] == self.field

@cached_property
def id(self) -> "FieldAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return FieldAccess.Id(self.parent.id, self.field.name)

@property
def ty(self) -> Type:
"""The type of this place."""
return self.field.ty

@cached_property
def struct_ty(self) -> StructType:
"""The type of the struct whose field is accessed."""
assert isinstance(self.parent.ty, StructType)
return self.parent.ty

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.exact_defined_at or self.parent.defined_at

@cached_property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Field `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}.{self.field.name}"


PyScope = dict[str, Any]

Expand Down
35 changes: 30 additions & 5 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Variable,
Expand All @@ -53,13 +54,14 @@
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
IterEnd,
IterHasNext,
IterNext,
LocalCall,
LocalName,
MakeIter,
PlaceNode,
PyExpr,
TensorCall,
TypeApply,
Expand All @@ -81,6 +83,7 @@
FunctionType,
NoneType,
OpaqueType,
StructType,
TupleType,
Type,
TypeBase,
Expand Down Expand Up @@ -329,11 +332,11 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]:
raise GuppyError("Unsupported constant", node)
return node, ty

def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
x = node.id
if x in self.ctx.locals:
var = self.ctx.locals[x]
return with_loc(node, LocalName(id=x)), var.ty
return with_loc(node, PlaceNode(place=var)), var.ty
elif x in self.ctx.globals:
# Look-up what kind of definition it is
match self.ctx.globals[x]:
Expand All @@ -353,6 +356,28 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
"been caught by program analysis!"
)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# A `value.attr` attribute access
node.value, ty = self.synthesize(node.value)
if isinstance(ty, StructType) and node.attr in ty.field_dict:
field = ty.field_dict[node.attr]
expr: ast.expr
if isinstance(node.value, PlaceNode):
# Field access on a place is itself a place
expr = PlaceNode(place=FieldAccess(node.value.place, field, None))
else:
# If the struct is not in a place, then there is no way to address the
# other fields after this one has been projected (e.g. `f().a` makes
# you loose access to all fields besides `a`).
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
return with_loc(node, expr), field.ty
raise GuppyTypeError(
f"Expression of type `{ty}` has no attribute `{node.attr}`",
# Unfortunately, `node.attr` doesn't contain source annotations, so we have
# to use `node` as the error location
node,
)

def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
elems = [self.synthesize(elem) for elem in node.elts]

Expand Down Expand Up @@ -882,9 +907,9 @@ def synthesize_comprehension(
expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx)
gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign)
gen.next_assign = stmt_chk.visit_Assign(gen.next_assign)
gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext)
gen.hasnext, hasnext_ty = expr_sth.visit(gen.hasnext)
gen.hasnext = with_type(hasnext_ty, gen.hasnext)
gen.iter, iter_ty = expr_sth.visit_Name(gen.iter)
gen.iter, iter_ty = expr_sth.visit(gen.iter)
gen.iter = with_type(iter_ty, gen.iter)

# Check `if` guards
Expand Down
69 changes: 56 additions & 13 deletions guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import ast
from collections.abc import Sequence

from guppylang.ast_util import AstVisitor, with_loc, with_type
from guppylang.ast_util import AstVisitor, with_loc
from guppylang.cfg.bb import BB, BBStatement
from guppylang.checker.core import Context, Variable
from guppylang.checker.core import Context, FieldAccess, Variable
from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import NestedFunctionDef
from guppylang.nodes import NestedFunctionDef, PlaceNode
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.subst import Subst
from guppylang.tys.ty import NoneType, TupleType, Type
from guppylang.tys.ty import NoneType, StructType, TupleType, Type


class StmtChecker(AstVisitor[BBStatement]):
Expand All @@ -46,17 +46,57 @@ def _check_expr(
) -> tuple[ast.expr, Subst]:
return ExprChecker(self.ctx).check(node, ty, kind)

def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None:
def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr:
"""Helper function to check assignments with patterns."""
match lhs:
# Easiest case is if the LHS pattern is a single variable.
case ast.Name(id=x):
# Store the type in the AST
with_type(ty, lhs)
self.ctx.locals[x] = Variable(x, ty, lhs)
var = Variable(x, ty, lhs)
self.ctx.locals[x] = var
return with_loc(lhs, PlaceNode(place=var))

# The LHS could also be a field `expr.field`
case ast.Attribute(value=value, attr=attr):
value, struct_ty = self._synth_expr(value)
if (
not isinstance(struct_ty, StructType)
or attr not in struct_ty.field_dict
):
raise GuppyTypeError(
f"Expression of type `{struct_ty}` has no attribute `{attr}`",
# Unfortunately, `attr` doesn't contain source annotations, so
# we have to use `lhs` as the error location
lhs,
)
field = struct_ty.field_dict[attr]
# TODO: In the future, we could infer some type args here
if field.ty != ty:
raise GuppyTypeError(
f"Cannot assign expression of type `{ty}` to field with type "
f"`{field.ty}`",
lhs,
)
if not isinstance(value, PlaceNode):
# For now we complain if someone tries to assign to something that
# is not a place, e.g. `f().a = 4`. This would only make sense if
# there is another reference to the return value of `f`, otherwise
# the mutation cannot be observed. We can start supporting this once
# we have proper reference semantics.
raise GuppyError(
"Assigning to this expression is not supported yet. Consider "
"binding the expression to variable and mutate that variable "
"instead.",
value
)
if not field.ty.linear:
raise GuppyError(
"Mutation of classical fields is not supported yet", lhs
)
place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs)
return with_loc(lhs, PlaceNode(place=place))

# The only other thing we support right now are tuples
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly outdated comment

case ast.Tuple(elts=elts):
case ast.Tuple(elts=elts) as lhs:
tys = ty.element_types if isinstance(ty, TupleType) else [ty]
n, m = len(elts), len(tys)
if n != m:
Expand All @@ -65,8 +105,11 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None:
f"(expected {n}, got {m})",
node,
)
for pat, el_ty in zip(elts, tys, strict=True):
lhs.elts = [
self._check_assign(pat, el_ty, node)
for pat, el_ty in zip(elts, tys, strict=True)
]
return lhs

# TODO: Python also supports assignments like `[a, b] = [1, 2]` or
# `a, *b = ...`. The former would require some runtime checks but
Expand All @@ -81,7 +124,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:

[target] = node.targets
node.value, ty = self._synth_expr(node.value)
self._check_assign(target, ty, node)
node.targets = [self._check_assign(target, ty, node)]
return node

def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
Expand All @@ -93,8 +136,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
node.value, subst = self._check_expr(node.value, ty)
assert not ty.unsolved_vars # `ty` must be closed!
assert len(subst) == 0
self._check_assign(node.target, ty, node)
return node
target = self._check_assign(node.target, ty, node)
return with_loc(node, ast.Assign(targets=[target], value=node.value))

def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt:
bin_op = with_loc(
Expand Down
29 changes: 22 additions & 7 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
from typing import TYPE_CHECKING, Any

from guppylang.tys.subst import Inst
from guppylang.tys.ty import FunctionType, Type
from guppylang.tys.ty import FunctionType, StructType, Type

if TYPE_CHECKING:
from guppylang.cfg.cfg import CFG
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Variable
from guppylang.checker.core import Place, Variable
from guppylang.definition.common import DefId
from guppylang.definition.struct import StructField


class LocalName(ast.Name):
id: str
class PlaceNode(ast.expr):
place: "Place"

_fields = ("id",)
_fields = ("place",)


class GlobalName(ast.Name):
Expand Down Expand Up @@ -77,6 +78,20 @@ class TypeApply(ast.expr):
)


class FieldAccessAndDrop(ast.expr):
"""A field access on a struct, dropping all the remaining other fields."""

value: ast.expr
struct_ty: "StructType"
field: "StructField"

_fields = (
"value",
"struct_ty",
"field",
)


class MakeIter(ast.expr):
"""Creates an iterator using the `__iter__` magic method.

Expand Down Expand Up @@ -137,8 +152,8 @@ class DesugaredGenerator(ast.expr):
hasnext_assign: ast.Assign
next_assign: ast.Assign
iterend: ast.expr
iter: ast.Name
hasnext: ast.Name
iter: ast.expr
hasnext: ast.expr
ifs: list[ast.expr]

_fields = (
Expand Down
5 changes: 5 additions & 0 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,11 @@ def fields(self) -> list["StructField"]:
inst = Instantiator(self.args)
return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields]

@cached_property
def field_dict(self) -> "dict[str, StructField]":
"""Mapping from names to fields of this struct type."""
return {field.name: field for field in self.fields}

@cached_property
def intrinsically_linear(self) -> bool:
"""Whether this type is linear, independent of the arguments."""
Expand Down