Skip to content

Commit

Permalink
Fix nested overload merging (#12607)
Browse files Browse the repository at this point in the history
Closes #12606
  • Loading branch information
cdce8p committed Apr 19, 2022
1 parent 9e9de71 commit cf6a48c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 18 deletions.
69 changes: 51 additions & 18 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,18 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
if_overload_name: Optional[str] = None
if_block_with_overload: Optional[Block] = None
if_unknown_truth_value: Optional[IfStmt] = None
if (
isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and seen_unconditional_func_def is False
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
):
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt)
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
if if_overload_name is not None:
if_block_with_overload, if_unknown_truth_value = \
self._get_executable_if_block_with_overloads(stmt)
Expand Down Expand Up @@ -553,8 +544,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
current_overload.extend(if_block_with_overload.body[0].items)
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
skipped_if_stmts.extend(
cast(List[IfStmt], if_block_with_overload.body[:-1])
)
current_overload.extend(if_block_with_overload.body[-1].items)
else:
current_overload.append(
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
Expand Down Expand Up @@ -600,9 +594,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
last_if_stmt = stmt
last_if_stmt_overload_name = None
if if_block_with_overload is not None:
skipped_if_stmts.extend(
cast(List[IfStmt], if_block_with_overload.body[:-1])
)
last_if_overload = cast(
Union[Decorator, FuncDef, OverloadedFuncDef],
if_block_with_overload.body[0]
if_block_with_overload.body[-1]
)
last_if_unknown_truth_value = if_unknown_truth_value
else:
Expand All @@ -620,23 +617,38 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
elif last_if_overload is not None:
ret.append(last_if_overload)
elif last_if_stmt is not None:
ret.append(last_if_stmt)
return ret

def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
def _check_ifstmt_for_overloads(
self, stmt: IfStmt, current_overload_name: Optional[str] = None
) -> Optional[str]:
"""Check if IfStmt contains only overloads with the same name.
Return overload_name if found, None otherwise.
"""
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
# Multiple overloads have already been merged as OverloadedFuncDef.
if not (
len(stmt.body[0].body) == 1
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
or len(stmt.body[0].body) > 1
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
and all(
self._is_stripped_if_stmt(if_stmt)
for if_stmt in stmt.body[0].body[:-1]
)
):
return None

overload_name = stmt.body[0].body[0].name
overload_name = cast(
Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]).name
if stmt.else_body is None:
return overload_name

Expand All @@ -649,7 +661,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
return overload_name
if (
isinstance(stmt.else_body.body[0], IfStmt)
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
and self._check_ifstmt_for_overloads(
stmt.else_body.body[0], current_overload_name
) == overload_name
):
return overload_name

Expand Down Expand Up @@ -704,6 +718,25 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
else:
stmt.else_body.body = []

def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
"""Check stmt to make sure it is a stripped IfStmt.
See also: _strip_contents_from_if_stmt
"""
if not isinstance(stmt, IfStmt):
return False

if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
# Body not empty
return False

if not stmt.else_body or len(stmt.else_body.body) == 0:
# No or empty else_body
return True

# For elif, IfStmt are stored recursively in else_body
return self._is_stripped_if_stmt(stmt.else_body.body[0])

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

Expand Down
139 changes: 139 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6367,3 +6367,142 @@ def g(x: int) -> str: ...

def g(x: int = 0) -> int: # E: Overloaded function implementation cannot produce return type of signature 2
return x

[case testOverloadIfNestedOk]
# flags: --always-true True --always-false False
from typing import overload

class A: ...
class B: ...
class C: ...
class D: ...

@overload
def f1(g: A) -> A: ...
if True:
@overload
def f1(g: B) -> B: ...
if True:
@overload
def f1(g: C) -> C: ...
@overload
def f1(g: D) -> D: ...
def f1(g): ...
reveal_type(f1(A())) # N: Revealed type is "__main__.A"
reveal_type(f1(B())) # N: Revealed type is "__main__.B"
reveal_type(f1(C())) # N: Revealed type is "__main__.C"
reveal_type(f1(D())) # N: Revealed type is "__main__.D"

@overload
def f2(g: A) -> A: ...
if True:
@overload
def f2(g: B) -> B: ...
if True:
@overload
def f2(g: C) -> C: ...
if True:
@overload
def f2(g: D) -> D: ...
def f2(g): ...
reveal_type(f2(A())) # N: Revealed type is "__main__.A"
reveal_type(f2(B())) # N: Revealed type is "__main__.B"
reveal_type(f2(C())) # N: Revealed type is "__main__.C"
reveal_type(f2(D())) # N: Revealed type is "__main__.D"

@overload
def f3(g: A) -> A: ...
if True:
if True:
@overload
def f3(g: B) -> B: ...
if True:
@overload
def f3(g: C) -> C: ...
def f3(g): ...
reveal_type(f3(A())) # N: Revealed type is "__main__.A"
reveal_type(f3(B())) # N: Revealed type is "__main__.B"
reveal_type(f3(C())) # N: Revealed type is "__main__.C"

@overload
def f4(g: A) -> A: ...
if True:
if False:
@overload
def f4(g: B) -> B: ...
else:
@overload
def f4(g: C) -> C: ...
def f4(g): ...
reveal_type(f4(A())) # N: Revealed type is "__main__.A"
reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \
# N: Possible overload variants: \
# N: def f4(g: A) -> A \
# N: def f4(g: C) -> C \
# N: Revealed type is "Any"
reveal_type(f4(C())) # N: Revealed type is "__main__.C"

@overload
def f5(g: A) -> A: ...
if True:
if False:
@overload
def f5(g: B) -> B: ...
elif True:
@overload
def f5(g: C) -> C: ...
def f5(g): ...
reveal_type(f5(A())) # N: Revealed type is "__main__.A"
reveal_type(f5(B())) # E: No overload variant of "f5" matches argument type "B" \
# N: Possible overload variants: \
# N: def f5(g: A) -> A \
# N: def f5(g: C) -> C \
# N: Revealed type is "Any"
reveal_type(f5(C())) # N: Revealed type is "__main__.C"

[case testOverloadIfNestedFailure]
# flags: --always-true True --always-false False
from typing import overload

class A: ...
class B: ...
class C: ...
class D: ...

@overload # E: Single overload definition, multiple required
def f1(g: A) -> A: ...
if True:
@overload # E: Single overload definition, multiple required
def f1(g: B) -> B: ...
if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \
# E: Name "maybe_true" is not defined
@overload
def f1(g: C) -> C: ...
@overload
def f1(g: D) -> D: ...
def f1(g): ... # E: Name "f1" already defined on line 9

@overload # E: Single overload definition, multiple required
def f2(g: A) -> A: ...
if True:
if False:
@overload
def f2(g: B) -> B: ...
elif maybe_true: # E: Name "maybe_true" is not defined
@overload # E: Single overload definition, multiple required
def f2(g: C) -> C: ...
def f2(g): ... # E: Name "f2" already defined on line 21

@overload # E: Single overload definition, multiple required
def f3(g: A) -> A: ...
if True:
@overload # E: Single overload definition, multiple required
def f3(g: B) -> B: ...
if True:
pass # Some other node
@overload # E: Name "f3" already defined on line 32 \
# E: An overloaded function outside a stub file must have an implementation
def f3(g: C) -> C: ...
@overload
def f3(g: D) -> D: ...
def f3(g): ... # E: Name "f3" already defined on line 32

0 comments on commit cf6a48c

Please sign in to comment.