Skip to content

Commit

Permalink
Implement miscellaneous fixes for partially-defined check (#14175)
Browse files Browse the repository at this point in the history
These are the issues that I've found using mypy-primer.

You should be able to review this PR commit-by-commit. Each commit
includes the relevant tests:
- Process imports correctly
- Support for function names
- Skip stub files (this change has no tests)
- Handle builtins and implicit module attrs (e.g. `str` and `__doc__`)
- Improved support for lambdas.
  • Loading branch information
ilinum authored Nov 25, 2022
1 parent db0beb1 commit d58a851
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2349,6 +2349,9 @@ def type_check_second_pass(self) -> bool:

def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
if self.tree.is_stub:
# We skip stub files because they aren't actually executed.
return
manager = self.manager
if manager.errors.is_error_code_enabled(
codes.PARTIALLY_DEFINED
Expand Down
43 changes: 43 additions & 0 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
FuncItem,
GeneratorExpr,
IfStmt,
Import,
ImportFrom,
LambdaExpr,
ListExpr,
Lvalue,
MatchStmt,
NameExpr,
RaiseStmt,
RefExpr,
ReturnStmt,
StarExpr,
TupleExpr,
WhileStmt,
WithStmt,
implicit_module_attrs,
)
from mypy.patterns import AsPattern, StarredPattern
from mypy.reachability import ALWAYS_TRUE, infer_pattern_value
Expand Down Expand Up @@ -213,6 +219,10 @@ def is_undefined(self, name: str) -> bool:
return self._scope().branch_stmts[-1].is_undefined(name)


def refers_to_builtin(o: RefExpr) -> bool:
return o.fullname is not None and o.fullname.startswith("builtins.")


class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor):
"""Detects the following cases:
- A variable that's defined only part of the time.
Expand All @@ -236,6 +246,8 @@ def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> Non
self.type_map = type_map
self.loop_depth = 0
self.tracker = DefinedVariableTracker()
for name in implicit_module_attrs:
self.tracker.record_definition(name)

def process_lvalue(self, lvalue: Lvalue | None) -> None:
if isinstance(lvalue, NameExpr):
Expand All @@ -244,6 +256,8 @@ def process_lvalue(self, lvalue: Lvalue | None) -> None:
for ref in refs:
self.msg.var_used_before_def(lvalue.name, ref)
self.tracker.record_definition(lvalue.name)
elif isinstance(lvalue, StarExpr):
self.process_lvalue(lvalue.expr)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
self.process_lvalue(item)
Expand Down Expand Up @@ -291,6 +305,7 @@ def visit_match_stmt(self, o: MatchStmt) -> None:
self.tracker.end_branch_statement()

def visit_func_def(self, o: FuncDef) -> None:
self.tracker.record_definition(o.name)
self.tracker.enter_scope()
super().visit_func_def(o)
self.tracker.exit_scope()
Expand Down Expand Up @@ -332,6 +347,11 @@ def visit_return_stmt(self, o: ReturnStmt) -> None:
super().visit_return_stmt(o)
self.tracker.skip_branch()

def visit_lambda_expr(self, o: LambdaExpr) -> None:
self.tracker.enter_scope()
super().visit_lambda_expr(o)
self.tracker.exit_scope()

def visit_assert_stmt(self, o: AssertStmt) -> None:
super().visit_assert_stmt(o)
if checker.is_false_literal(o.expr):
Expand Down Expand Up @@ -377,6 +397,8 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
super().visit_starred_pattern(o)

def visit_name_expr(self, o: NameExpr) -> None:
if refers_to_builtin(o):
return
if self.tracker.is_partially_defined(o.name):
# A variable is only defined in some branches.
if self.msg.errors.is_error_code_enabled(errorcodes.PARTIALLY_DEFINED):
Expand Down Expand Up @@ -404,3 +426,24 @@ def visit_with_stmt(self, o: WithStmt) -> None:
expr.accept(self)
self.process_lvalue(idx)
o.body.accept(self)

def visit_import(self, o: Import) -> None:
for mod, alias in o.ids:
if alias is not None:
self.tracker.record_definition(alias)
else:
# When you do `import x.y`, only `x` becomes defined.
names = mod.split(".")
if len(names) > 0:
# `names` should always be nonempty, but we don't want mypy
# to crash on invalid code.
self.tracker.record_definition(names[0])
super().visit_import(o)

def visit_import_from(self, o: ImportFrom) -> None:
for mod, alias in o.names:
name = alias
if name is None:
name = mod
self.tracker.record_definition(name)
super().visit_import_from(o)
121 changes: 121 additions & 0 deletions test-data/unit/check-partially-defined.test
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ else:
a = y + x # E: Name "x" may be undefined
a = y + z # E: Name "z" may be undefined

[case testIndexExpr]
# flags: --enable-error-code partially-defined

if int():
*x, y = (1, 2)
else:
x = [1, 2]
a = x # No error.
b = y # E: Name "y" may be undefined

[case testRedefined]
# flags: --enable-error-code partially-defined
y = 3
Expand All @@ -104,6 +114,32 @@ else:

x = y + 2

[case testFunction]
# flags: --enable-error-code partially-defined
def f0() -> None:
if int():
def some_func() -> None:
pass

some_func() # E: Name "some_func" may be undefined

def f1() -> None:
if int():
def some_func() -> None:
pass
else:
def some_func() -> None:
pass

some_func() # No error.

[case testLambda]
# flags: --enable-error-code partially-defined
def f0(b: bool) -> None:
if b:
fn = lambda: 2
y = fn # E: Name "fn" may be undefined

[case testGenerator]
# flags: --enable-error-code partially-defined
if int():
Expand Down Expand Up @@ -460,3 +496,88 @@ def f4() -> None:
y = z # E: Name "z" is used before definition
x = z # E: Name "z" is used before definition
z: int = 2

[case testUseBeforeDefImportsBasic]
# flags: --enable-error-code use-before-def
import foo # type: ignore
import x.y # type: ignore

def f0() -> None:
a = foo # No error.
foo: int = 1

def f1() -> None:
a = y # E: Name "y" is used before definition
y: int = 1

def f2() -> None:
a = x # No error.
x: int = 1

def f3() -> None:
a = x.y # No error.
x: int = 1

[case testUseBeforeDefImportBasicRename]
# flags: --enable-error-code use-before-def
import x.y as z # type: ignore
from typing import Any

def f0() -> None:
a = z # No error.
z: int = 1

def f1() -> None:
a = x # E: Name "x" is used before definition
x: int = 1

def f2() -> None:
a = x.y # E: Name "x" is used before definition
x: Any = 1

def f3() -> None:
a = y # E: Name "y" is used before definition
y: int = 1

[case testUseBeforeDefImportFrom]
# flags: --enable-error-code use-before-def
from foo import x # type: ignore

def f0() -> None:
a = x # No error.
x: int = 1

[case testUseBeforeDefImportFromRename]
# flags: --enable-error-code use-before-def
from foo import x as y # type: ignore

def f0() -> None:
a = y # No error.
y: int = 1

def f1() -> None:
a = x # E: Name "x" is used before definition
x: int = 1

[case testUseBeforeDefFunctionDeclarations]
# flags: --enable-error-code use-before-def

def f0() -> None:
def inner() -> None:
pass

inner() # No error.
inner = lambda: None

[case testUseBeforeDefBuiltins]
# flags: --enable-error-code use-before-def

def f0() -> None:
s = type(123)
type = "abc"
a = type

[case testUseBeforeDefImplicitModuleAttrs]
# flags: --enable-error-code use-before-def
a = __name__ # No error.
__name__ = "abc"

0 comments on commit d58a851

Please sign in to comment.