Skip to content

Commit

Permalink
feat: Allow access to struct fields and mutation of linear ones (#295)
Browse files Browse the repository at this point in the history
This is the feature branch for enabling struct field access and mutation
of linear fields.

Closes #276, closes #280, and closes #156. 

The difficult bit is that we want to track linearity of individual
struct fields separately (akin to Rust's partial moves):

```python
@guppy.struct
class MyStruct:
   q1: qubit
   q2: qubit
   x: int

def main(s: MyStruct):
   q = h(s.q1)
   t = h(s.q2)    # Using s.q2 is fine, only s.q1 has been used
   y = s.x + s.x  # Classical fields can be used multiple times
   use(s)         # Error: Linear fields of s have already been used
   ...
```

This is the plan:

* We introduce a new notion called `Place`: A place is a description for
a storage location of a local value that users can refer to in their
program. Roughly, these are values that can be lowered to a static wire
within the Hugr and are tracked separately when checking linearity.
* For the purposes of this PR, a place is either a local variable or a
field projection of another place that holds a struct. I.e. places are
paths `a.b.c.d` of zero or more nested struct accesses. In the future,
indexing into an array etc will also become a place.
* During type checking, we figure out which AST nodes correspond to
places and annotate them as such
* For linearity checking, we run a liveness analysis pass that tracks
usage and assignment of places across the CFG. This way, linearity of
different struct fields is tracked individually.
* When compiling to Hugr, we keep track of a mapping from places to
wires/ports


Tracked PRs:

* #288: Precursor PR to generalise our program analysis framework to run
on places in the future.
* #289: Adds the `Place` type and implements the type checking logic to
turn `ast.Name` and `ast.Attribute` nodes into places.
* #290: Update linearity checker to operate on places instead of
variables
* #291: Lower places to Hugr
* #292: Some missing pieces to handle nested functions correctlt
* #293
  • Loading branch information
mark-koch authored Jul 24, 2024
1 parent 4e27b06 commit 6698b75
Show file tree
Hide file tree
Showing 76 changed files with 1,661 additions and 344 deletions.
78 changes: 45 additions & 33 deletions guppylang/cfg/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from typing import Generic, TypeVar

from guppylang.cfg.bb import BB
from guppylang.cfg.bb import BB, VariableStats, VId

# Type variable for the lattice domain
T = TypeVar("T")
Expand All @@ -11,7 +11,7 @@
Result = dict[BB, T]


class Analysis(ABC, Generic[T]):
class Analysis(Generic[T], ABC):
"""Abstract base class for a program analysis pass over the lattice `T`"""

def eq(self, t1: T, t2: T, /) -> bool:
Expand All @@ -34,7 +34,7 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
"""


class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
class ForwardAnalysis(Generic[T], Analysis[T], ABC):
"""Abstract base class for a program analysis pass running in forward direction."""

@abstractmethod
Expand All @@ -59,7 +59,7 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
return vals_before


class BackwardAnalysis(Analysis[T], ABC, Generic[T]):
class BackwardAnalysis(Generic[T], Analysis[T], ABC):
"""Abstract base class for a program analysis pass running in backward direction."""

@abstractmethod
Expand All @@ -85,83 +85,92 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:

# For live variable analysis, we also store a BB in which a use occurs as evidence of
# liveness.
LivenessDomain = dict[str, BB]
LivenessDomain = dict[VId, BB]


class LivenessAnalysis(BackwardAnalysis[LivenessDomain]):
class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
"""Live variable analysis pass.
Computes the variables that are live before the execution of each BB. The analysis
runs over the lattice of mappings from variable names to BBs containing a use.
"""

def eq(self, live1: LivenessDomain, live2: LivenessDomain) -> bool:
stats: dict[BB, VariableStats[VId]]

def __init__(self, stats: dict[BB, VariableStats[VId]]) -> None:
self.stats = stats

def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
# Only check that both contain the same variables. We don't care about the BB
# in which the use occurs, we just need any one, to report to the user.
return live1.keys() == live2.keys()

def initial(self) -> LivenessDomain:
def initial(self) -> LivenessDomain[VId]:
return {}

def join(self, *ts: LivenessDomain) -> LivenessDomain:
res: LivenessDomain = {}
def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
res: LivenessDomain[VId] = {}
for t in ts:
res |= t
return res

def apply_bb(self, live_after: LivenessDomain, bb: BB) -> LivenessDomain:
return {x: bb for x in bb.vars.used} | {
x: b for x, b in live_after.items() if x not in bb.vars.assigned
def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
stats = self.stats[bb]
return {x: bb for x in stats.used} | {
x: b for x, b in live_after.items() if x not in stats.assigned
}


# Set of variables that are definitely assigned at the start of a BB
DefAssignmentDomain = set[str]
DefAssignmentDomain = set[VId]

# Set of variables that are assigned on (at least) some paths to a BB. Definitely
# assigned variables are a subset of this
MaybeAssignmentDomain = set[str]
MaybeAssignmentDomain = set[VId]

# For assignment analysis, we do definite- and maybe-assignment in one pass
AssignmentDomain = tuple[DefAssignmentDomain, MaybeAssignmentDomain]
AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]


class AssignmentAnalysis(ForwardAnalysis[AssignmentDomain]):
class AssignmentAnalysis(Generic[VId], ForwardAnalysis[AssignmentDomain[VId]]):
"""Assigned variable analysis pass.
Computes the set of variable that are definitely assigned at the start of a BB.
Additionally, we compute the set of variables that are assigned on (at least) some
paths to a BB (the definitely assigned variables are a subset of this).
Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
of a BB. Additionally, we compute the set of variables that are assigned on (at
least) some paths to a BB (the definitely assigned variables are a subset of this).
"""

all_vars: set[str]
ass_before_entry: set[str]
maybe_ass_before_entry: set[str]
stats: dict[BB, VariableStats[VId]]
all_vars: set[VId]
ass_before_entry: set[VId]
maybe_ass_before_entry: set[VId]

def __init__(
self,
bbs: Iterable[BB],
ass_before_entry: set[str],
maybe_ass_before_entry: set[str],
stats: dict[BB, VariableStats[VId]],
ass_before_entry: set[VId],
maybe_ass_before_entry: set[VId],
) -> None:
"""Constructs an `AssignmentAnalysis` pass for a CFG.
Also takes a set variables that are definitely assigned before the entry of the
CFG (for example function arguments).
"""
assert ass_before_entry.issubset(maybe_ass_before_entry)
self.stats = stats
self.ass_before_entry = ass_before_entry
self.maybe_ass_before_entry = maybe_ass_before_entry
self.all_vars = (
set.union(*(set(bb.vars.assigned.keys()) for bb in bbs)) | ass_before_entry
set.union(*(set(stat.assigned.keys()) for stat in stats.values()))
| ass_before_entry
)

def initial(self) -> AssignmentDomain:
def initial(self) -> AssignmentDomain[VId]:
# Note that definite assignment must start with `all_vars` instead of only
# `ass_before_entry` since we want to compute the *greatest* fixpoint.
return self.all_vars, self.maybe_ass_before_entry

def join(self, *ts: AssignmentDomain) -> AssignmentDomain:
def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
# We always include the variables that are definitely assigned before the entry,
# even if the join is empty
if len(ts) == 0:
Expand All @@ -171,16 +180,19 @@ def join(self, *ts: AssignmentDomain) -> AssignmentDomain:
maybe_ass = set.union(*(maybe_ass for _, maybe_ass in ts))
return def_ass, maybe_ass

def apply_bb(self, val_before: AssignmentDomain, bb: BB) -> AssignmentDomain:
def apply_bb(
self, val_before: AssignmentDomain[VId], bb: BB
) -> AssignmentDomain[VId]:
stats = self.stats[bb]
def_ass_before, maybe_ass_before = val_before
return (
def_ass_before | bb.vars.assigned.keys(),
maybe_ass_before | bb.vars.assigned.keys(),
def_ass_before | stats.assigned.keys(),
maybe_ass_before | stats.assigned.keys(),
)

def run_unpacked(
self, bbs: Iterable[BB]
) -> tuple[Result[DefAssignmentDomain], Result[MaybeAssignmentDomain]]:
) -> tuple[Result[DefAssignmentDomain[VId]], Result[MaybeAssignmentDomain[VId]]]:
"""Runs the analysis and unpacks the definite- and maybe-assignment results."""
res = self.run(bbs)
return {bb: res[bb][0] for bb in res}, {bb: res[bb][1] for bb in res}
78 changes: 47 additions & 31 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ast
from abc import ABC
from collections.abc import Hashable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, TypeVar

from typing_extensions import Self

Expand All @@ -12,27 +13,21 @@
from guppylang.cfg.cfg import BaseCFG


# Type variable for ids of entities which we may wish to track during program analysis
# (generally ids for program variables or parts thereof)
VId = TypeVar("VId", bound=Hashable)


@dataclass
class VariableStats:
class VariableStats(Generic[VId]):
"""Stores variable usage information for a basic block."""

# Variables that are assigned in the BB
assigned: dict[str, AstNode] = field(default_factory=dict)
assigned: dict[VId, AstNode] = field(default_factory=dict)

# The (external) variables used in the BB, i.e. usages of variables that are
# assigned in the BB are not included here.
used: dict[str, ast.Name] = field(default_factory=dict)

def update_used(self, node: ast.AST) -> None:
"""Marks the variables occurring in a statement as used.
This method should be called whenever an expression is used in the BB.
"""
for name in name_nodes_in_ast(node):
# Should point to first use, so also check that the name is not already
# contained
if name.id not in self.assigned and name.id not in self.used:
self.used[name.id] = name
# created in the BB are not included here.
used: dict[VId, AstNode] = field(default_factory=dict)


BBStatement = (
Expand Down Expand Up @@ -66,10 +61,10 @@ class BB(ABC):
branch_pred: ast.expr | None = None

# Information about assigned/used variables in the BB
_vars: VariableStats | None = None
_vars: VariableStats[str] | None = None

@property
def vars(self) -> VariableStats:
def vars(self) -> VariableStats[str]:
"""Returns variable usage information for this BB.
Note that `compute_variable_stats` must be called before this property can be
Expand All @@ -78,46 +73,68 @@ def vars(self) -> VariableStats:
assert self._vars is not None
return self._vars

def compute_variable_stats(self) -> None:
def compute_variable_stats(self) -> VariableStats[str]:
"""Determines which variables are assigned/used in this BB."""
visitor = VariableVisitor(self)
for s in self.statements:
visitor.visit(s)
if self.branch_pred is not None:
visitor.visit(self.branch_pred)
self._vars = visitor.stats
return visitor.stats


class VariableVisitor(ast.NodeVisitor):
"""Visitor that computes used and assigned variables in a BB."""

bb: BB
stats: VariableStats
stats: VariableStats[str]

def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def _update_used(self, node: ast.AST) -> None:
"""Marks the variables occurring in a statement as used.
This method should be called whenever an expression is used in the BB.
"""
for name in name_nodes_in_ast(node):
# Should point to first use, so also check that the name is not already
# contained
x = name.id
if x not in self.stats.assigned and x not in self.stats.used:
self.stats.used[x] = name

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)
self._update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
for t in node.targets:
for name in name_nodes_in_ast(t):
self.stats.assigned[name.id] = node
self._handle_assign_target(t, node)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node
self._update_used(node.target) # The target is also used
self._handle_assign_target(node.target, node)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.visit(node.value)
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node
self._handle_assign_target(node.target, node)

def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None:
match lhs:
case ast.Name(id=name):
self.stats.assigned[name] = node
case ast.Tuple(elts=elts):
for elt in elts:
self._handle_assign_target(elt, node)
case ast.Attribute(value=value):
# Setting attributes counts as a use of the value, so we do a regular
# visit instead of treating it like a LHS
self.visit(value)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
Expand Down Expand Up @@ -147,9 +164,8 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# definition, we have to run live variable analysis first
from guppylang.cfg.analysis import LivenessAnalysis

for bb in node.cfg.bbs:
bb.compute_variable_stats()
live = LivenessAnalysis().run(node.cfg.bbs)
stats = {bb: bb.compute_variable_stats() for bb in node.cfg.bbs}
live = LivenessAnalysis(stats).run(node.cfg.bbs)

# Only store used *external* variables: things defined in the current BB, as
# well as the function name and argument names should not be included
Expand Down
26 changes: 14 additions & 12 deletions guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MaybeAssignmentDomain,
Result,
)
from guppylang.cfg.bb import BB, BBStatement
from guppylang.cfg.bb import BB, BBStatement, VariableStats

T = TypeVar("T", bound=BB)

Expand All @@ -20,9 +20,9 @@ class BaseCFG(Generic[T]):
entry_bb: T
exit_bb: T

live_before: Result[LivenessDomain]
ass_before: Result[DefAssignmentDomain]
maybe_ass_before: Result[MaybeAssignmentDomain]
live_before: Result[LivenessDomain[str]]
ass_before: Result[DefAssignmentDomain[str]]
maybe_ass_before: Result[MaybeAssignmentDomain[str]]

def __init__(
self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
Expand All @@ -36,14 +36,6 @@ def __init__(
self.ass_before = {}
self.maybe_ass_before = {}

def analyze(self, def_ass_before: set[str], maybe_ass_before: set[str]) -> None:
for bb in self.bbs:
bb.compute_variable_stats()
self.live_before = LivenessAnalysis().run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
self.bbs, def_ass_before, maybe_ass_before
).run_unpacked(self.bbs)


class CFG(BaseCFG[BB]):
"""A control-flow graph of unchecked basic blocks."""
Expand All @@ -67,3 +59,13 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None:
"""Adds a control-flow edge between two basic blocks."""
src_bb.successors.append(tgt_bb)
tgt_bb.predecessors.append(src_bb)

def analyze(
self, def_ass_before: set[str], maybe_ass_before: set[str]
) -> dict[BB, VariableStats[str]]:
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
stats, def_ass_before, maybe_ass_before
).run_unpacked(self.bbs)
return stats
Loading

0 comments on commit 6698b75

Please sign in to comment.