Skip to content

Refactor reversible operators #5475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 101 additions & 54 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
self.msg, context=fdef)

if name: # Special method names
if defn.info and name in nodes.reverse_op_method_set:
if defn.info and self.is_reverse_op_method(name):
self.check_reverse_op_method(item, typ, name, defn)
elif name in ('__getattr__', '__getattribute__'):
self.check_getattr_method(typ, defn, name)
Expand Down Expand Up @@ -923,6 +923,18 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])

self.binder = old_binder

def is_forward_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__div__':
return True
else:
return method_name in nodes.reverse_op_methods

def is_reverse_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__rdiv__':
return True
else:
return method_name in nodes.reverse_op_method_set

def check_for_missing_annotations(self, fdef: FuncItem) -> None:
# Check for functions with unspecified/not fully specified types.
def is_unannotated_any(t: Type) -> bool:
Expand Down Expand Up @@ -1010,7 +1022,10 @@ def check_reverse_op_method(self, defn: FuncItem,
arg_names=[reverse_type.arg_names[0], "_"])
assert len(reverse_type.arg_types) >= 2

forward_name = nodes.normal_from_reverse_op[reverse_name]
if self.options.python_version[0] == 2 and reverse_name == '__rdiv__':
forward_name = '__div__'
else:
forward_name = nodes.normal_from_reverse_op[reverse_name]
forward_inst = reverse_type.arg_types[1]
if isinstance(forward_inst, TypeVarType):
forward_inst = forward_inst.upper_bound
Expand Down Expand Up @@ -1042,73 +1057,105 @@ def check_overlapping_op_methods(self,
context: Context) -> None:
"""Check for overlapping method and reverse method signatures.

Assume reverse method has valid argument count and kinds.
This function assumes that:

- The reverse method has valid argument count and kinds.
- If the reverse operator method accepts some argument of type
X, the forward operator method also belong to class X.

For example, if we have the reverse operator `A.__radd__(B)`, then the
corresponding forward operator must have the type `B.__add__(...)`.
"""

# Reverse operator method that overlaps unsafely with the
# forward operator method can result in type unsafety. This is
# similar to overlapping overload variants.
# Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and
# "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping
# by using the following algorithm:
#
# 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1"
#
# 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2"
#
# 3. Treat temp1 and temp2 as if they were both variants in the same
# overloaded function. (This mirrors how the Python runtime calls
# operator methods: we first try __OP__, then __rOP__.)
#
# If the first signature is unsafely overlapping with the second,
# report an error.
#
# This example illustrates the issue:
# 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never
# be called), do NOT report an error.
#
# class X: pass
# class A:
# def __add__(self, x: X) -> int:
# if isinstance(x, X):
# return 1
# return NotImplemented
# class B:
# def __radd__(self, x: A) -> str: return 'x'
# class C(X, B): pass
# def f(b: B) -> None:
# A() + b # Result is 1, even though static type seems to be str!
# f(C())
# This behavior deviates from how we handle overloads -- many of the
# modules in typeshed seem to define __OP__ methods that shadow the
# corresponding __rOP__ method.
#
# The reason for the problem is that B and X are overlapping
# types, and the return types are different. Also, if the type
# of x in __radd__ would not be A, the methods could be
# non-overlapping.
# Note: we do not attempt to handle unsafe overlaps related to multiple
# inheritance. (This is consistent with how we handle overloads: we also
# do not try checking unsafe overlaps due to multiple inheritance there.)

for forward_item in union_items(forward_type):
if isinstance(forward_item, CallableType):
# TODO check argument kinds
if len(forward_item.arg_types) < 1:
# Not a valid operator method -- can't succeed anyway.
return

# Construct normalized function signatures corresponding to the
# operator methods. The first argument is the left operand and the
# second operand is the right argument -- we switch the order of
# the arguments of the reverse method.
forward_tweaked = CallableType(
[forward_base, forward_item.arg_types[0]],
[nodes.ARG_POS] * 2,
[None] * 2,
forward_item.ret_type,
forward_item.fallback,
name=forward_item.name)
reverse_args = reverse_type.arg_types
reverse_tweaked = CallableType(
[reverse_args[1], reverse_args[0]],
[nodes.ARG_POS] * 2,
[None] * 2,
reverse_type.ret_type,
fallback=self.named_type('builtins.function'),
name=reverse_type.name)

if is_unsafe_overlapping_operator_signatures(
forward_tweaked, reverse_tweaked):
if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type):
self.msg.operator_method_signatures_overlap(
reverse_class, reverse_name,
forward_base, forward_name, context)
elif isinstance(forward_item, Overloaded):
for item in forward_item.items():
self.check_overlapping_op_methods(
reverse_type, reverse_name, reverse_class,
item, forward_name, forward_base, context)
if self.is_unsafe_overlapping_op(item, forward_base, reverse_type):
self.msg.operator_method_signatures_overlap(
reverse_class, reverse_name,
forward_base, forward_name,
context)
elif not isinstance(forward_item, AnyType):
self.msg.forward_operator_not_callable(forward_name, context)

def is_unsafe_overlapping_op(self,
forward_item: CallableType,
forward_base: Type,
reverse_type: CallableType) -> bool:
# TODO: check argument kinds?
if len(forward_item.arg_types) < 1:
# Not a valid operator method -- can't succeed anyway.
return False

# Erase the type if necessary to make sure we don't have a single
# TypeVar in forward_tweaked. (Having a function signature containing
# just a single TypeVar can lead to unpredictable behavior.)
forward_base_erased = forward_base
if isinstance(forward_base, TypeVarType):
forward_base_erased = erase_to_bound(forward_base)

# Construct normalized function signatures corresponding to the
# operator methods. The first argument is the left operand and the
# second operand is the right argument -- we switch the order of
# the arguments of the reverse method.

forward_tweaked = forward_item.copy_modified(
arg_types=[forward_base_erased, forward_item.arg_types[0]],
arg_kinds=[nodes.ARG_POS] * 2,
arg_names=[None] * 2,
)
reverse_tweaked = reverse_type.copy_modified(
arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]],
arg_kinds=[nodes.ARG_POS] * 2,
arg_names=[None] * 2,
)

reverse_base_erased = reverse_type.arg_types[0]
if isinstance(reverse_base_erased, TypeVarType):
reverse_base_erased = erase_to_bound(reverse_base_erased)

if is_same_type(reverse_base_erased, forward_base_erased):
return False
elif is_subtype(reverse_base_erased, forward_base_erased):
first = reverse_tweaked
second = forward_tweaked
else:
first = forward_tweaked
second = reverse_tweaked

return is_unsafe_overlapping_overload_signatures(first, second)

def check_inplace_operator_method(self, defn: FuncBase) -> None:
"""Check an inplace operator method such as __iadd__.

Expand Down Expand Up @@ -1312,7 +1359,7 @@ def check_override(self, override: FunctionLike, original: FunctionLike,
fail = True
elif (not isinstance(original, Overloaded) and
isinstance(override, Overloaded) and
name in nodes.reverse_op_methods.keys()):
self.is_forward_op_method(name)):
# Operator method overrides cannot introduce overloading, as
# this could be unsafe with reverse operator methods.
fail = True
Expand Down
Loading