Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle empty bodies safely #8111

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/class_basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,28 @@ however:
in this case, but any attempt to construct an instance will be
flagged as an error.

Mypy allows you to omit the body for an abstract method, but if you do so,
it is unsafe to call such method via ``super()``. For example:

.. code-block:: python

from abc import abstractmethod

class Base:
@abstractmethod
def foo(self) -> int: pass
@abstractmethod
def bar(self) -> int:
return 0

class Sub(Base):
def foo(self) -> int:
return super().foo() + 1 # error: Call to abstract method "foo" of "Base"
# with trivial body via super() is unsafe
@abstractmethod
def bar(self) -> int:
return super().bar() + 1 # This is OK however.

A class can inherit any number of classes, both abstract and
concrete. As with normal overrides, a dynamically typed method can
override or implement a statically typed method defined in any base
Expand Down
25 changes: 25 additions & 0 deletions docs/source/error_code_list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,31 @@ Example:
# Error: Cannot instantiate abstract class 'Thing' with abstract attribute 'save' [abstract]
t = Thing()

Check that call to an abstract method via super is valid [safe-super]
---------------------------------------------------------------------

Abstract methods often don't have any default implementation, i.e. their
bodies are just empty. Calling such methods in subclasses via ``super()``
will cause runtime errors, so mypy prevents you from doing so:

.. code-block:: python

from abc import abstractmethod

class Base:
@abstractmethod
def foo(self) -> int: ...

class Sub(Base):
def foo(self) -> int:
return super().foo() + 1 # error: Call to abstract method "foo" of "Base" with
# trivial body via super() is unsafe [safe-super]

Sub().foo() # This will crash at runtime.

Mypy considers the following as trivial bodies: a ``pass`` statement, a literal
ellipsis ``...``, a docstring, and a ``raise NotImplementedError`` statement.

Check the target of NewType [valid-newtype]
-------------------------------------------

Expand Down
16 changes: 15 additions & 1 deletion docs/source/protocols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,21 @@ protocols. If you explicitly subclass these protocols you can inherit
these default implementations. Explicitly including a protocol as a
base class is also a way of documenting that your class implements a
particular protocol, and it forces mypy to verify that your class
implementation is actually compatible with the protocol.
implementation is actually compatible with the protocol. In particular,
omitting a value for an attribute or a method body will make it implicitly
abstract:

.. code-block:: python

class SomeProto(Protocol):
attr: int # Note, no right hand side
def method(self) -> str: ... # Literal ... here

class ExplicitSubclass(SomeProto):
pass

ExplicitSubclass() # error: Cannot instantiate abstract class 'ExplicitSubclass'
# with abstract attributes 'attr' and 'method'

.. note::

Expand Down
77 changes: 21 additions & 56 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
ClassDef, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr,
TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt,
WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt,
RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr,
UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode,
RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr,
OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode,
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
Import, ImportFrom, ImportAll, ImportBase, TypeAlias,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr,
is_final_node,
ARG_NAMED)
is_final_node, is_trivial_body, ARG_NAMED
)
from mypy import nodes
from mypy.literals import literal, literal_hash
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
Expand Down Expand Up @@ -445,7 +445,9 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
self.visit_decorator(cast(Decorator, defn.items[0]))
for fdef in defn.items:
assert isinstance(fdef, Decorator)
self.check_func_item(fdef.func, name=fdef.func.name)
self.check_func_item(fdef.func,
name=fdef.func.name,
allow_empty=True)
if fdef.func.is_abstract:
num_abstract += 1
if num_abstract not in (0, len(defn.items)):
Expand Down Expand Up @@ -774,7 +776,8 @@ def _visit_func_def(self, defn: FuncDef) -> None:

def check_func_item(self, defn: FuncItem,
type_override: Optional[CallableType] = None,
name: Optional[str] = None) -> None:
name: Optional[str] = None,
allow_empty: bool = False) -> None:
"""Type check a function.

If type_override is provided, use it as the function type.
Expand All @@ -787,7 +790,7 @@ def check_func_item(self, defn: FuncItem,
typ = type_override.copy_modified(line=typ.line, column=typ.column)
if isinstance(typ, CallableType):
with self.enter_attribute_inference_context():
self.check_func_def(defn, typ, name)
self.check_func_def(defn, typ, name, allow_empty)
else:
raise RuntimeError('Not supported')

Expand All @@ -804,7 +807,8 @@ def enter_attribute_inference_context(self) -> Iterator[None]:
yield None
self.inferred_attribute_types = old_types

def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) -> None:
def check_func_def(self, defn: FuncItem, typ: CallableType,
name: Optional[str], allow_empty: bool = False) -> None:
"""Type check a function definition."""
# Expand type variables with value restrictions to ordinary types.
expanded = self.expand_typevars(defn, typ)
Expand Down Expand Up @@ -956,7 +960,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
item.arguments[i].variable.type = arg_type

# Type check initialization expressions.
body_is_trivial = self.is_trivial_body(defn.body)
body_is_trivial = is_trivial_body(defn.body)
self.check_default_args(item, body_is_trivial)

# Type check body in a new scope.
Expand All @@ -973,7 +977,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
self.accept(item.body)
unreachable = self.binder.is_unreachable()

if (self.options.warn_no_return and not unreachable):
if self.options.warn_no_return and not unreachable:
if (defn.is_generator or
is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')):
return_type = self.get_generator_return_type(self.return_types[-1],
Expand All @@ -984,7 +988,13 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
return_type = self.return_types[-1]

return_type = get_proper_type(return_type)
if not isinstance(return_type, (NoneType, AnyType)) and not body_is_trivial:
allow_empty = allow_empty or self.options.allow_empty_bodies
if (not isinstance(return_type, (NoneType, AnyType)) and
(not body_is_trivial or
# Allow empty bodies for abstract methods, overloads, in tests and stubs.
not allow_empty
and not (isinstance(defn, FuncDef) and defn.is_abstract)
and not self.is_stub)):
# Control flow fell off the end of a function that was
# declared to return a non-None type and is not
# entirely pass/Ellipsis/raise NotImplementedError.
Expand Down Expand Up @@ -1097,51 +1107,6 @@ def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None:
'but must return a subtype of'
)

def is_trivial_body(self, block: Block) -> bool:
"""Returns 'true' if the given body is "trivial" -- if it contains just a "pass",
"..." (ellipsis), or "raise NotImplementedError()". A trivial body may also
start with a statement containing just a string (e.g. a docstring).

Note: functions that raise other kinds of exceptions do not count as
"trivial". We use this function to help us determine when it's ok to
relax certain checks on body, but functions that raise arbitrary exceptions
are more likely to do non-trivial work. For example:

def halt(self, reason: str = ...) -> NoReturn:
raise MyCustomError("Fatal error: " + reason, self.line, self.context)

A function that raises just NotImplementedError is much less likely to be
this complex.
"""
body = block.body

# Skip a docstring
if (body and isinstance(body[0], ExpressionStmt) and
isinstance(body[0].expr, (StrExpr, UnicodeExpr))):
body = block.body[1:]

if len(body) == 0:
# There's only a docstring (or no body at all).
return True
elif len(body) > 1:
return False

stmt = body[0]

if isinstance(stmt, RaiseStmt):
expr = stmt.expr
if expr is None:
return False
if isinstance(expr, CallExpr):
expr = expr.callee

return (isinstance(expr, NameExpr)
and expr.fullname == 'builtins.NotImplementedError')

return (isinstance(stmt, PassStmt) or
(isinstance(stmt, ExpressionStmt) and
isinstance(stmt.expr, EllipsisExpr)))

def check_reverse_op_method(self, defn: FuncItem,
reverse_type: CallableType, reverse_name: str,
context: Context) -> None:
Expand Down
3 changes: 2 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,8 @@ def check_callable_call(self,
if (callee.is_type_obj() and callee.type_object().is_abstract
# Exception for Type[...]
and not callee.from_type_type
and not callee.type_object().fallback_to_any):
and not callee.type_object().fallback_to_any
and not callee.type_object().is_protocol):
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
callee.type_object().name, type.abstract_attributes,
Expand Down
21 changes: 20 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def analyze_instance_member_access(name: str,
# Look up the member. First look up the method dictionary.
method = info.get_method(name)
if method:
unsafe_super = False
if mx.is_super:
if isinstance(method, FuncDef) and method.is_trivial_body:
unsafe_super = True
impl = method
elif isinstance(method, OverloadedFuncDef):
if method.impl:
impl = method.impl if isinstance(method.impl, FuncDef) else method.impl.func
unsafe_super = impl.is_trivial_body
if unsafe_super:
ret_type = (impl.type.ret_type if isinstance(impl.type, CallableType)
else AnyType(TypeOfAny.unannotated))
if not subtypes.is_subtype(NoneType(), ret_type):
mx.msg.unsafe_super(method.name, method.info.name, mx.context)
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
first_item = cast(Decorator, method.items[0])
Expand Down Expand Up @@ -346,6 +360,11 @@ def analyze_member_var_access(name: str,
if isinstance(vv, Decorator):
# The associated Var node of a decorator contains the type.
v = vv.var
if mx.is_super and vv.func.is_trivial_body:
ret_type = (vv.func.type.ret_type if isinstance(vv.func.type, CallableType)
else AnyType(TypeOfAny.unannotated))
if not subtypes.is_subtype(NoneType(), ret_type):
mx.msg.unsafe_super(vv.func.name, vv.func.info.name, mx.context)

if isinstance(vv, TypeInfo):
# If the associated variable is a TypeInfo synthesize a Var node for
Expand Down Expand Up @@ -565,7 +584,7 @@ def analyze_var(name: str,
# * B.f: Callable[[B1], None] where B1 <: B (maybe B1 == B)
# * x: Union[A1, B1]
# In `x.f`, when checking `x` against A1 we assume x is compatible with A
# and similarly for B1 when checking agains B
# and similarly for B1 when checking against B
dispatched_type = meet.meet_types(mx.original_type, itype)
signature = freshen_function_type_vars(functype)
signature = check_self_arg(signature, dispatched_type, var.is_classmethod,
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __str__(self) -> str:
'General') # type: Final
EXIT_RETURN = ErrorCode(
'exit-return', "Warn about too general return type for '__exit__'", 'General') # type: Final
SAFE_SUPER = ErrorCode(
'safe-super', "Warn about calls to abstract methods with empty/trivial bodies",
'General') # type: Final

# These error codes aren't enabled by default.
NO_UNTYPED_DEF = ErrorCode(
Expand Down
2 changes: 2 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ def add_invertible_flag(flag: str,
"the contents of SHADOW_FILE instead.")
add_invertible_flag('--fast-exit', default=False, help=argparse.SUPPRESS,
group=internals_group)
add_invertible_flag('--allow-empty-bodies', default=False, help=argparse.SUPPRESS,
group=internals_group)

report_group = parser.add_argument_group(
title='Report generation',
Expand Down
4 changes: 4 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,10 @@ def first_argument_for_super_must_be_type(self, actual: Type, context: Context)
type_str = format_type(actual)
self.fail('Argument 1 for "super" must be a type object; got {}'.format(type_str), context)

def unsafe_super(self, method: str, cls: str, ctx: Context) -> None:
self.fail('Call to abstract method "{}" of "{}" with trivial body'
' via super() is unsafe'.format(method, cls), ctx, code=codes.SAFE_SUPER)

def too_few_string_formatting_arguments(self, context: Context) -> None:
self.fail('Not enough arguments for format string', context,
code=codes.STRING_FORMATTING)
Expand Down
52 changes: 51 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def is_dynamic(self) -> bool:


FUNCDEF_FLAGS = FUNCITEM_FLAGS + [
'is_decorated', 'is_conditional', 'is_abstract',
'is_decorated', 'is_conditional', 'is_abstract', 'is_trivial_body',
] # type: Final


Expand All @@ -651,6 +651,7 @@ class FuncDef(FuncItem, SymbolNode, Statement):
'is_decorated',
'is_conditional',
'is_abstract',
'is_trivial_body',
'original_def',
)

Expand All @@ -664,6 +665,9 @@ def __init__(self,
self.is_decorated = False
self.is_conditional = False # Defined conditionally (within block)?
self.is_abstract = False
# Is this an abstract method with trivial body?
# Such methods can't be called via super().
self.is_trivial_body = False
self.is_final = False
# Original conditional definition
self.original_def = None # type: Union[None, FuncDef, Var, Decorator]
Expand Down Expand Up @@ -3203,3 +3207,49 @@ def local_definitions(names: SymbolTable,
yield fullname, symnode, info
if isinstance(node, TypeInfo):
yield from local_definitions(node.names, fullname, node)


def is_trivial_body(block: Block) -> bool:
"""Returns 'true' if the given body is "trivial" -- if it contains just a "pass",
"..." (ellipsis), or "raise NotImplementedError()". A trivial body may also
start with a statement containing just a string (e.g. a docstring).

Note: functions that raise other kinds of exceptions do not count as
"trivial". We use this function to help us determine when it's ok to
relax certain checks on body, but functions that raise arbitrary exceptions
are more likely to do non-trivial work. For example:

def halt(self, reason: str = ...) -> NoReturn:
raise MyCustomError("Fatal error: " + reason, self.line, self.context)

A function that raises just NotImplementedError is much less likely to be
this complex.
"""
body = block.body

# Skip a docstring
if (body and isinstance(body[0], ExpressionStmt) and
isinstance(body[0].expr, (StrExpr, UnicodeExpr))):
body = block.body[1:]

if len(body) == 0:
# There's only a docstring (or no body at all).
return True
elif len(body) > 1:
return False

stmt = body[0]

if isinstance(stmt, RaiseStmt):
expr = stmt.expr
if expr is None:
return False
if isinstance(expr, CallExpr):
expr = expr.callee

return (isinstance(expr, NameExpr)
and expr.fullname == 'builtins.NotImplementedError')

return (isinstance(stmt, PassStmt) or
(isinstance(stmt, ExpressionStmt) and
isinstance(stmt.expr, EllipsisExpr)))
2 changes: 2 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def __init__(self) -> None:
self.cache_map = {} # type: Dict[str, Tuple[str, str]]
# Don't properly free objects on exit, just kill the current process.
self.fast_exit = False
# Allow empty function bodies even if it is not safe, used for testing only.
self.allow_empty_bodies = False
# Used to transform source code before parsing if not None
# TODO: Make the type precise (AnyStr -> AnyStr)
self.transform_source = None # type: Optional[Callable[[Any], Any]]
Expand Down
Loading