diff --git a/mypy/checker.py b/mypy/checker.py index f083ed8aa254..561e5fcd2439 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -36,7 +36,7 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, true_only, false_only, function_type, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types + get_proper_types, is_literal_type ) from mypy.sametypes import is_same_type from mypy.messages import ( @@ -3341,12 +3341,40 @@ def check_incompatible_property_override(self, e: Decorator) -> None: self.fail(message_registry.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e) def visit_with_stmt(self, s: WithStmt) -> None: + exceptions_maybe_suppressed = False for expr, target in zip(s.expr, s.target): if s.is_async: - self.check_async_with_item(expr, target, s.unanalyzed_type is None) + exit_ret_type = self.check_async_with_item(expr, target, s.unanalyzed_type is None) else: - self.check_with_item(expr, target, s.unanalyzed_type is None) - self.accept(s.body) + exit_ret_type = self.check_with_item(expr, target, s.unanalyzed_type is None) + + # Based on the return type, determine if this context manager 'swallows' + # exceptions or not. We determine this using a heuristic based on the + # return type of the __exit__ method -- see the discussion in + # https://github.com/python/mypy/issues/7214 and the section about context managers + # in https://github.com/python/typeshed/blob/master/CONTRIBUTING.md#conventions + # for more details. + + exit_ret_type = get_proper_type(exit_ret_type) + if is_literal_type(exit_ret_type, "builtins.bool", False): + continue + + if (is_literal_type(exit_ret_type, "builtins.bool", True) + or (isinstance(exit_ret_type, Instance) + and exit_ret_type.type.fullname() == 'builtins.bool' + and state.strict_optional)): + # Note: if strict-optional is disabled, this bool instance + # could actually be an Optional[bool]. + exceptions_maybe_suppressed = True + + if exceptions_maybe_suppressed: + # Treat this 'with' block in the same way we'd treat a 'try: BODY; except: pass' + # block. This means control flow can continue after the 'with' even if the 'with' + # block immediately returns. + with self.binder.frame_context(can_skip=True, try_frame=True): + self.accept(s.body) + else: + self.accept(s.body) def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None: if not self.options.disallow_any_decorated or self.is_stub: @@ -3356,7 +3384,7 @@ def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None: self.msg.untyped_decorated_function(typ, func) def check_async_with_item(self, expr: Expression, target: Optional[Expression], - infer_lvalue_type: bool) -> None: + infer_lvalue_type: bool) -> Type: echk = self.expr_checker ctx = echk.accept(expr) obj = echk.check_method_call_by_name('__aenter__', ctx, [], [], expr)[0] @@ -3365,20 +3393,22 @@ def check_async_with_item(self, expr: Expression, target: Optional[Expression], if target: self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type) arg = self.temp_node(AnyType(TypeOfAny.special_form), expr) - res = echk.check_method_call_by_name( - '__aexit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] - echk.check_awaitable_expr( + res, _ = echk.check_method_call_by_name( + '__aexit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr) + return echk.check_awaitable_expr( res, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT) def check_with_item(self, expr: Expression, target: Optional[Expression], - infer_lvalue_type: bool) -> None: + infer_lvalue_type: bool) -> Type: echk = self.expr_checker ctx = echk.accept(expr) obj = echk.check_method_call_by_name('__enter__', ctx, [], [], expr)[0] if target: self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type) arg = self.temp_node(AnyType(TypeOfAny.special_form), expr) - echk.check_method_call_by_name('__exit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr) + res, _ = echk.check_method_call_by_name( + '__exit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr) + return res def visit_print_stmt(self, s: PrintStmt) -> None: for arg in s.args: diff --git a/mypy/types.py b/mypy/types.py index e406c86ae386..bd96d539ef59 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2406,6 +2406,17 @@ def remove_optional(typ: Type) -> ProperType: return typ +def is_literal_type(typ: ProperType, fallback_fullname: str, value: LiteralValue) -> bool: + """Check if this type is a LiteralType with the given fallback type and value.""" + if isinstance(typ, Instance) and typ.last_known_value: + typ = typ.last_known_value + if not isinstance(typ, LiteralType): + return False + if typ.fallback.type.fullname() != fallback_fullname: + return False + return typ.value == value + + @overload def get_proper_type(typ: None) -> None: ... @overload # noqa diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index 55d444148aed..3f52646f5ea2 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -963,3 +963,317 @@ class Test3(Generic[T2]): reveal_type(self.x) [builtins fixtures/isinstancelist.pyi] + +[case testUnreachableFlagContextManagersNoSuppress] +# flags: --warn-unreachable +from contextlib import contextmanager +from typing import Optional, Iterator, Any +from typing_extensions import Literal +class DoesNotSuppress1: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ... + +class DoesNotSuppress2: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[False]: ... + +class DoesNotSuppress3: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Any: ... + +class DoesNotSuppress4: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> None: ... + +@contextmanager +def simple() -> Iterator[int]: + yield 3 + +def cond() -> bool: ... + +def noop() -> None: ... + +def f_no_suppress_1a() -> int: + with DoesNotSuppress1(): + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_1b() -> int: + with DoesNotSuppress1(): + if cond(): + return 3 + else: + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_2() -> int: + with DoesNotSuppress2(): + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_3() -> int: + with DoesNotSuppress3(): + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_4() -> int: + with DoesNotSuppress4(): + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_5() -> int: + with simple(): + return 3 + noop() # E: Statement is unreachable + +[typing fixtures/typing-full.pyi] + +[case testUnreachableFlagContextManagersSuppressed] +# flags: --warn-unreachable +from contextlib import contextmanager +from typing import Optional, Iterator, Any +from typing_extensions import Literal + +class DoesNotSuppress: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ... + +class Suppresses1: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ... + +class Suppresses2: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[True]: ... + +def cond() -> bool: ... + +def noop() -> None: ... + +def f_suppress_1a() -> int: # E: Missing return statement + with Suppresses1(): + return 3 + noop() + +def f_suppress_1b() -> int: # E: Missing return statement + with Suppresses1(): + if cond(): + return 3 + else: + return 3 + noop() + +def f_suppress_2() -> int: # E: Missing return statement + with Suppresses2(): + return 3 + noop() + +def f_mix() -> int: # E: Missing return statement + with DoesNotSuppress(), Suppresses1(), DoesNotSuppress(): + return 3 + noop() +[typing fixtures/typing-full.pyi] + +[case testUnreachableFlagContextManagersSuppressedNoStrictOptional] +# flags: --warn-unreachable --no-strict-optional +from contextlib import contextmanager +from typing import Optional, Iterator, Any +from typing_extensions import Literal + +class DoesNotSuppress1: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ... + +# Normally, this should suppress. But when strict-optional mode is disabled, we can't +# necessarily distinguish between bool and Optional[bool]. So we default to assuming +# no suppression, since that's what most context managers will do. +class DoesNotSuppress2: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ... + +# But if we see Literal[True], it's pretty unlikely the return type is actually meant to +# be 'Optional[Literal[True]]'. So, we optimistically assume this is meant to be suppressing. +class Suppresses: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[True]: ... + +def noop() -> None: ... + +def f_no_suppress_1() -> int: + with DoesNotSuppress1(): + return 3 + noop() # E: Statement is unreachable + +def f_no_suppress_2() -> int: + with DoesNotSuppress1(): + return 3 + noop() # E: Statement is unreachable + +def f_suppress() -> int: # E: Missing return statement + with Suppresses(): + return 3 + noop() +[typing fixtures/typing-full.pyi] + +[case testUnreachableFlagContextAsyncManagersNoSuppress] +# flags: --warn-unreachable --python-version 3.7 +from contextlib import asynccontextmanager +from typing import Optional, AsyncIterator, Any +from typing_extensions import Literal + +class DoesNotSuppress1: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ... + +class DoesNotSuppress2: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[False]: ... + +class DoesNotSuppress3: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Any: ... + +class DoesNotSuppress4: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> None: ... + +@asynccontextmanager +async def simple() -> AsyncIterator[int]: + yield 3 + +def cond() -> bool: ... + +def noop() -> None: ... + +async def f_no_suppress_1a() -> int: + async with DoesNotSuppress1(): + return 3 + noop() # E: Statement is unreachable + +async def f_no_suppress_1b() -> int: + async with DoesNotSuppress1(): + if cond(): + return 3 + else: + return 3 + noop() # E: Statement is unreachable + +async def f_no_suppress_2() -> int: + async with DoesNotSuppress2(): + return 3 + noop() # E: Statement is unreachable + +async def f_no_suppress_3() -> int: + async with DoesNotSuppress3(): + return 3 + noop() # E: Statement is unreachable + +async def f_no_suppress_4() -> int: + async with DoesNotSuppress4(): + return 3 + noop() # E: Statement is unreachable + +async def f_no_suppress_5() -> int: + async with simple(): + return 3 + noop() # E: Statement is unreachable + +[typing fixtures/typing-full.pyi] + +[case testUnreachableFlagContextAsyncManagersSuppressed] +# flags: --warn-unreachable --python-version 3.7 +from contextlib import asynccontextmanager +from typing import Optional, AsyncIterator, Any +from typing_extensions import Literal + +class DoesNotSuppress: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Optional[bool]: ... + +class Suppresses1: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ... + +class Suppresses2: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> Literal[True]: ... + +def cond() -> bool: ... + +def noop() -> None: ... + +async def f_suppress_1() -> int: # E: Missing return statement + async with Suppresses1(): + return 3 + noop() + +async def f_suppress_2() -> int: # E: Missing return statement + async with Suppresses1(): + if cond(): + return 3 + else: + return 3 + noop() + +async def f_suppress_3() -> int: # E: Missing return statement + async with Suppresses2(): + return 3 + noop() + +async def f_mix() -> int: # E: Missing return statement + async with DoesNotSuppress(), Suppresses1(), DoesNotSuppress(): + return 3 + noop() +[typing fixtures/typing-full.pyi] + +[case testUnreachableFlagContextAsyncManagersAbnormal] +# flags: --warn-unreachable --python-version 3.7 +from contextlib import asynccontextmanager +from typing import Optional, AsyncIterator, Any +from typing_extensions import Literal + +class RegularManager: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ... + +class AsyncManager: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> bool: ... + +def noop() -> None: ... + +async def f_bad_1() -> int: + async with RegularManager(): # E: "RegularManager" has no attribute "__aenter__"; maybe "__enter__"? \ + # E: "RegularManager" has no attribute "__aexit__"; maybe "__exit__"? + return 3 + noop() # E: Statement is unreachable + +def f_bad_2() -> int: + with AsyncManager(): # E: "AsyncManager" has no attribute "__enter__"; maybe "__aenter__"? \ + # E: "AsyncManager" has no attribute "__exit__"; maybe "__aexit__"? + return 3 + noop() # E: Statement is unreachable + +# TODO: We should consider reporting an error when the user tries using +# context manager with malformed signatures instead of silently continuing. + +class RegularManagerMalformedSignature: + def __enter__(self) -> int: ... + def __exit__(self, exctype: object, excvalue: object, traceback: object) -> object: ... + +class AsyncManagerMalformedSignature: + async def __aenter__(self) -> int: ... + async def __aexit__(self, exctype: object, excvalue: object, traceback: object) -> object: ... + +def f_malformed_1() -> int: + with RegularManagerMalformedSignature(): + return 3 + noop() # E: Statement is unreachable + +async def f_malformed_2() -> int: + async with AsyncManagerMalformedSignature(): + return 3 + noop() # E: Statement is unreachable + +[typing fixtures/typing-full.pyi] + diff --git a/test-data/unit/lib-stub/contextlib.pyi b/test-data/unit/lib-stub/contextlib.pyi index fa4760c71054..e7db25da1b5f 100644 --- a/test-data/unit/lib-stub/contextlib.pyi +++ b/test-data/unit/lib-stub/contextlib.pyi @@ -1,3 +1,4 @@ +import sys from typing import Generic, TypeVar, Callable, Iterator from typing import ContextManager as ContextManager @@ -8,3 +9,8 @@ class GeneratorContextManager(ContextManager[_T], Generic[_T]): def contextmanager(func: Callable[..., Iterator[_T]]) -> Callable[..., GeneratorContextManager[_T]]: ... + +if sys.version_info >= (3, 7): + from typing import AsyncIterator + from typing import AsyncContextManager as AsyncContextManager + def asynccontextmanager(func: Callable[..., AsyncIterator[_T]]) -> Callable[..., AsyncContextManager[_T]]: ... diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 6c6203237f3b..de18922aca28 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1405,3 +1405,46 @@ _testStrictEqualityWhitelist.py:5: error: Non-overlapping equality check (left o _testStrictEqualityWhitelist.py:11: error: Non-overlapping equality check (left operand type: "KeysView[int]", right operand type: "Set[str]") _testStrictEqualityWhitelist.py:12: error: Non-overlapping equality check (left operand type: "ValuesView[int]", right operand type: "Set[int]") _testStrictEqualityWhitelist.py:13: error: Non-overlapping equality check (left operand type: "KeysView[int]", right operand type: "List[int]") + +[case testUnreachableWithStdlibContextManagers] +# mypy: warn-unreachable, strict-optional + +from contextlib import suppress + +# This test overlaps with some of the warn-unreachable tests in check-unreachable-code, +# but 'open(...)' is a very common function so we want to make sure we don't regress +# against it specifically +def f_open() -> str: + with open("foo.txt", "r") as f: + return f.read() + print("noop") + +# contextlib.suppress is less common, but it's a fairly prominent example of an +# exception-suppressing context manager, so it'd be good to double-check. +def f_suppresses() -> int: + with suppress(Exception): + return 3 + print("noop") +[out] +_testUnreachableWithStdlibContextManagers.py:11: error: Statement is unreachable +_testUnreachableWithStdlibContextManagers.py:15: error: Missing return statement + +[case testUnreachableWithStdlibContextManagersNoStrictOptional] +# mypy: warn-unreachable, no-strict-optional + +from contextlib import suppress + +# When strict-optional is disabled, 'open' should still behave in the same way as before +def f_open() -> str: + with open("foo.txt", "r") as f: + return f.read() + print("noop") + +# ...but unfortunately, we can't +def f_suppresses() -> int: + with suppress(Exception): + return 3 + print("noop") +[out] +_testUnreachableWithStdlibContextManagersNoStrictOptional.py:9: error: Statement is unreachable +_testUnreachableWithStdlibContextManagersNoStrictOptional.py:15: error: Statement is unreachable