Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Turn linearity checking into separate compiler stage #273

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import itertools
from collections.abc import Iterator
from typing import NamedTuple
Expand Down Expand Up @@ -270,7 +271,9 @@ def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.Name:
# assignment statement and replace the expression with `x`.
if not isinstance(node.target, ast.Name):
raise InternalGuppyError(f"Unexpected assign target: {node.target}")
assign = ast.Assign(targets=[node.target], value=self.visit(node.value))
assign = ast.Assign(
targets=[copy.deepcopy(node.target)], value=self.visit(node.value)
)
Comment on lines +274 to +276
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a drive-by?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is required now that we annotate the LHS of assignments with types

def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None:
"""Helper function to check assignments with patterns."""
match lhs:
# Easiest case is if the LHS pattern is a single variable.
case ast.Name(id=x):
# Store the type in the AST
with_type(ty, lhs)
self.ctx.locals[x] = Variable(x, ty, lhs)

If we don't make a copy, the type checker will see that a type has been assigned when the variable is used the next time and won't bother checking again which leads incorrect behaviour

set_location_from(assign, node)
self.bb.statements.append(assign)
return node.target
Expand Down
30 changes: 6 additions & 24 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from guppylang.cfg.cfg import CFG, BaseCFG
from guppylang.checker.core import Context, Globals, Locals, Variable
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.linearity_checker import check_cfg_linearity
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.tys.ty import Type
Expand Down Expand Up @@ -84,7 +85,7 @@ def check_cfg(
while len(queue) > 0:
pred, num_output, bb = queue.popleft()
input_row = [
Variable(v.name, v.ty, v.defined_at, None)
Variable(v.name, v.ty, v.defined_at)
for v in pred.sig.output_rows[num_output]
]

Expand Down Expand Up @@ -114,6 +115,10 @@ def check_cfg(
checked_cfg.maybe_ass_before = {
compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
}

# Finally, run the linearity check
check_cfg_linearity(checked_cfg)

return checked_cfg


Expand Down Expand Up @@ -160,29 +165,6 @@ def check_bb(
)
raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x])

# We have to check that used linear variables are not being outputted
if x in ctx.locals:
var = ctx.locals[x]
if var.ty.linear and var.used:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was "
"already used (at {0})",
cfg.live_before[succ][x].vars.used[x],
[var.used],
)

# On the other hand, unused linear variables *must* be outputted
for x, var in ctx.locals.items():
if var.ty.linear and not var.used and x not in cfg.live_before[succ]:
# TODO: This should be "Variable x with linear type ty is not
# used in {bb}". But for this we need a way to associate BBs with
# source locations.
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` is "
"not used on all control-flow paths",
var.defined_at,
)

# Finally, we need to compute the signature of the basic block
outputs = [
[ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals]
Expand Down
3 changes: 1 addition & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@
)


@dataclass
@dataclass(frozen=True)
class Variable:
"""Class holding data associated with a local variable."""

name: str
ty: Type
defined_at: AstNode | None
used: AstNode | None


PyScope = dict[str, Any]
Expand Down
65 changes: 4 additions & 61 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
AstVisitor,
breaks_in_loop,
get_type_opt,
name_nodes_in_ast,
return_nodes_in_ast,
with_loc,
with_type,
Expand Down Expand Up @@ -333,14 +332,6 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]:
x = node.id
if x in self.ctx.locals:
var = self.ctx.locals[x]
if var.ty.linear and var.used is not None:
raise GuppyError(
f"Variable `{x}` with linear type `{var.ty}` was "
"already used (at {0})",
node,
[var.used],
)
var.used = node
return with_loc(node, LocalName(id=x)), var.ty
elif x in self.ctx.globals:
# Look-up what kind of definition it is
Expand Down Expand Up @@ -874,34 +865,14 @@ def synthesize_comprehension(
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
"""Checks if an expression uses a linear variable from an outer scope.

Since the expression is executed multiple times in the inner scope, this would
mean that the outer linear variable is used multiple times, which is not
allowed.
"""
for name in name_nodes_in_ast(expr):
x = name.id
if x in locals and x not in locals.vars:
var = locals[x]
if var.ty.linear:
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` would be used "
"multiple times when evaluating this comprehension",
name,
)

# If there are no more generators left, we can check the list element
if not gens:
node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt)
check_linear_use_from_outer_scope(node.elt, ctx.locals)
return node, elt_ty

# Check the iterator in the outer context
gen, *gens = gens
gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign)
check_linear_use_from_outer_scope(gen.iter_assign.value, ctx.locals)

# The rest is checked in a new nested context to ensure that variables don't escape
# their scope
Expand All @@ -915,43 +886,15 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
gen.iter, iter_ty = expr_sth.visit_Name(gen.iter)
gen.iter = with_type(iter_ty, gen.iter)

# `if` guards are generally not allowed when we're iterating over linear variables.
# The only exception is if all linear variables are already consumed by the first
# guard
if gen.ifs:
gen.ifs[0], _ = expr_sth.synthesize(gen.ifs[0])

# Now, check if there are linear iteration variables that have not been used by
# the first guard
for target in name_nodes_in_ast(gen.next_assign.targets[0]):
var = inner_ctx.locals[target.id]
if var.ty.linear and not var.used and gen.ifs:
raise GuppyTypeError(
f"Variable `{var.name}` with linear type `{var.ty}` is not used on "
"all control-flow paths of the list comprehension",
target,
)

# Now, we can properly check all guards
for i in range(len(gen.ifs)):
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)
check_linear_use_from_outer_scope(gen.ifs[i], inner_locals)
# Check `if` guards
for i in range(len(gen.ifs)):
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
node, elt_ty = synthesize_comprehension(node, gens, inner_ctx)

# We have to make sure that all linear variables that were introduced in this scope
# have been used
for x, var in inner_ctx.locals.vars.items():
if var.ty.linear and not var.used:
raise GuppyTypeError(
f"Variable `{x}` with linear type `{var.ty}` is not used",
var.defined_at,
)

# The iter finalizer is again checked in the outer context
ctx.locals[gen.iter.id].used = None
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return node, elt_ty
Expand Down
6 changes: 3 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check_global_func_def(

cfg = CFGBuilder().build(func_def.body, returns_none, globals)
inputs = [
Variable(x, ty, loc, None)
Variable(x, ty, loc)
for x, ty, loc in zip(ty.input_names, ty.inputs, args, strict=True)
]
return check_cfg(cfg, inputs, ty.output, globals)
Expand Down Expand Up @@ -87,7 +87,7 @@ def check_nested_func_def(

# Construct inputs for checking the body CFG
inputs = list(captured.values()) + [
Variable(x, ty, func_def.args.args[i], None)
Variable(x, ty, func_def.args.args[i])
for i, (x, ty) in enumerate(
zip(func_ty.input_names, func_ty.inputs, strict=True)
)
Expand All @@ -111,7 +111,7 @@ def check_nested_func_def(
)
else:
# Otherwise, we treat it like a local name
inputs.append(Variable(func_def.name, func_def.ty, func_def, None))
inputs.append(Variable(func_def.name, func_def.ty, func_def))

checked_cfg = check_cfg(cfg, inputs, func_ty.output, globals)
checked_def = CheckedNestedFunctionDef(
Expand Down
Loading
Loading