Skip to content

Commit

Permalink
Make warn-unreachable understand exception-swallowing contextmanagers
Browse files Browse the repository at this point in the history
This pull request fixes python#7214:
it makes mypy treat any contextmanagers where the `__exit__` returns
`bool` or `Literal[True]` as ones that can potentially swallow
exceptions.

Contextmanagers that return `Optional[bool]`, None, or `Literal[False]`
continue to be treated as non-exception-swallowing ones.

This distinction helps the `--warn-unreachable` flag do the right thing
in this example program:

```python
from contextlib import suppress

def should_warn() -> str:
    with contextlib.suppress(IndexError):
        return ["a", "b", "c"][0]

def should_not_warn() -> str:
    with open("foo.txt") as f:
        return "blah"
```

This pull request needs the typeshed changes I made in
python/typeshed#3179. Once that one gets
merged, I'll update typeshed and rebase this pull request.
  • Loading branch information
Michael0x2a committed Aug 10, 2019
1 parent 5bb6796 commit b667bd1
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 10 deletions.
34 changes: 24 additions & 10 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Instance, NoneType, strip_type, TypeType, TypeOfAny,
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
true_only, false_only, function_type, is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType
is_optional, remove_optional, is_literal_type, TypeTranslator, StarType
)
from mypy.sametypes import is_same_type
from mypy.messages import (
Expand Down Expand Up @@ -3290,12 +3290,24 @@ def check_incompatible_property_override(self, e: Decorator) -> None:
self.fail(message_registry.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e)

def visit_with_stmt(self, s: WithStmt) -> None:
exceptions_maybe_suppressed = False
for expr, target in zip(s.expr, s.target):
if s.is_async:
self.check_async_with_item(expr, target, s.unanalyzed_type is None)
exit_ret_type = self.check_async_with_item(expr, target, s.unanalyzed_type is None)
else:
self.check_with_item(expr, target, s.unanalyzed_type is None)
self.accept(s.body)
exit_ret_type = self.check_with_item(expr, target, s.unanalyzed_type is None)
if is_literal_type(exit_ret_type, "builtins.bool", False):
continue
if is_literal_type(exit_ret_type, "builtins.bool", True):
exceptions_maybe_suppressed = True
elif (isinstance(exit_ret_type, Instance)
and exit_ret_type.type.fullname() == 'builtins.bool'):
exceptions_maybe_suppressed = True
if exceptions_maybe_suppressed:
with self.binder.frame_context(can_skip=True, try_frame=True):
self.accept(s.body)
else:
self.accept(s.body)

def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None:
if not self.options.disallow_any_decorated or self.is_stub:
Expand All @@ -3305,7 +3317,7 @@ def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None:
self.msg.untyped_decorated_function(typ, func)

def check_async_with_item(self, expr: Expression, target: Optional[Expression],
infer_lvalue_type: bool) -> None:
infer_lvalue_type: bool) -> Type:
echk = self.expr_checker
ctx = echk.accept(expr)
obj = echk.check_method_call_by_name('__aenter__', ctx, [], [], expr)[0]
Expand All @@ -3314,20 +3326,22 @@ def check_async_with_item(self, expr: Expression, target: Optional[Expression],
if target:
self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type)
arg = self.temp_node(AnyType(TypeOfAny.special_form), expr)
res = echk.check_method_call_by_name(
'__aexit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0]
echk.check_awaitable_expr(
res, _ = echk.check_method_call_by_name(
'__aexit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr)
return echk.check_awaitable_expr(
res, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT)

def check_with_item(self, expr: Expression, target: Optional[Expression],
infer_lvalue_type: bool) -> None:
infer_lvalue_type: bool) -> Type:
echk = self.expr_checker
ctx = echk.accept(expr)
obj = echk.check_method_call_by_name('__enter__', ctx, [], [], expr)[0]
if target:
self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type)
arg = self.temp_node(AnyType(TypeOfAny.special_form), expr)
echk.check_method_call_by_name('__exit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr)
res, _ = echk.check_method_call_by_name(
'__exit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr)
return res

def visit_print_stmt(self, s: PrintStmt) -> None:
for arg in s.args:
Expand Down
13 changes: 13 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,19 @@ def remove_optional(typ: Type) -> Type:
return typ


def is_literal_type(typ: Type, fallback_fullname: str, value: LiteralValue) -> bool:
"""Returns 'true' if this type is a LiteralType with the given value
and underlying base fallback type.
"""
if isinstance(typ, Instance) and typ.last_known_value:
typ = typ.last_known_value
if not isinstance(typ, LiteralType):
return False
if typ.fallback.type.fullname() != fallback_fullname:
return False
return typ.value == value


names = globals().copy() # type: Final
names.pop('NOT_READY', None)
deserialize_map = {
Expand Down
226 changes: 226 additions & 0 deletions test-data/unit/check-unreachable-code.test
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,229 @@ class Test3(Generic[T2]):
reveal_type(self.x)

[builtins fixtures/isinstancelist.pyi]

[case testUnreachableFlagContextManagers]
# flags: --warn-unreachable
from contextlib import contextmanager
from typing import Optional, Iterator, Any
from typing_extensions import Literal
class DoesNotSuppress1:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ...

class DoesNotSuppress2:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[False]: ...

class DoesNotSuppress3:
def __enter__(self) -> int: ...
# TODO: We should report an error when somebody tries using a context manager like this
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> object: ...

class DoesNotSuppress4:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Any: ...

class DoesNotSuppress5:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> None: ...

class Suppresses1:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ...

class Suppresses2:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[True]: ...

@contextmanager
def simple() -> Iterator[int]:
yield 3

def cond() -> bool: ...

def noop() -> None: ...

def f_no_suppress_1() -> int:
with DoesNotSuppress1():
return 3
noop() # E: Statement is unreachable

def f_no_suppress_2() -> int:
with DoesNotSuppress1():
if cond():
return 3
else:
return 3
noop() # E: Statement is unreachable

def f_no_suppress_3() -> int:
with DoesNotSuppress2():
return 3
noop() # E: Statement is unreachable

def f_no_suppress_4() -> int:
with DoesNotSuppress3():
return 3
noop() # E: Statement is unreachable

def f_no_suppress_5() -> int:
with DoesNotSuppress4():
return 3
noop() # E: Statement is unreachable

def f_no_suppress_6() -> int:
with DoesNotSuppress5():
return 3
noop() # E: Statement is unreachable

def f_no_suppress_7() -> int:
with simple():
return 3
noop() # E: Statement is unreachable

def f_suppress_1() -> int: # E: Missing return statement
with Suppresses1():
return 3
noop()

def f_suppress_2() -> int: # E: Missing return statement
with Suppresses1():
if cond():
return 3
else:
return 3
noop()

def f_suppress_3() -> int: # E: Missing return statement
with Suppresses2():
return 3
noop()

def f_mix() -> int: # E: Missing return statement
with DoesNotSuppress1(), Suppresses1(), DoesNotSuppress1():
return 3
noop()
[typing fixtures/typing-full.pyi]

[case testUnreachableFlagContextAsyncManagers]
# flags: --warn-unreachable --python-version 3.7
from contextlib import asynccontextmanager
from typing import Optional, AsyncIterator, Any
from typing_extensions import Literal

class DoesNotSuppress1:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ...

class DoesNotSuppress2:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[False]: ...

class DoesNotSuppress3:
async def __aenter__(self) -> int: ...
# TODO: We should report an error when somebody tries using a context manager like this
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> object: ...

class DoesNotSuppress4:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Any: ...

class DoesNotSuppress5:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> None: ...

class Suppresses1:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ...

class Suppresses2:
async def __aenter__(self) -> int: ...
async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[True]: ...

class NotAsyncManager:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ...

@asynccontextmanager
async def simple() -> AsyncIterator[int]:
yield 3

def cond() -> bool: ...

def noop() -> None: ...

async def f_no_suppress_1() -> int:
async with DoesNotSuppress1():
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_2() -> int:
async with DoesNotSuppress1():
if cond():
return 3
else:
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_3() -> int:
async with DoesNotSuppress2():
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_4() -> int:
async with DoesNotSuppress3():
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_5() -> int:
async with DoesNotSuppress4():
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_6() -> int:
async with DoesNotSuppress5():
return 3
noop() # E: Statement is unreachable

async def f_no_suppress_7() -> int:
async with simple():
return 3
noop() # E: Statement is unreachable

async def f_suppress_1() -> int: # E: Missing return statement
async with Suppresses1():
return 3
noop()

async def f_suppress_2() -> int: # E: Missing return statement
async with Suppresses1():
if cond():
return 3
else:
return 3
noop()

async def f_suppress_3() -> int: # E: Missing return statement
async with Suppresses2():
return 3
noop()

async def f_mix() -> int: # E: Missing return statement
async with DoesNotSuppress1(), Suppresses1(), DoesNotSuppress1():
return 3
noop()

async def f_bad_1() -> int:
async with NotAsyncManager(): # E: "NotAsyncManager" has no attribute "__aenter__"; maybe "__enter__"? \
# E: "NotAsyncManager" has no attribute "__aexit__"; maybe "__exit__"?
return 3
noop() # E: Statement is unreachable

def f_bad_2() -> int:
with DoesNotSuppress1(): # E: "DoesNotSuppress1" has no attribute "__enter__"; maybe "__aenter__"? \
# E: "DoesNotSuppress1" has no attribute "__exit__"; maybe "__aexit__"?
return 3
noop() # E: Statement is unreachable

[typing fixtures/typing-full.pyi]
6 changes: 6 additions & 0 deletions test-data/unit/lib-stub/contextlib.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Generic, TypeVar, Callable, Iterator
from typing import ContextManager as ContextManager

Expand All @@ -8,3 +9,8 @@ class GeneratorContextManager(ContextManager[_T], Generic[_T]):

def contextmanager(func: Callable[..., Iterator[_T]]) -> Callable[..., GeneratorContextManager[_T]]:
...

if sys.version_info >= (3, 7):
from typing import AsyncIterator
from typing import AsyncContextManager as AsyncContextManager
def asynccontextmanager(func: Callable[..., AsyncIterator[_T]]) -> Callable[..., AsyncContextManager[_T]]: ...
20 changes: 20 additions & 0 deletions test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -1405,3 +1405,23 @@ _testStrictEqualityWhitelist.py:5: error: Non-overlapping equality check (left o
_testStrictEqualityWhitelist.py:11: error: Non-overlapping equality check (left operand type: "KeysView[int]", right operand type: "Set[str]")
_testStrictEqualityWhitelist.py:12: error: Non-overlapping equality check (left operand type: "ValuesView[int]", right operand type: "Set[int]")
_testStrictEqualityWhitelist.py:13: error: Non-overlapping equality check (left operand type: "KeysView[int]", right operand type: "List[int]")

[case testUnreachableWithOpen]
# mypy: warn-unreachable

class Suppresses:
def __enter__(self) -> int: ...
def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ...

# This test overlaps with some of the warn-unreachable tests in check-unreachable-code,
# but 'open(...)' is a very common function so we want to make sure we don't regress
# against it specifically
def f_open() -> str:
with open("foo.txt", "r") as f:
return f.read()

def f_suppresses() -> int:
with Suppresses():
return 3
[out]
_testUnreachableWithOpen.py:14: error: Missing return statement

0 comments on commit b667bd1

Please sign in to comment.