Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nested overload merging #12607

Merged
merged 1 commit into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 51 additions & 18 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,18 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
if_overload_name: Optional[str] = None
if_block_with_overload: Optional[Block] = None
if_unknown_truth_value: Optional[IfStmt] = None
if (
isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and seen_unconditional_func_def is False
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
):
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt)
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
if if_overload_name is not None:
if_block_with_overload, if_unknown_truth_value = \
self._get_executable_if_block_with_overloads(stmt)
Expand Down Expand Up @@ -553,8 +544,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
current_overload.extend(if_block_with_overload.body[0].items)
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
skipped_if_stmts.extend(
cast(List[IfStmt], if_block_with_overload.body[:-1])
)
current_overload.extend(if_block_with_overload.body[-1].items)
else:
current_overload.append(
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
Expand Down Expand Up @@ -600,9 +594,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
last_if_stmt = stmt
last_if_stmt_overload_name = None
if if_block_with_overload is not None:
skipped_if_stmts.extend(
cast(List[IfStmt], if_block_with_overload.body[:-1])
)
last_if_overload = cast(
Union[Decorator, FuncDef, OverloadedFuncDef],
if_block_with_overload.body[0]
if_block_with_overload.body[-1]
)
last_if_unknown_truth_value = if_unknown_truth_value
else:
Expand All @@ -620,23 +617,38 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
elif last_if_overload is not None:
ret.append(last_if_overload)
elif last_if_stmt is not None:
ret.append(last_if_stmt)
return ret

def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
def _check_ifstmt_for_overloads(
self, stmt: IfStmt, current_overload_name: Optional[str] = None
) -> Optional[str]:
"""Check if IfStmt contains only overloads with the same name.
Return overload_name if found, None otherwise.
"""
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
# Multiple overloads have already been merged as OverloadedFuncDef.
if not (
len(stmt.body[0].body) == 1
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
or len(stmt.body[0].body) > 1
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
and all(
self._is_stripped_if_stmt(if_stmt)
for if_stmt in stmt.body[0].body[:-1]
)
):
return None

overload_name = stmt.body[0].body[0].name
overload_name = cast(
Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]).name
if stmt.else_body is None:
return overload_name

Expand All @@ -649,7 +661,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
return overload_name
if (
isinstance(stmt.else_body.body[0], IfStmt)
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
and self._check_ifstmt_for_overloads(
stmt.else_body.body[0], current_overload_name
) == overload_name
):
return overload_name

Expand Down Expand Up @@ -704,6 +718,25 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
else:
stmt.else_body.body = []

def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
"""Check stmt to make sure it is a stripped IfStmt.

See also: _strip_contents_from_if_stmt
"""
if not isinstance(stmt, IfStmt):
return False

if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
# Body not empty
return False

if not stmt.else_body or len(stmt.else_body.body) == 0:
# No or empty else_body
return True

# For elif, IfStmt are stored recursively in else_body
return self._is_stripped_if_stmt(stmt.else_body.body[0])

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

Expand Down
139 changes: 139 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6349,3 +6349,142 @@ def g(x: int) -> str: ...

def g(x: int = 0) -> int: # E: Overloaded function implementation cannot produce return type of signature 2
return x

[case testOverloadIfNestedOk]
# flags: --always-true True --always-false False
from typing import overload

class A: ...
class B: ...
class C: ...
class D: ...

@overload
def f1(g: A) -> A: ...
if True:
@overload
def f1(g: B) -> B: ...
if True:
@overload
def f1(g: C) -> C: ...
@overload
def f1(g: D) -> D: ...
def f1(g): ...
reveal_type(f1(A())) # N: Revealed type is "__main__.A"
reveal_type(f1(B())) # N: Revealed type is "__main__.B"
reveal_type(f1(C())) # N: Revealed type is "__main__.C"
reveal_type(f1(D())) # N: Revealed type is "__main__.D"

@overload
def f2(g: A) -> A: ...
if True:
@overload
def f2(g: B) -> B: ...
if True:
@overload
def f2(g: C) -> C: ...
if True:
@overload
def f2(g: D) -> D: ...
def f2(g): ...
reveal_type(f2(A())) # N: Revealed type is "__main__.A"
reveal_type(f2(B())) # N: Revealed type is "__main__.B"
reveal_type(f2(C())) # N: Revealed type is "__main__.C"
reveal_type(f2(D())) # N: Revealed type is "__main__.D"

@overload
def f3(g: A) -> A: ...
if True:
if True:
@overload
def f3(g: B) -> B: ...
if True:
@overload
def f3(g: C) -> C: ...
def f3(g): ...
reveal_type(f3(A())) # N: Revealed type is "__main__.A"
reveal_type(f3(B())) # N: Revealed type is "__main__.B"
reveal_type(f3(C())) # N: Revealed type is "__main__.C"

@overload
def f4(g: A) -> A: ...
if True:
if False:
@overload
def f4(g: B) -> B: ...
else:
@overload
def f4(g: C) -> C: ...
def f4(g): ...
reveal_type(f4(A())) # N: Revealed type is "__main__.A"
reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \
# N: Possible overload variants: \
# N: def f4(g: A) -> A \
# N: def f4(g: C) -> C \
# N: Revealed type is "Any"
reveal_type(f4(C())) # N: Revealed type is "__main__.C"

@overload
def f5(g: A) -> A: ...
if True:
if False:
@overload
def f5(g: B) -> B: ...
elif True:
@overload
def f5(g: C) -> C: ...
def f5(g): ...
reveal_type(f5(A())) # N: Revealed type is "__main__.A"
reveal_type(f5(B())) # E: No overload variant of "f5" matches argument type "B" \
# N: Possible overload variants: \
# N: def f5(g: A) -> A \
# N: def f5(g: C) -> C \
# N: Revealed type is "Any"
reveal_type(f5(C())) # N: Revealed type is "__main__.C"

[case testOverloadIfNestedFailure]
# flags: --always-true True --always-false False
from typing import overload

class A: ...
class B: ...
class C: ...
class D: ...

@overload # E: Single overload definition, multiple required
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's going on in this test case, could you explain? I see you have --always-true True, but why should we care about that? And why does it cause all of these errors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation for overload merging requires that all conditions can be inferred and mypy knows which branch will be executed. I.e. we know what if True will do (with --always-true True). However, for if maybe_true we can't be sure the branch will actually be executed and thus we don't merge the overloads.

The errors then are a result of the default behavior if any of the if blocks contains a node other than overload / FuncDef. Since none of them are merge, we emit Single overload definition, multiple required and Name "f1" already defined errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's weird though that there are so many errors; I feel like only the one on line 6459 should be emitted. Is it possible to make the "Single overload definition, multiple required" and "Name "f1" already defined on line 9" errors go away? Seeing those will be confusing for users who use an unsupported condition.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird though that there are so many errors

I agree. The issue here is that both single overload definition and Name already defined on line are emitted during the analysis phase. At least currently, I've implemented the overload merging during the initial file parse. Basically, I'm rearranging the overloads if it's safe to do so, i.e. removing the nesting and IfExp nodes to leave just one OverloadedFuncDef in its place.

I've not (yet) touched the original overload analysis. It could be possible, but I fear that the logic to handle it will get complicated pretty fast. An especially challenge would be to exclude those cases where it's actually desirable to have all these messages.

I'm just not sure the effort is worth it. Do also consider that it's unlikely to encounter this issue in practice. The docs clearly state that only a limited number of conditions are supported. Most users will likely be just fine if they stick to if TYPE_CHECKING and if sys.version_info checks. The only other case would be if they add some other node between all the overloads (L-6483). That is unfortunate, but should happen normally. Anyway, it's good that mypy fails loudly in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, thanks. We can merge this as is for now and revisit if users complain about the errors.

def f1(g: A) -> A: ...
if True:
@overload # E: Single overload definition, multiple required
def f1(g: B) -> B: ...
if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \
# E: Name "maybe_true" is not defined
@overload
def f1(g: C) -> C: ...
@overload
def f1(g: D) -> D: ...
def f1(g): ... # E: Name "f1" already defined on line 9

@overload # E: Single overload definition, multiple required
def f2(g: A) -> A: ...
if True:
if False:
@overload
def f2(g: B) -> B: ...
elif maybe_true: # E: Name "maybe_true" is not defined
@overload # E: Single overload definition, multiple required
def f2(g: C) -> C: ...
def f2(g): ... # E: Name "f2" already defined on line 21

@overload # E: Single overload definition, multiple required
def f3(g: A) -> A: ...
if True:
@overload # E: Single overload definition, multiple required
def f3(g: B) -> B: ...
if True:
pass # Some other node
@overload # E: Name "f3" already defined on line 32 \
# E: An overloaded function outside a stub file must have an implementation
def f3(g: C) -> C: ...
@overload
def f3(g: D) -> D: ...
def f3(g): ... # E: Name "f3" already defined on line 32