diff --git a/guppylang/cfg/analysis.py b/guppylang/cfg/analysis.py index cbe3742d..d71822f0 100644 --- a/guppylang/cfg/analysis.py +++ b/guppylang/cfg/analysis.py @@ -170,7 +170,7 @@ def initial(self) -> AssignmentDomain[VId]: # `ass_before_entry` since we want to compute the *greatest* fixpoint. return self.all_vars, self.maybe_ass_before_entry - def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]: + def join(self, *ts: AssignmentDomain[P]) -> AssignmentDomain[P]: # We always include the variables that are definitely assigned before the entry, # even if the join is empty if len(ts) == 0: diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 1b92c01a..39fb58ea 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -302,6 +302,12 @@ def keys(self) -> set[VId]: parent_keys = self.parent_scope.keys() if self.parent_scope else set() return parent_keys | self.vars.keys() + def values(self) -> Iterable[V]: + parent_values = ( + iter(self.parent_scope.values()) if self.parent_scope else iter(()) + ) + return itertools.chain(self.vars.values(), parent_values) + def items(self) -> Iterable[tuple[VId, V]]: parent_items = ( iter(self.parent_scope.items()) if self.parent_scope else iter(()) diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 0499a048..6f995e91 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -4,46 +4,60 @@ """ import ast -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterator from contextlib import contextmanager from typing import TYPE_CHECKING -from guppylang.ast_util import get_type, name_nodes_in_ast -from guppylang.checker.core import Locals, Variable +from guppylang.ast_util import AstNode, find_nodes, get_type +from guppylang.cfg.analysis import LivenessAnalysis +from guppylang.cfg.bb import BB, VariableStats +from guppylang.checker.core import ( + FieldAccess, + Locals, + Place, + PlaceId, + Variable, +) from guppylang.error import GuppyError, GuppyTypeError -from guppylang.nodes import DesugaredGenerator, DesugaredListComp, LocalName +from guppylang.nodes import ( + CheckedNestedFunctionDef, + DesugaredGenerator, + DesugaredListComp, + FieldAccessAndDrop, + PlaceNode, +) +from guppylang.tys.ty import StructType if TYPE_CHECKING: from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG -class Scope(Locals[str, Variable]): - """Scoped collection of assigned variables indexed by name. +class Scope(Locals[PlaceId, Place]): + """Scoped collection of assigned places indexed by their id. - Keeps track of which variables have already been used. + Keeps track of which places have already been used. """ parent_scope: "Scope | None" - used_local: dict[str, ast.Name] - used_parent: dict[str, ast.Name] + used_local: dict[PlaceId, AstNode] + used_parent: dict[PlaceId, AstNode] - def __init__(self, assigned: Iterable[Variable], parent: "Scope | None" = None): + def __init__(self, parent: "Scope | None" = None): self.used_local = {} self.used_parent = {} - super().__init__({var.name: var for var in assigned}, parent) + super().__init__({}, parent) - def used(self, x: str) -> ast.Name | None: - """Checks whether a variable has already been used.""" + def used(self, x: PlaceId) -> AstNode | None: + """Checks whether a place has already been used.""" if x in self.vars: return self.used_local.get(x, None) assert self.parent_scope is not None return self.parent_scope.used(x) - def use(self, x: str, node: ast.Name) -> None: - """Records a use of a variable. + def use(self, x: PlaceId, node: AstNode) -> None: + """Records a use of a place. - Works for local variables in the current scope as well as variables in any - parent scope. + Works for places in the current scope as well as places in any parent scope. """ if x in self.vars: self.used_local[x] = node @@ -53,21 +67,42 @@ def use(self, x: str, node: ast.Name) -> None: self.used_parent[x] = node self.parent_scope.use(x, node) - def assign(self, var: Variable) -> None: - """Records an assignment of a variable.""" - x = var.name - self.vars[x] = var + def assign(self, place: Place) -> None: + """Records an assignment of a place.""" + assert place.defined_at is not None + x = place.id + self.vars[x] = place if x in self.used_local: self.used_local.pop(x) + def stats(self) -> VariableStats[PlaceId]: + assigned = {} + for x, place in self.vars.items(): + assert place.defined_at is not None + assigned[x] = place.defined_at + return VariableStats(assigned, self.used_parent) + class BBLinearityChecker(ast.NodeVisitor): """AST visitor that checks linearity for a single basic block.""" scope: Scope + stats: VariableStats[PlaceId] + + def check(self, bb: "CheckedBB", is_entry: bool) -> Scope: + # Manufacture a scope that holds all places that are live at the start + # of this BB + input_scope = Scope() + for var in bb.sig.input_row: + for place in leaf_places(var): + input_scope.assign(place) + + # Open up a new nested scope to check the BB contents. This way we can track + # when we use variables from the outside vs ones assigned in this BB. The only + # exception is the entry BB since function arguments should be treated as part + # of the entry BB + self.scope = input_scope if is_entry else Scope(input_scope) - def check(self, bb: "CheckedBB") -> Scope: - self.scope = Scope(bb.sig.input_row) for stmt in bb.statements: self.visit(stmt) if bb.branch_pred: @@ -76,18 +111,17 @@ def check(self, bb: "CheckedBB") -> Scope: @contextmanager def new_scope(self) -> Generator[Scope, None, None]: - scope, new_scope = self.scope, Scope({}, self.scope) + scope, new_scope = self.scope, Scope(self.scope) self.scope = new_scope yield new_scope self.scope = scope - def visit_LocalName(self, node: LocalName) -> None: - x = node.id - if x in self.scope: - var = self.scope[x] - if (use := self.scope.used(x)) and var.ty.linear: + def visit_PlaceNode(self, node: PlaceNode) -> None: + for place in leaf_places(node.place): + x = place.id + if (use := self.scope.used(x)) and place.ty.linear: raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was already used " + f"{place.describe} with linear type `{place.ty}` was already used " "(at {0})", node, [use], @@ -98,6 +132,19 @@ def visit_Assign(self, node: ast.Assign) -> None: self.visit(node.value) self._check_assign_targets(node.targets) + def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None: + # A field access on a value that is not a place. This means the value can no + # longer be accessed after the field has been projected out. Thus, this is only + # legal if there are no remaining linear fields on the value + self.visit(node.value) + for field in node.struct_ty.fields: + if field.name != node.field.name and field.ty.linear: + raise GuppyTypeError( + f"Linear field `{field.name}` of expression with type " + f"`{node.struct_ty}` is not used", + node.value, + ) + def visit_Expr(self, node: ast.Expr) -> None: # An expression statement where the return value is discarded self.visit(node.value) @@ -110,18 +157,21 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: def _check_assign_targets(self, targets: list[ast.expr]) -> None: """Helper function to check assignments.""" - # We're not allowed to override an unused linear variable + # We're not allowed to override an unused linear place [target] = targets - for name in name_nodes_in_ast(target): - x = name.id - if x in self.scope and not self.scope.used(x): - var = self.scope[x] - if var.ty.linear: - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is not used", - var.defined_at, - ) - self.scope.assign(Variable(x, get_type(name), name)) + for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target): + assert isinstance(tgt, PlaceNode) + for tgt_place in leaf_places(tgt.place): + x = tgt_place.id + if x in self.scope and not self.scope.used(x): + place = self.scope[x] + if place.ty.linear: + raise GuppyError( + f"{place.describe} with linear type `{place.ty}` is not " + "used", + place.defined_at, + ) + self.scope.assign(tgt_place) def _check_comprehension( self, node: DesugaredListComp, gens: list[DesugaredGenerator] @@ -134,6 +184,7 @@ def _check_comprehension( # Check the iterator expression in the current scope gen, *gens = gens self.visit(gen.iter_assign.value) + assert isinstance(gen.iter, PlaceNode) # The rest is checked in a new nested scope so we can track which variables # are introduced and used inside the loop @@ -151,17 +202,20 @@ def _check_comprehension( # Check if there are linear iteration variables that have not been used # by the first guard self.visit(first_if) - for x, var in self.scope.vars.items(): + for place in self.scope.vars.values(): # The only exception is the iterator variable since we make sure # that it is carried through each iteration during Hugr generation - if x == gen.iter.id: + if place == gen.iter.place: continue - if not self.scope.used(x) and var.ty.linear: - raise GuppyTypeError( - f"Variable `{var.name}` with linear type `{var.ty}` is not " - "used on all control-flow paths of the list comprehension", - var.defined_at, - ) + for leaf in leaf_places(place): + x = leaf.id + if not self.scope.used(x) and place.ty.linear: + raise GuppyTypeError( + f"{place.describe} with linear type `{place.ty}` is " + "not used on all control-flow paths of the list " + "comprehension", + place.defined_at, + ) for expr in other_ifs: self.visit(expr) @@ -173,58 +227,82 @@ def _check_comprehension( # We have to make sure that all linear variables that were introduced in the # inner scope have been used - for x, var in inner_scope.vars.items(): - if var.ty.linear and not inner_scope.used(x): - raise GuppyTypeError( - f"Variable `{x}` with linear type `{var.ty}` is not used", - var.defined_at, - ) + for place in inner_scope.vars.values(): + for leaf in leaf_places(place): + x = leaf.id + if leaf.ty.linear and not inner_scope.used(x): + raise GuppyTypeError( + f"{leaf.describe} with linear type `{leaf.ty}` is not used", + leaf.defined_at, + ) - # On the other hand, we have to ensure that no linear variables from the + # On the other hand, we have to ensure that no linear places from the # outer scope have been used inside the comprehension (they would be used # multiple times since the comprehension body is executed repeatedly) for x, use in inner_scope.used_parent.items(): - var = inner_scope[x] - if var.ty.linear: + place = inner_scope[x] + if place.ty.linear: raise GuppyTypeError( - f"Variable `{x}` with linear type `{var.ty}` would be used " + f"{place.describe} with linear type `{place.ty}` would be used " "multiple times when evaluating this comprehension", use, ) +def leaf_places(place: Place) -> Iterator[Place]: + """Returns all leaf descendant projections of a place.""" + stack = [place] + while stack: + place = stack.pop() + if isinstance(place.ty, StructType): + for field in place.ty.fields: + stack.append(FieldAccess(place, field, place.defined_at)) + else: + yield place + + def check_cfg_linearity(cfg: "CheckedCFG") -> None: """Checks whether a CFG satisfies the linearity requirements. Raises a user-error if linearity violations are found. """ bb_checker = BBLinearityChecker() - for bb in cfg.bbs: - scope = bb_checker.check(bb) + scopes: dict[BB, Scope] = { + bb: bb_checker.check(bb, is_entry=bb == cfg.entry_bb) for bb in cfg.bbs + } + # Run liveness analysis + stats = {bb: scope.stats() for bb, scope in scopes.items()} + live_before = LivenessAnalysis(stats).run(cfg.bbs) + + for bb, scope in scopes.items(): # We have to check that used linear variables are not being outputted for succ in bb.successors: - live = cfg.live_before[succ] + live = live_before[succ] for x, use_bb in live.items(): - if x in scope: - var = scope[x] - if var.ty.linear and (use := scope.used(x)): - raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` was " - "already used (at {0})", - use_bb.vars.used[x], - [use], - ) - - # On the other hand, unused linear variables *must* be outputted - for x, var in scope.vars.items(): - used_later = x in cfg.live_before[succ] - if var.ty.linear and not scope.used(x) and not used_later: - # TODO: This should be "Variable x with linear type ty is not - # used in {bb}". But for this we need a way to associate BBs with - # source locations. + use_scope = scopes[use_bb] + place = use_scope[x] + if place.ty.linear and (use := scope.used(x)): raise GuppyError( - f"Variable `{x}` with linear type `{var.ty}` is " - "not used on all control-flow paths", - var.defined_at, + f"{place.describe} with linear type `{place.ty}` was " + "already used (at {0})", + use_scope.used_parent[x], + [use], ) + + # On the other hand, unused linear variables *must* be outputted + for place in scope.vars.values(): + for leaf in leaf_places(place): + x = leaf.id + used_later = x in live + if leaf.ty.linear and not scope.used(x) and not used_later: + # TODO: This should be "Variable x with linear type ty is not + # used in {bb}". But for this we need a way to associate BBs + # with source locations. + raise GuppyError( + f"{leaf.describe} with linear type `{leaf.ty}` is " + "not used on all control-flow paths", + # Re-lookup defined_at in scope because we might have a + # more precise location + scope[x].defined_at, + )