Skip to content

Commit

Permalink
Improve type narrowing for walrus operator in conditional statements (#…
Browse files Browse the repository at this point in the history
…11202)

Authored-by: @kprzybyla <>
  • Loading branch information
hauntsaninja authored Sep 27, 2021
1 parent d469295 commit 8e82171
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 27 deletions.
6 changes: 3 additions & 3 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import Expression, Var, RefExpr
from mypy.literals import Key, literal, literal_hash, subkeys
from mypy.nodes import IndexExpr, MemberExpr, NameExpr
from mypy.nodes import IndexExpr, MemberExpr, AssignmentExpr, NameExpr


BindableExpression = Union[IndexExpr, MemberExpr, NameExpr]
BindableExpression = Union[IndexExpr, MemberExpr, AssignmentExpr, NameExpr]


class Frame:
Expand Down Expand Up @@ -136,7 +136,7 @@ def _get(self, key: Key, index: int = -1) -> Optional[Type]:
return None

def put(self, expr: Expression, typ: Type) -> None:
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
if not isinstance(expr, (IndexExpr, MemberExpr, AssignmentExpr, NameExpr)):
return
if not literal(expr):
return
Expand Down
64 changes: 42 additions & 22 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4286,12 +4286,10 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
type_map = self.type_map
if is_true_literal(node):
return {}, None
elif is_false_literal(node):
if is_false_literal(node):
return None, {}
elif isinstance(node, CallExpr):
self._check_for_truthy_type(type_map[node], node)
if len(node.args) == 0:
return {}, {}

if isinstance(node, CallExpr) and len(node.args) != 0:
expr = collapse_walrus(node.args[0])
if refers_to_fullname(node.callee, 'builtins.isinstance'):
if len(node.args) != 2: # the error will be reported elsewhere
Expand Down Expand Up @@ -4472,21 +4470,27 @@ def has_no_custom_eq_checks(t: Type) -> bool:

return reduce_conditional_maps(partial_type_maps)
elif isinstance(node, AssignmentExpr):
return self.find_isinstance_check_helper(node.target)
elif isinstance(node, RefExpr):
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
self._check_for_truthy_type(vartype, node)
if_type: Type = true_only(vartype)
else_type: Type = false_only(vartype)
ref: Expression = node
if_map = ({ref: if_type} if not isinstance(get_proper_type(if_type), UninhabitedType)
else None)
else_map = ({ref: else_type} if not isinstance(get_proper_type(else_type),
UninhabitedType)
else None)
return if_map, else_map
if_map = {}
else_map = {}

if_assignment_map, else_assignment_map = self.find_isinstance_check_helper(node.target)

if if_assignment_map is not None:
if_map.update(if_assignment_map)
if else_assignment_map is not None:
else_map.update(else_assignment_map)

if_condition_map, else_condition_map = self.find_isinstance_check_helper(node.value)

if if_condition_map is not None:
if_map.update(if_condition_map)
if else_condition_map is not None:
else_map.update(else_condition_map)

return (
(None if if_assignment_map is None or if_condition_map is None else if_map),
(None if else_assignment_map is None or else_condition_map is None else else_map),
)
elif isinstance(node, OpExpr) and node.op == 'and':
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
Expand All @@ -4507,8 +4511,24 @@ def has_no_custom_eq_checks(t: Type) -> bool:
left, right = self.find_isinstance_check_helper(node.expr)
return right, left

# Not a supported isinstance check
return {}, {}
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
self._check_for_truthy_type(vartype, node)
if_type = true_only(vartype) # type: Type
else_type = false_only(vartype) # type: Type
ref = node # type: Expression
if_map = (
{ref: if_type}
if not isinstance(get_proper_type(if_type), UninhabitedType)
else None
)
else_map = (
{ref: else_type}
if not isinstance(get_proper_type(else_type), UninhabitedType)
else None
)
return if_map, else_map

def propagate_up_typemap_info(self,
existing_types: Mapping[Expression, Type],
Expand Down
7 changes: 5 additions & 2 deletions mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def literal(e: Expression) -> int:
elif isinstance(e, (MemberExpr, UnaryExpr, StarExpr)):
return literal(e.expr)

elif isinstance(e, AssignmentExpr):
return literal(e.target)

elif isinstance(e, IndexExpr):
if literal(e.index) == LITERAL_YES:
return literal(e.base)
Expand Down Expand Up @@ -160,8 +163,8 @@ def visit_index_expr(self, e: IndexExpr) -> Optional[Key]:
return ('Index', literal_hash(e.base), literal_hash(e.index))
return None

def visit_assignment_expr(self, e: AssignmentExpr) -> None:
return None
def visit_assignment_expr(self, e: AssignmentExpr) -> Optional[Key]:
return literal_hash(e.target)

def visit_call_expr(self, e: CallExpr) -> None:
return None
Expand Down
122 changes: 122 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,57 @@ reveal_type(z2) # E: Name "z2" is not defined # N: Revealed type is "Any"

[builtins fixtures/isinstancelist.pyi]

[case testWalrusConditionalTypeBinder]
from typing import Union
from typing_extensions import Literal

class Good:
@property
def is_good(self) -> Literal[True]: ...

class Bad:
@property
def is_good(self) -> Literal[False]: ...

def get_thing() -> Union[Good, Bad]: ...

if (thing := get_thing()).is_good:
reveal_type(thing) # N: Revealed type is "__main__.Good"
else:
reveal_type(thing) # N: Revealed type is "__main__.Bad"
[builtins fixtures/property.pyi]

[case testWalrusConditionalTypeCheck]
# flags: --strict-optional
from typing import Optional

maybe_str: Optional[str]

if (is_str := maybe_str is not None):
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(maybe_str) # N: Revealed type is "builtins.str"
else:
reveal_type(is_str) # N: Revealed type is "builtins.bool"
reveal_type(maybe_str) # N: Revealed type is "None"

reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
[builtins fixtures/bool.pyi]

[case testWalrusConditionalTypeCheck2]
from typing import Optional

maybe_str: Optional[str]

if (x := maybe_str) is not None:
reveal_type(x) # N: Revealed type is "builtins.str"
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
else:
reveal_type(x) # N: Revealed type is "None"
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"

reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
[builtins fixtures/bool.pyi]

[case testWalrusPartialTypes]
from typing import List

Expand All @@ -400,6 +451,77 @@ def check_partial_list() -> None:
reveal_type(z) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testWalrusAssignmentAndConditionScopeForLiteral]
# flags: --warn-unreachable

if (x := 0):
reveal_type(x) # E: Statement is unreachable
else:
reveal_type(x) # N: Revealed type is "builtins.int"

reveal_type(x) # N: Revealed type is "builtins.int"

[case testWalrusAssignmentAndConditionScopeForProperty]
# flags: --warn-unreachable

from typing_extensions import Literal

class PropertyWrapper:
@property
def f(self) -> str: ...
@property
def always_false(self) -> Literal[False]: ...

wrapper = PropertyWrapper()

if x := wrapper.f:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "builtins.str"

reveal_type(x) # N: Revealed type is "builtins.str"

if y := wrapper.always_false:
reveal_type(y) # E: Statement is unreachable
else:
reveal_type(y) # N: Revealed type is "Literal[False]"

reveal_type(y) # N: Revealed type is "Literal[False]"
[builtins fixtures/property.pyi]

[case testWalrusAssignmentAndConditionScopeForFunction]
# flags: --warn-unreachable

from typing_extensions import Literal

def f() -> str: ...

if x := f():
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "builtins.str"

reveal_type(x) # N: Revealed type is "builtins.str"

def always_false() -> Literal[False]: ...

if y := always_false():
reveal_type(y) # E: Statement is unreachable
else:
reveal_type(y) # N: Revealed type is "Literal[False]"

reveal_type(y) # N: Revealed type is "Literal[False]"

def always_false_with_parameter(x: int) -> Literal[False]: ...

if z := always_false_with_parameter(5):
reveal_type(z) # E: Statement is unreachable
else:
reveal_type(z) # N: Revealed type is "Literal[False]"

reveal_type(z) # N: Revealed type is "Literal[False]"
[builtins fixtures/tuple.pyi]

[case testWalrusExpr]
def func() -> None:
foo = Foo()
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-unreachable-code.test
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,15 @@ def f(t: T) -> None:
except BaseException as e:
pass
[builtins fixtures/dict.pyi]


[case testUnreachableLiteral]
# flags: --warn-unreachable
from typing_extensions import Literal

def nope() -> Literal[False]: ...

def f() -> None:
if nope():
x = 1 # E: Statement is unreachable
[builtins fixtures/dict.pyi]

0 comments on commit 8e82171

Please sign in to comment.