Skip to content

Commit

Permalink
Fix bug with in operator used with a union of Container and Iterable (#…
Browse files Browse the repository at this point in the history
…14384)

Fixes #4954.

Modifies analysis of `in` comparison expressions. Previously, mypy would
check the right operand of an `in` expression to see if it was a union
of `Container`s, and then if it was a union of `Iterable`s, but would
fail on unions of both `Container`s and `Iterable`s.
  • Loading branch information
koogoro authored Jan 30, 2023
1 parent b2cf9d1 commit 1d247ea
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 51 deletions.
20 changes: 20 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4500,6 +4500,26 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]

def analyze_iterable_item_type_without_expression(
self, type: Type, context: Context
) -> tuple[Type, Type]:
"""Analyse iterable type and return iterator and iterator item types."""
echk = self.expr_checker
iterable = get_proper_type(type)
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]

if isinstance(iterable, TupleType):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
else:
# Non-tuple iterable.
return (
iterator,
echk.check_method_call_by_name("__next__", iterator, [], [], context)[0],
)

def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).
Expand Down
141 changes: 90 additions & 51 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2919,75 +2919,116 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
"""
result: Type | None = None
sub_result: Type | None = None
sub_result: Type

# Check each consecutive operand pair and their operator
for left, right, operator in zip(e.operands, e.operands[1:], e.operators):
left_type = self.accept(left)

method_type: mypy.types.Type | None = None

if operator == "in" or operator == "not in":
# This case covers both iterables and containers, which have different meanings.
# For a container, the in operator calls the __contains__ method.
# For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
# We allow `in` for a union of containers and iterables as long as at least one of them matches the
# type of the left operand, as the operation will simply return False if the union's container/iterator
# type doesn't match the left operand.

# If the right operand has partial type, look it up without triggering
# a "Need type annotation ..." message, as it would be noise.
right_type = self.find_partial_type_ref_fast_path(right)
if right_type is None:
right_type = self.accept(right) # Validate the right operand

# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=right_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
)
right_type = get_proper_type(right_type)
item_types: Sequence[Type] = [right_type]
if isinstance(right_type, UnionType):
item_types = list(right_type.items)

sub_result = self.bool_type()
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(right_type)
if isinstance(right_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
pass
elif (
local_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(right_type)
):
_, itertype = self.chk.analyze_iterable_item_type(right)
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types("in", left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (
not local_errors.has_new_errors()
and cont_type
and self.dangerous_comparison(
left_type, cont_type, original_container=right_type, prefer_literal=False
)
):
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
else:
self.msg.add_errors(local_errors.filtered_errors())

container_types: list[Type] = []
iterable_types: list[Type] = []
failed_out = False
encountered_partial_type = False

for item_type in item_types:
# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as container_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=item_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
original_type=right_type,
)
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(item_type)

if isinstance(item_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
encountered_partial_type = True
pass
elif (
container_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(item_type)
):
# it's not a container, but it is an iterable
with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors:
_, itertype = self.chk.analyze_iterable_item_type_without_expression(
item_type, e
)
if iterable_errors.has_new_errors():
self.msg.add_errors(iterable_errors.filtered_errors())
failed_out = True
else:
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
e.method_types.append(method_type)
iterable_types.append(itertype)
elif not container_errors.has_new_errors() and cont_type:
container_types.append(cont_type)
e.method_types.append(method_type)
else:
self.msg.add_errors(container_errors.filtered_errors())
failed_out = True

if not encountered_partial_type and not failed_out:
iterable_type = UnionType.make_union(iterable_types)
if not is_subtype(left_type, iterable_type):
if len(container_types) == 0:
self.msg.unsupported_operand_types("in", left_type, right_type, e)
else:
container_type = UnionType.make_union(container_types)
if self.dangerous_comparison(
left_type,
container_type,
original_container=right_type,
prefer_literal=False,
):
self.msg.dangerous_comparison(
left_type, container_type, "container", e
)

elif operator in operators.op_methods:
method = operators.op_methods[operator]

with ErrorWatcher(self.msg.errors) as w:
sub_result, method_type = self.check_op(
method, left_type, right, e, allow_reverse=True
)
e.method_types.append(method_type)

# Only show dangerous overlap if there are no other errors. See
# testCustomEqCheckStrictEquality for an example.
Expand All @@ -3007,12 +3048,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
method_type = None
e.method_types.append(None)
else:
raise RuntimeError(f"Unknown comparison operator {operator}")

e.method_types.append(method_type)

# Determine type of boolean-and of result and sub_result
if result is None:
result = sub_result
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1202,3 +1202,20 @@ def foo(
yield i
foo([1])
[builtins fixtures/list.pyi]

[case testUnionIterableContainer]
from typing import Iterable, Container, Union

i: Iterable[str]
c: Container[str]
u: Union[Iterable[str], Container[str]]
ni: Union[Iterable[str], int]
nc: Union[Container[str], int]

'x' in i
'x' in c
'x' in u
'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]")
'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]")
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

0 comments on commit 1d247ea

Please sign in to comment.