Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[used before def] improve handling of global definitions in local scopes #14517

Merged
merged 14 commits into from
Mar 1, 2023
41 changes: 27 additions & 14 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def copy(self) -> BranchState:


class BranchStatement:
def __init__(self, initial_state: BranchState) -> None:
def __init__(self, initial_state: BranchState | None = None) -> None:
if initial_state is None:
initial_state = BranchState()
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(
Expand Down Expand Up @@ -171,7 +173,7 @@ class ScopeType(Enum):
Global = 1
Class = 2
Func = 3
Generator = 3
Generator = 4
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't matter until we needed to handle generators differently. Generators actually do inherit the scope!



class Scope:
Expand Down Expand Up @@ -199,7 +201,7 @@ class DefinedVariableTracker:

def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())], ScopeType.Global)]
self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)]
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
# in things like try/except/finally statements.
self.disable_branch_skip = False
Expand All @@ -216,9 +218,11 @@ def _scope(self) -> Scope:

def enter_scope(self, scope_type: ScopeType) -> None:
assert len(self._scope().branch_stmts) > 0
self.scopes.append(
Scope([BranchStatement(self._scope().branch_stmts[-1].branches[-1])], scope_type)
)
initial_state = None
if scope_type == ScopeType.Generator:
# Generators are special because they inherit the outer scope.
initial_state = self._scope().branch_stmts[-1].branches[-1]
self.scopes.append(Scope([BranchStatement(initial_state)], scope_type))

def exit_scope(self) -> None:
self.scopes.pop()
Expand Down Expand Up @@ -342,13 +346,15 @@ def variable_may_be_undefined(self, name: str, context: Context) -> None:
def process_definition(self, name: str) -> None:
# Was this name previously used? If yes, it's a used-before-definition error.
if not self.tracker.in_scope(ScopeType.Class):
# Errors in class scopes are caught by the semantic analyzer.
refs = self.tracker.pop_undefined_ref(name)
for ref in refs:
if self.loops:
self.variable_may_be_undefined(name, ref)
else:
self.var_used_before_def(name, ref)
else:
# Errors in class scopes are caught by the semantic analyzer.
pass
self.tracker.record_definition(name)

def visit_global_decl(self, o: GlobalDecl) -> None:
Expand Down Expand Up @@ -415,17 +421,24 @@ def visit_match_stmt(self, o: MatchStmt) -> None:

def visit_func_def(self, o: FuncDef) -> None:
self.process_definition(o.name)
self.tracker.enter_scope(ScopeType.Func)
super().visit_func_def(o)
self.tracker.exit_scope()

def visit_func(self, o: FuncItem) -> None:
if o.is_dynamic() and not self.options.check_untyped_defs:
return
if o.arguments is not None:
for arg in o.arguments:
self.tracker.record_definition(arg.variable.name)
super().visit_func(o)

args = o.arguments or []
# Process initializers (defaults) outside the function scope.
for arg in args:
if arg.initializer is not None:
arg.initializer.accept(self)

self.tracker.enter_scope(ScopeType.Func)
for arg in args:
self.process_definition(arg.variable.name)
super().visit_var(arg.variable)
o.body.accept(self)
self.tracker.exit_scope()

def visit_generator_expr(self, o: GeneratorExpr) -> None:
self.tracker.enter_scope(ScopeType.Generator)
Expand Down Expand Up @@ -603,7 +616,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
super().visit_starred_pattern(o)

def visit_name_expr(self, o: NameExpr) -> None:
if o.name in self.builtins:
if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
return
if self.tracker.is_possibly_undefined(o.name):
# A variable is only defined in some branches.
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/run-sets.test
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def test_in_set() -> None:
assert main_set(item), f"{item!r} should be in set_main"
assert not main_negated_set(item), item

assert non_final_name_set(non_const)
global non_const
assert non_final_name_set(non_const)
non_const = "updated"
assert non_final_name_set("updated")

Expand Down
39 changes: 19 additions & 20 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -491,62 +491,61 @@ if int():

[case testDefaultArgumentExpressions]
import typing
class B: pass
class A: pass

def f(x: 'A' = A()) -> None:
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A

class B: pass
class A: pass
[out]

[case testDefaultArgumentExpressions2]
import typing
def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A

class B: pass
class A: pass

def f(x: 'A' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
b = x # type: B # E: Incompatible types in assignment (expression has type "A", variable has type "B")
a = x # type: A
[case testDefaultArgumentExpressionsGeneric]
from typing import TypeVar
T = TypeVar('T', bound='A')
def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
a = x # type: A

class B: pass
class A: pass

def f(x: T = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "T")
b = x # type: B # E: Incompatible types in assignment (expression has type "T", variable has type "B")
a = x # type: A
[case testDefaultArgumentsWithSubtypes]
import typing
class A: pass
class B(A): pass

def f(x: 'B' = A()) -> None: # E: Incompatible default for argument "x" (default has type "A", argument has type "B")
pass
def g(x: 'A' = B()) -> None:
pass

class A: pass
class B(A): pass
[out]

[case testMultipleDefaultArgumentExpressions]
import typing
class A: pass
class B: pass

def f(x: 'A' = B(), y: 'B' = B()) -> None: # E: Incompatible default for argument "x" (default has type "B", argument has type "A")
pass
def h(x: 'A' = A(), y: 'B' = B()) -> None:
pass

class A: pass
class B: pass
[out]

[case testMultipleDefaultArgumentExpressions2]
import typing
def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
pass

class A: pass
class B: pass

def g(x: 'A' = A(), y: 'B' = A()) -> None: # E: Incompatible default for argument "y" (default has type "A", argument has type "B")
pass
[out]

[case testDefaultArgumentsAndSignatureAsComment]
Expand Down Expand Up @@ -2612,7 +2611,7 @@ def f() -> int: ...
[case testLambdaDefaultTypeErrors]
lambda a=(1 + 'asdf'): a # E: Unsupported operand types for + ("int" and "str")
lambda a=nonsense: a # E: Name "nonsense" is not defined
def f(x: int = i): # E: Name "i" is not defined # E: Name "i" is used before definition
def f(x: int = i): # E: Name "i" is not defined
i = 42

[case testRevealTypeOfCallExpressionReturningNoneWorks]
Expand Down
Loading