Skip to content

Commit

Permalink
feat: Linearity checking for places (#290)
Browse files Browse the repository at this point in the history
See #295 for context, the tests are at #293

The basic idea is that we implicitly "unfold" structs into tuples and
then rely on the existing linearity checking logic. The example from
#295 is checked as follows:

```python
@guppy.struct
class MyStruct:
   q1: qubit
   q2: qubit
   x: int

def main(s: MyStruct):
   s_q1, s_q2, s_x = s
   q = h(s_q1)
   t = h(s_q2)
   y = s_x + s_x
   use((s_q1, s_q2, s_x))  # Error
   ...
```

This is implemented using the `leaf_places()` iterator that gives us the
leaf projections of places with (nested) struct types.
  • Loading branch information
mark-koch authored Jul 23, 2024
1 parent acf1242 commit 6561f05
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 83 deletions.
2 changes: 1 addition & 1 deletion guppylang/cfg/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def initial(self) -> AssignmentDomain[VId]:
# `ass_before_entry` since we want to compute the *greatest* fixpoint.
return self.all_vars, self.maybe_ass_before_entry

def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
def join(self, *ts: AssignmentDomain[P]) -> AssignmentDomain[P]:
# We always include the variables that are definitely assigned before the entry,
# even if the join is empty
if len(ts) == 0:
Expand Down
6 changes: 6 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ def keys(self) -> set[VId]:
parent_keys = self.parent_scope.keys() if self.parent_scope else set()
return parent_keys | self.vars.keys()

def values(self) -> Iterable[V]:
parent_values = (
iter(self.parent_scope.values()) if self.parent_scope else iter(())
)
return itertools.chain(self.vars.values(), parent_values)

def items(self) -> Iterable[tuple[VId, V]]:
parent_items = (
iter(self.parent_scope.items()) if self.parent_scope else iter(())
Expand Down
242 changes: 160 additions & 82 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,60 @@
"""

import ast
from collections.abc import Generator, Iterable
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

from guppylang.ast_util import get_type, name_nodes_in_ast
from guppylang.checker.core import Locals, Variable
from guppylang.ast_util import AstNode, find_nodes, get_type
from guppylang.cfg.analysis import LivenessAnalysis
from guppylang.cfg.bb import BB, VariableStats
from guppylang.checker.core import (
FieldAccess,
Locals,
Place,
PlaceId,
Variable,
)
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.nodes import DesugaredGenerator, DesugaredListComp, LocalName
from guppylang.nodes import (
CheckedNestedFunctionDef,
DesugaredGenerator,
DesugaredListComp,
FieldAccessAndDrop,
PlaceNode,
)
from guppylang.tys.ty import StructType

if TYPE_CHECKING:
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG


class Scope(Locals[str, Variable]):
"""Scoped collection of assigned variables indexed by name.
class Scope(Locals[PlaceId, Place]):
"""Scoped collection of assigned places indexed by their id.
Keeps track of which variables have already been used.
Keeps track of which places have already been used.
"""

parent_scope: "Scope | None"
used_local: dict[str, ast.Name]
used_parent: dict[str, ast.Name]
used_local: dict[PlaceId, AstNode]
used_parent: dict[PlaceId, AstNode]

def __init__(self, assigned: Iterable[Variable], parent: "Scope | None" = None):
def __init__(self, parent: "Scope | None" = None):
self.used_local = {}
self.used_parent = {}
super().__init__({var.name: var for var in assigned}, parent)
super().__init__({}, parent)

def used(self, x: str) -> ast.Name | None:
"""Checks whether a variable has already been used."""
def used(self, x: PlaceId) -> AstNode | None:
"""Checks whether a place has already been used."""
if x in self.vars:
return self.used_local.get(x, None)
assert self.parent_scope is not None
return self.parent_scope.used(x)

def use(self, x: str, node: ast.Name) -> None:
"""Records a use of a variable.
def use(self, x: PlaceId, node: AstNode) -> None:
"""Records a use of a place.
Works for local variables in the current scope as well as variables in any
parent scope.
Works for places in the current scope as well as places in any parent scope.
"""
if x in self.vars:
self.used_local[x] = node
Expand All @@ -53,21 +67,42 @@ def use(self, x: str, node: ast.Name) -> None:
self.used_parent[x] = node
self.parent_scope.use(x, node)

def assign(self, var: Variable) -> None:
"""Records an assignment of a variable."""
x = var.name
self.vars[x] = var
def assign(self, place: Place) -> None:
"""Records an assignment of a place."""
assert place.defined_at is not None
x = place.id
self.vars[x] = place
if x in self.used_local:
self.used_local.pop(x)

def stats(self) -> VariableStats[PlaceId]:
assigned = {}
for x, place in self.vars.items():
assert place.defined_at is not None
assigned[x] = place.defined_at
return VariableStats(assigned, self.used_parent)


class BBLinearityChecker(ast.NodeVisitor):
"""AST visitor that checks linearity for a single basic block."""

scope: Scope
stats: VariableStats[PlaceId]

def check(self, bb: "CheckedBB", is_entry: bool) -> Scope:
# Manufacture a scope that holds all places that are live at the start
# of this BB
input_scope = Scope()
for var in bb.sig.input_row:
for place in leaf_places(var):
input_scope.assign(place)

# Open up a new nested scope to check the BB contents. This way we can track
# when we use variables from the outside vs ones assigned in this BB. The only
# exception is the entry BB since function arguments should be treated as part
# of the entry BB
self.scope = input_scope if is_entry else Scope(input_scope)

def check(self, bb: "CheckedBB") -> Scope:
self.scope = Scope(bb.sig.input_row)
for stmt in bb.statements:
self.visit(stmt)
if bb.branch_pred:
Expand All @@ -76,18 +111,17 @@ def check(self, bb: "CheckedBB") -> Scope:

@contextmanager
def new_scope(self) -> Generator[Scope, None, None]:
scope, new_scope = self.scope, Scope({}, self.scope)
scope, new_scope = self.scope, Scope(self.scope)
self.scope = new_scope
yield new_scope
self.scope = scope

def visit_LocalName(self, node: LocalName) -> None:
x = node.id
if x in self.scope:
var = self.scope[x]
if (use := self.scope.used(x)) and var.ty.linear:
def visit_PlaceNode(self, node: PlaceNode) -> None:
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was already used "
f"{place.describe} with linear type `{place.ty}` was already used "
"(at {0})",
node,
[use],
Expand All @@ -98,6 +132,19 @@ def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
self._check_assign_targets(node.targets)

def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
# A field access on a value that is not a place. This means the value can no
# longer be accessed after the field has been projected out. Thus, this is only
# legal if there are no remaining linear fields on the value
self.visit(node.value)
for field in node.struct_ty.fields:
if field.name != node.field.name and field.ty.linear:
raise GuppyTypeError(
f"Linear field `{field.name}` of expression with type "
f"`{node.struct_ty}` is not used",
node.value,
)

def visit_Expr(self, node: ast.Expr) -> None:
# An expression statement where the return value is discarded
self.visit(node.value)
Expand All @@ -110,18 +157,21 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:

def _check_assign_targets(self, targets: list[ast.expr]) -> None:
"""Helper function to check assignments."""
# We're not allowed to override an unused linear variable
# We're not allowed to override an unused linear place
[target] = targets
for name in name_nodes_in_ast(target):
x = name.id
if x in self.scope and not self.scope.used(x):
var = self.scope[x]
if var.ty.linear:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` is not used",
var.defined_at,
)
self.scope.assign(Variable(x, get_type(name), name))
for tgt in find_nodes(lambda n: isinstance(n, PlaceNode), target):
assert isinstance(tgt, PlaceNode)
for tgt_place in leaf_places(tgt.place):
x = tgt_place.id
if x in self.scope and not self.scope.used(x):
place = self.scope[x]
if place.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` is not "
"used",
place.defined_at,
)
self.scope.assign(tgt_place)

def _check_comprehension(
self, node: DesugaredListComp, gens: list[DesugaredGenerator]
Expand All @@ -134,6 +184,7 @@ def _check_comprehension(
# Check the iterator expression in the current scope
gen, *gens = gens
self.visit(gen.iter_assign.value)
assert isinstance(gen.iter, PlaceNode)

# The rest is checked in a new nested scope so we can track which variables
# are introduced and used inside the loop
Expand All @@ -151,17 +202,20 @@ def _check_comprehension(
# Check if there are linear iteration variables that have not been used
# by the first guard
self.visit(first_if)
for x, var in self.scope.vars.items():
for place in self.scope.vars.values():
# The only exception is the iterator variable since we make sure
# that it is carried through each iteration during Hugr generation
if x == gen.iter.id:
if place == gen.iter.place:
continue
if not self.scope.used(x) and var.ty.linear:
raise GuppyTypeError(
f"Variable `{var.name}` with linear type `{var.ty}` is not "
"used on all control-flow paths of the list comprehension",
var.defined_at,
)
for leaf in leaf_places(place):
x = leaf.id
if not self.scope.used(x) and place.ty.linear:
raise GuppyTypeError(
f"{place.describe} with linear type `{place.ty}` is "
"not used on all control-flow paths of the list "
"comprehension",
place.defined_at,
)
for expr in other_ifs:
self.visit(expr)

Expand All @@ -173,58 +227,82 @@ def _check_comprehension(

# We have to make sure that all linear variables that were introduced in the
# inner scope have been used
for x, var in inner_scope.vars.items():
if var.ty.linear and not inner_scope.used(x):
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` is not used",
var.defined_at,
)
for place in inner_scope.vars.values():
for leaf in leaf_places(place):
x = leaf.id
if leaf.ty.linear and not inner_scope.used(x):
raise GuppyTypeError(
f"{leaf.describe} with linear type `{leaf.ty}` is not used",
leaf.defined_at,
)

# On the other hand, we have to ensure that no linear variables from the
# On the other hand, we have to ensure that no linear places from the
# outer scope have been used inside the comprehension (they would be used
# multiple times since the comprehension body is executed repeatedly)
for x, use in inner_scope.used_parent.items():
var = inner_scope[x]
if var.ty.linear:
place = inner_scope[x]
if place.ty.linear:
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` would be used "
f"{place.describe} with linear type `{place.ty}` would be used "
"multiple times when evaluating this comprehension",
use,
)


def leaf_places(place: Place) -> Iterator[Place]:
"""Returns all leaf descendant projections of a place."""
stack = [place]
while stack:
place = stack.pop()
if isinstance(place.ty, StructType):
for field in place.ty.fields:
stack.append(FieldAccess(place, field, place.defined_at))
else:
yield place


def check_cfg_linearity(cfg: "CheckedCFG") -> None:
"""Checks whether a CFG satisfies the linearity requirements.
Raises a user-error if linearity violations are found.
"""
bb_checker = BBLinearityChecker()
for bb in cfg.bbs:
scope = bb_checker.check(bb)
scopes: dict[BB, Scope] = {
bb: bb_checker.check(bb, is_entry=bb == cfg.entry_bb) for bb in cfg.bbs
}

# Run liveness analysis
stats = {bb: scope.stats() for bb, scope in scopes.items()}
live_before = LivenessAnalysis(stats).run(cfg.bbs)

for bb, scope in scopes.items():
# We have to check that used linear variables are not being outputted
for succ in bb.successors:
live = cfg.live_before[succ]
live = live_before[succ]
for x, use_bb in live.items():
if x in scope:
var = scope[x]
if var.ty.linear and (use := scope.used(x)):
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was "
"already used (at {0})",
use_bb.vars.used[x],
[use],
)

# On the other hand, unused linear variables *must* be outputted
for x, var in scope.vars.items():
used_later = x in cfg.live_before[succ]
if var.ty.linear and not scope.used(x) and not used_later:
# TODO: This should be "Variable x with linear type ty is not
# used in {bb}". But for this we need a way to associate BBs with
# source locations.
use_scope = scopes[use_bb]
place = use_scope[x]
if place.ty.linear and (use := scope.used(x)):
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` is "
"not used on all control-flow paths",
var.defined_at,
f"{place.describe} with linear type `{place.ty}` was "
"already used (at {0})",
use_scope.used_parent[x],
[use],
)

# On the other hand, unused linear variables *must* be outputted
for place in scope.vars.values():
for leaf in leaf_places(place):
x = leaf.id
used_later = x in live
if leaf.ty.linear and not scope.used(x) and not used_later:
# TODO: This should be "Variable x with linear type ty is not
# used in {bb}". But for this we need a way to associate BBs
# with source locations.
raise GuppyError(
f"{leaf.describe} with linear type `{leaf.ty}` is "
"not used on all control-flow paths",
# Re-lookup defined_at in scope because we might have a
# more precise location
scope[x].defined_at,
)

0 comments on commit 6561f05

Please sign in to comment.