Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Place union and lower AST to it during checking #289

Merged
merged 16 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 45 additions & 33 deletions guppylang/cfg/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from typing import Generic, TypeVar

from guppylang.cfg.bb import BB
from guppylang.cfg.bb import BB, VariableStats, VId

# Type variable for the lattice domain
T = TypeVar("T")
Expand All @@ -11,7 +11,7 @@
Result = dict[BB, T]


class Analysis(ABC, Generic[T]):
class Analysis(Generic[T], ABC):
"""Abstract base class for a program analysis pass over the lattice `T`"""

def eq(self, t1: T, t2: T, /) -> bool:
Expand All @@ -34,7 +34,7 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
"""


class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
class ForwardAnalysis(Generic[T], Analysis[T], ABC):
"""Abstract base class for a program analysis pass running in forward direction."""

@abstractmethod
Expand All @@ -59,7 +59,7 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
return vals_before


class BackwardAnalysis(Analysis[T], ABC, Generic[T]):
class BackwardAnalysis(Generic[T], Analysis[T], ABC):
"""Abstract base class for a program analysis pass running in backward direction."""

@abstractmethod
Expand All @@ -85,83 +85,92 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:

# For live variable analysis, we also store a BB in which a use occurs as evidence of
# liveness.
LivenessDomain = dict[str, BB]
LivenessDomain = dict[VId, BB]


class LivenessAnalysis(BackwardAnalysis[LivenessDomain]):
class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
"""Live variable analysis pass.

Computes the variables that are live before the execution of each BB. The analysis
runs over the lattice of mappings from variable names to BBs containing a use.
"""

def eq(self, live1: LivenessDomain, live2: LivenessDomain) -> bool:
stats: dict[BB, VariableStats[VId]]

def __init__(self, stats: dict[BB, VariableStats[VId]]) -> None:
self.stats = stats

def eq(self, live1: LivenessDomain[VId], live2: LivenessDomain[VId]) -> bool:
# Only check that both contain the same variables. We don't care about the BB
# in which the use occurs, we just need any one, to report to the user.
return live1.keys() == live2.keys()

def initial(self) -> LivenessDomain:
def initial(self) -> LivenessDomain[VId]:
return {}

def join(self, *ts: LivenessDomain) -> LivenessDomain:
res: LivenessDomain = {}
def join(self, *ts: LivenessDomain[VId]) -> LivenessDomain[VId]:
res: LivenessDomain[VId] = {}
for t in ts:
res |= t
return res

def apply_bb(self, live_after: LivenessDomain, bb: BB) -> LivenessDomain:
return {x: bb for x in bb.vars.used} | {
x: b for x, b in live_after.items() if x not in bb.vars.assigned
def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
stats = self.stats[bb]
return {x: bb for x in stats.used} | {
x: b for x, b in live_after.items() if x not in stats.assigned
}


# Set of variables that are definitely assigned at the start of a BB
DefAssignmentDomain = set[str]
DefAssignmentDomain = set[VId]

# Set of variables that are assigned on (at least) some paths to a BB. Definitely
# assigned variables are a subset of this
MaybeAssignmentDomain = set[str]
MaybeAssignmentDomain = set[VId]

# For assignment analysis, we do definite- and maybe-assignment in one pass
AssignmentDomain = tuple[DefAssignmentDomain, MaybeAssignmentDomain]
AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]


class AssignmentAnalysis(ForwardAnalysis[AssignmentDomain]):
class AssignmentAnalysis(Generic[VId], ForwardAnalysis[AssignmentDomain[VId]]):
"""Assigned variable analysis pass.

Computes the set of variable that are definitely assigned at the start of a BB.
Additionally, we compute the set of variables that are assigned on (at least) some
paths to a BB (the definitely assigned variables are a subset of this).
Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
of a BB. Additionally, we compute the set of variables that are assigned on (at
least) some paths to a BB (the definitely assigned variables are a subset of this).
"""

all_vars: set[str]
ass_before_entry: set[str]
maybe_ass_before_entry: set[str]
stats: dict[BB, VariableStats[VId]]
all_vars: set[VId]
ass_before_entry: set[VId]
maybe_ass_before_entry: set[VId]

def __init__(
self,
bbs: Iterable[BB],
ass_before_entry: set[str],
maybe_ass_before_entry: set[str],
stats: dict[BB, VariableStats[VId]],
ass_before_entry: set[VId],
maybe_ass_before_entry: set[VId],
) -> None:
"""Constructs an `AssignmentAnalysis` pass for a CFG.

Also takes a set variables that are definitely assigned before the entry of the
CFG (for example function arguments).
"""
assert ass_before_entry.issubset(maybe_ass_before_entry)
self.stats = stats
self.ass_before_entry = ass_before_entry
self.maybe_ass_before_entry = maybe_ass_before_entry
self.all_vars = (
set.union(*(set(bb.vars.assigned.keys()) for bb in bbs)) | ass_before_entry
set.union(*(set(stat.assigned.keys()) for stat in stats.values()))
| ass_before_entry
)

def initial(self) -> AssignmentDomain:
def initial(self) -> AssignmentDomain[VId]:
# Note that definite assignment must start with `all_vars` instead of only
# `ass_before_entry` since we want to compute the *greatest* fixpoint.
return self.all_vars, self.maybe_ass_before_entry

def join(self, *ts: AssignmentDomain) -> AssignmentDomain:
def join(self, *ts: AssignmentDomain[VId]) -> AssignmentDomain[VId]:
# We always include the variables that are definitely assigned before the entry,
# even if the join is empty
if len(ts) == 0:
Expand All @@ -171,16 +180,19 @@ def join(self, *ts: AssignmentDomain) -> AssignmentDomain:
maybe_ass = set.union(*(maybe_ass for _, maybe_ass in ts))
return def_ass, maybe_ass

def apply_bb(self, val_before: AssignmentDomain, bb: BB) -> AssignmentDomain:
def apply_bb(
self, val_before: AssignmentDomain[VId], bb: BB
) -> AssignmentDomain[VId]:
stats = self.stats[bb]
def_ass_before, maybe_ass_before = val_before
return (
def_ass_before | bb.vars.assigned.keys(),
maybe_ass_before | bb.vars.assigned.keys(),
def_ass_before | stats.assigned.keys(),
maybe_ass_before | stats.assigned.keys(),
)

def run_unpacked(
self, bbs: Iterable[BB]
) -> tuple[Result[DefAssignmentDomain], Result[MaybeAssignmentDomain]]:
) -> tuple[Result[DefAssignmentDomain[VId]], Result[MaybeAssignmentDomain[VId]]]:
"""Runs the analysis and unpacks the definite- and maybe-assignment results."""
res = self.run(bbs)
return {bb: res[bb][0] for bb in res}, {bb: res[bb][1] for bb in res}
57 changes: 32 additions & 25 deletions guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ast
from abc import ABC
from collections.abc import Hashable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, TypeVar

from typing_extensions import Self

Expand All @@ -12,27 +13,21 @@
from guppylang.cfg.cfg import BaseCFG


# Type variable for ids of entities which we may wish to track during program analysis
# (generally ids for program variables or parts thereof)
VId = TypeVar("VId", bound=Hashable)


@dataclass
class VariableStats:
class VariableStats(Generic[VId]):
"""Stores variable usage information for a basic block."""

# Variables that are assigned in the BB
assigned: dict[str, AstNode] = field(default_factory=dict)
assigned: dict[VId, AstNode] = field(default_factory=dict)

# The (external) variables used in the BB, i.e. usages of variables that are
# assigned in the BB are not included here.
used: dict[str, ast.Name] = field(default_factory=dict)

def update_used(self, node: ast.AST) -> None:
"""Marks the variables occurring in a statement as used.

This method should be called whenever an expression is used in the BB.
"""
for name in name_nodes_in_ast(node):
# Should point to first use, so also check that the name is not already
# contained
if name.id not in self.assigned and name.id not in self.used:
self.used[name.id] = name
# created in the BB are not included here.
used: dict[VId, ast.Name] = field(default_factory=dict)


BBStatement = (
Expand Down Expand Up @@ -66,10 +61,10 @@ class BB(ABC):
branch_pred: ast.expr | None = None

# Information about assigned/used variables in the BB
_vars: VariableStats | None = None
_vars: VariableStats[str] | None = None

@property
def vars(self) -> VariableStats:
def vars(self) -> VariableStats[str]:
"""Returns variable usage information for this BB.

Note that `compute_variable_stats` must be called before this property can be
Expand All @@ -78,28 +73,41 @@ def vars(self) -> VariableStats:
assert self._vars is not None
return self._vars

def compute_variable_stats(self) -> None:
def compute_variable_stats(self) -> VariableStats[str]:
"""Determines which variables are assigned/used in this BB."""
visitor = VariableVisitor(self)
for s in self.statements:
visitor.visit(s)
if self.branch_pred is not None:
visitor.visit(self.branch_pred)
self._vars = visitor.stats
return visitor.stats


class VariableVisitor(ast.NodeVisitor):
"""Visitor that computes used and assigned variables in a BB."""

bb: BB
stats: VariableStats
stats: VariableStats[str]

def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def _update_used(self, node: ast.AST) -> None:
"""Marks the variables occurring in a statement as used.

This method should be called whenever an expression is used in the BB.
"""
for name in name_nodes_in_ast(node):
# Should point to first use, so also check that the name is not already
# contained
x = name.id
if x not in self.stats.assigned and x not in self.stats.used:
self.stats.used[x] = name

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)
self._update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
Expand All @@ -109,7 +117,7 @@ def visit_Assign(self, node: ast.Assign) -> None:

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
self._update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

Expand Down Expand Up @@ -147,9 +155,8 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# definition, we have to run live variable analysis first
from guppylang.cfg.analysis import LivenessAnalysis

for bb in node.cfg.bbs:
bb.compute_variable_stats()
live = LivenessAnalysis().run(node.cfg.bbs)
stats = {bb: bb.compute_variable_stats() for bb in node.cfg.bbs}
live = LivenessAnalysis(stats).run(node.cfg.bbs)

# Only store used *external* variables: things defined in the current BB, as
# well as the function name and argument names should not be included
Expand Down
26 changes: 14 additions & 12 deletions guppylang/cfg/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MaybeAssignmentDomain,
Result,
)
from guppylang.cfg.bb import BB, BBStatement
from guppylang.cfg.bb import BB, BBStatement, VariableStats

T = TypeVar("T", bound=BB)

Expand All @@ -20,9 +20,9 @@ class BaseCFG(Generic[T]):
entry_bb: T
exit_bb: T

live_before: Result[LivenessDomain]
ass_before: Result[DefAssignmentDomain]
maybe_ass_before: Result[MaybeAssignmentDomain]
live_before: Result[LivenessDomain[str]]
ass_before: Result[DefAssignmentDomain[str]]
maybe_ass_before: Result[MaybeAssignmentDomain[str]]

def __init__(
self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
Expand All @@ -36,14 +36,6 @@ def __init__(
self.ass_before = {}
self.maybe_ass_before = {}

def analyze(self, def_ass_before: set[str], maybe_ass_before: set[str]) -> None:
for bb in self.bbs:
bb.compute_variable_stats()
self.live_before = LivenessAnalysis().run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
self.bbs, def_ass_before, maybe_ass_before
).run_unpacked(self.bbs)


class CFG(BaseCFG[BB]):
"""A control-flow graph of unchecked basic blocks."""
Expand All @@ -67,3 +59,13 @@ def link(self, src_bb: BB, tgt_bb: BB) -> None:
"""Adds a control-flow edge between two basic blocks."""
src_bb.successors.append(tgt_bb)
tgt_bb.predecessors.append(src_bb)

def analyze(
self, def_ass_before: set[str], maybe_ass_before: set[str]
) -> dict[BB, VariableStats[str]]:
stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
self.live_before = LivenessAnalysis(stats).run(self.bbs)
self.ass_before, self.maybe_ass_before = AssignmentAnalysis(
stats, def_ass_before, maybe_ass_before
).run_unpacked(self.bbs)
return stats
Loading