Skip to content

Commit

Permalink
Propagate type narrowing to nested functions (#15133)
Browse files Browse the repository at this point in the history
Fixes #2608.

Use the heuristic suggested in #2608 and allow narrowed types of
variables (but not attributes) to be propagated to nested functions if
the variable is not assigned to after the definition of the nested
function in the outer function.

Since we don't have a full control flow graph, we simply look for
assignments that are textually after the nested function in the outer
function. This can result in false negatives (at least in loops) and
false positives (in if statements, and if the assigned type is narrow
enough), but I expect these to be rare and not a significant issue. Type
narrowing is already unsound, and the additional unsoundness seems
minor, while the usability benefit is big.

This doesn't do the right thing for nested classes yet. I'll create an
issue to track that.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
JukkaL and AlexWaygood authored May 3, 2023
1 parent d71ece8 commit 8c14cba
Show file tree
Hide file tree
Showing 8 changed files with 518 additions and 5 deletions.
6 changes: 5 additions & 1 deletion mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, id: int, conditional_frame: bool = False) -> None:
# need this field.
self.suppress_unreachable_warnings = False

def __repr__(self) -> str:
return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})"


Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]]

Expand All @@ -63,7 +66,7 @@ class ConditionalTypeBinder:
```
class A:
a = None # type: Union[int, str]
a: Union[int, str] = None
x = A()
lst = [x]
reveal_type(x.a) # Union[int, str]
Expand Down Expand Up @@ -446,6 +449,7 @@ def top_frame_context(self) -> Iterator[Frame]:
assert len(self.frames) == 1
yield self.push_frame()
self.pop_frame(True, 0)
assert len(self.frames) == 1


def get_declaration(expr: BindableExpression) -> Type | None:
Expand Down
117 changes: 114 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import mypy.checkexpr
from mypy import errorcodes as codes, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, get_declaration
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checkmember import (
MemberContext,
analyze_decorator_or_funcbase_access,
Expand All @@ -41,7 +41,7 @@
from mypy.errors import Errors, ErrorWatcher, report_internal_error
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
from mypy.join import join_types
from mypy.literals import Key, literal, literal_hash
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types
from mypy.message_registry import ErrorMessage
Expand Down Expand Up @@ -134,6 +134,7 @@
is_final_node,
)
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.scope import Scope
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
Expand All @@ -151,7 +152,7 @@
restrict_subtype_away,
unify_generic_callable,
)
from mypy.traverser import all_return_statements, has_return_statement
from mypy.traverser import TraverserVisitor, all_return_statements, has_return_statement
from mypy.treetransform import TransformVisitor
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
from mypy.typeops import (
Expand Down Expand Up @@ -1207,6 +1208,20 @@ def check_func_def(

# Type check body in a new scope.
with self.binder.top_frame_context():
# Copy some type narrowings from an outer function when it seems safe enough
# (i.e. we can't find an assignment that might change the type of the
# variable afterwards).
new_frame: Frame | None = None
for frame in old_binder.frames:
for key, narrowed_type in frame.types.items():
key_var = extract_var_from_literal_hash(key)
if key_var is not None and not self.is_var_redefined_in_outer_context(
key_var, defn.line
):
# It seems safe to propagate the type narrowing to a nested scope.
if new_frame is None:
new_frame = self.binder.push_frame()
new_frame.types[key] = narrowed_type
with self.scope.push_function(defn):
# We suppress reachability warnings when we use TypeVars with value
# restrictions: we only want to report a warning if a certain statement is
Expand All @@ -1218,6 +1233,8 @@ def check_func_def(
self.binder.suppress_unreachable_warnings()
self.accept(item.body)
unreachable = self.binder.is_unreachable()
if new_frame is not None:
self.binder.pop_frame(True, 0)

if not unreachable:
if defn.is_generator or is_named_instance(
Expand Down Expand Up @@ -1310,6 +1327,23 @@ def check_func_def(

self.binder = old_binder

def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool:
"""Can the variable be assigned to at module top level or outer function?
Note that this doesn't do a full CFG analysis but uses a line number based
heuristic that isn't correct in some (rare) cases.
"""
outers = self.tscope.outer_functions()
if not outers:
# Top-level function -- outer context is top level, and we can't reason about
# globals
return True
for outer in outers:
if isinstance(outer, FuncDef):
if find_last_var_assignment_line(outer.body, v) >= after_line:
return True
return False

def check_unbound_return_typevar(self, typ: CallableType) -> None:
"""Fails when the return typevar is not defined in arguments."""
if isinstance(typ.ret_type, TypeVarType) and typ.ret_type in typ.variables:
Expand Down Expand Up @@ -7629,3 +7663,80 @@ def collapse_walrus(e: Expression) -> Expression:
if isinstance(e, AssignmentExpr):
return e.target
return e


def find_last_var_assignment_line(n: Node, v: Var) -> int:
"""Find the highest line number of a potential assignment to variable within node.
This supports local and global variables.
Return -1 if no assignment was found.
"""
visitor = VarAssignVisitor(v)
n.accept(visitor)
return visitor.last_line


class VarAssignVisitor(TraverserVisitor):
def __init__(self, v: Var) -> None:
self.last_line = -1
self.lvalue = False
self.var_node = v

def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.lvalue = True
for lv in s.lvalues:
lv.accept(self)
self.lvalue = False

def visit_name_expr(self, e: NameExpr) -> None:
if self.lvalue and e.node is self.var_node:
self.last_line = max(self.last_line, e.line)

def visit_member_expr(self, e: MemberExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_member_expr(e)
self.lvalue = old_lvalue

def visit_index_expr(self, e: IndexExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_index_expr(e)
self.lvalue = old_lvalue

def visit_with_stmt(self, s: WithStmt) -> None:
self.lvalue = True
for lv in s.target:
if lv is not None:
lv.accept(self)
self.lvalue = False
s.body.accept(self)

def visit_for_stmt(self, s: ForStmt) -> None:
self.lvalue = True
s.index.accept(self)
self.lvalue = False
s.body.accept(self)
if s.else_body:
s.else_body.accept(self)

def visit_assignment_expr(self, e: AssignmentExpr) -> None:
self.lvalue = True
e.target.accept(self)
self.lvalue = False
e.value.accept(self)

def visit_as_pattern(self, p: AsPattern) -> None:
if p.pattern is not None:
p.pattern.accept(self)
if p.name is not None:
self.lvalue = True
p.name.accept(self)
self.lvalue = False

def visit_starred_pattern(self, p: StarredPattern) -> None:
if p.capture is not None:
self.lvalue = True
p.capture.accept(self)
self.lvalue = False
3 changes: 2 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,8 @@ def visit_MatchStar(self, n: MatchStar) -> StarredPattern:
if n.name is None:
node = StarredPattern(None)
else:
node = StarredPattern(NameExpr(n.name))
name = self.set_line(NameExpr(n.name), n)
node = StarredPattern(name)

return self.set_line(node, n)

Expand Down
10 changes: 10 additions & 0 deletions mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ def literal_hash(e: Expression) -> Key | None:
return e.accept(_hasher)


def extract_var_from_literal_hash(key: Key) -> Var | None:
"""If key refers to a Var node, return it.
Return None otherwise.
"""
if len(key) == 2 and key[0] == "Var" and isinstance(key[1], Var):
return key[1]
return None


class _Hasher(ExpressionVisitor[Optional[Key]]):
def visit_int_expr(self, e: IntExpr) -> Key:
return ("Literal", e.value)
Expand Down
6 changes: 6 additions & 0 deletions mypy/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self) -> None:
self.module: str | None = None
self.classes: list[TypeInfo] = []
self.function: FuncBase | None = None
self.functions: list[FuncBase] = []
# Number of nested scopes ignored (that don't get their own separate targets)
self.ignored = 0

Expand Down Expand Up @@ -65,19 +66,24 @@ def module_scope(self, prefix: str) -> Iterator[None]:

@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
self.functions.append(fdef)
if not self.function:
self.function = fdef
else:
# Nested functions are part of the topmost function target.
self.ignored += 1
yield
self.functions.pop()
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
else:
assert self.function
self.function = None

def outer_functions(self) -> list[FuncBase]:
return self.functions[:-1]

def enter_class(self, info: TypeInfo) -> None:
"""Enter a class target scope."""
if not self.function:
Expand Down
Loading

0 comments on commit 8c14cba

Please sign in to comment.