diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 077d287655fb..0b3322db2af3 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -496,18 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if_overload_name: Optional[str] = None if_block_with_overload: Optional[Block] = None if_unknown_truth_value: Optional[IfStmt] = None - if ( - isinstance(stmt, IfStmt) - and len(stmt.body[0].body) == 1 - and seen_unconditional_func_def is False - and ( - isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) - or current_overload_name is not None - and isinstance(stmt.body[0].body[0], FuncDef) - ) - ): + if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False: # Check IfStmt block to determine if function overloads can be merged - if_overload_name = self._check_ifstmt_for_overloads(stmt) + if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name) if if_overload_name is not None: if_block_with_overload, if_unknown_truth_value = \ self._get_executable_if_block_with_overloads(stmt) @@ -553,8 +544,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: else: current_overload.append(last_if_overload) last_if_stmt, last_if_overload = None, None - if isinstance(if_block_with_overload.body[0], OverloadedFuncDef): - current_overload.extend(if_block_with_overload.body[0].items) + if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef): + skipped_if_stmts.extend( + cast(List[IfStmt], if_block_with_overload.body[:-1]) + ) + current_overload.extend(if_block_with_overload.body[-1].items) else: current_overload.append( cast(Union[Decorator, FuncDef], if_block_with_overload.body[0]) @@ -600,9 +594,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: last_if_stmt = stmt last_if_stmt_overload_name = None if if_block_with_overload is not None: + skipped_if_stmts.extend( + cast(List[IfStmt], if_block_with_overload.body[:-1]) + ) last_if_overload = cast( Union[Decorator, FuncDef, OverloadedFuncDef], - if_block_with_overload.body[0] + if_block_with_overload.body[-1] ) last_if_unknown_truth_value = if_unknown_truth_value else: @@ -620,11 +617,15 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret.append(current_overload[0]) elif len(current_overload) > 1: ret.append(OverloadedFuncDef(current_overload)) + elif last_if_overload is not None: + ret.append(last_if_overload) elif last_if_stmt is not None: ret.append(last_if_stmt) return ret - def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: + def _check_ifstmt_for_overloads( + self, stmt: IfStmt, current_overload_name: Optional[str] = None + ) -> Optional[str]: """Check if IfStmt contains only overloads with the same name. Return overload_name if found, None otherwise. """ @@ -632,11 +633,22 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: # Multiple overloads have already been merged as OverloadedFuncDef. if not ( len(stmt.body[0].body) == 1 - and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and ( + isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) + or current_overload_name is not None + and isinstance(stmt.body[0].body[0], FuncDef) + ) + or len(stmt.body[0].body) > 1 + and isinstance(stmt.body[0].body[-1], OverloadedFuncDef) + and all( + self._is_stripped_if_stmt(if_stmt) + for if_stmt in stmt.body[0].body[:-1] + ) ): return None - overload_name = stmt.body[0].body[0].name + overload_name = cast( + Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]).name if stmt.else_body is None: return overload_name @@ -649,7 +661,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]: return overload_name if ( isinstance(stmt.else_body.body[0], IfStmt) - and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name + and self._check_ifstmt_for_overloads( + stmt.else_body.body[0], current_overload_name + ) == overload_name ): return overload_name @@ -704,6 +718,25 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None: else: stmt.else_body.body = [] + def _is_stripped_if_stmt(self, stmt: Statement) -> bool: + """Check stmt to make sure it is a stripped IfStmt. + + See also: _strip_contents_from_if_stmt + """ + if not isinstance(stmt, IfStmt): + return False + + if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0): + # Body not empty + return False + + if not stmt.else_body or len(stmt.else_body.body) == 0: + # No or empty else_body + return True + + # For elif, IfStmt are stored recursively in else_body + return self._is_stripped_if_stmt(stmt.else_body.body[0]) + def in_method_scope(self) -> bool: return self.class_and_function_stack[-2:] == ['C', 'F'] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index e2a87ea62a92..8259f2754bce 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6367,3 +6367,142 @@ def g(x: int) -> str: ... def g(x: int = 0) -> int: # E: Overloaded function implementation cannot produce return type of signature 2 return x + +[case testOverloadIfNestedOk] +# flags: --always-true True --always-false False +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f1(g: A) -> A: ... +if True: + @overload + def f1(g: B) -> B: ... + if True: + @overload + def f1(g: C) -> C: ... + @overload + def f1(g: D) -> D: ... +def f1(g): ... +reveal_type(f1(A())) # N: Revealed type is "__main__.A" +reveal_type(f1(B())) # N: Revealed type is "__main__.B" +reveal_type(f1(C())) # N: Revealed type is "__main__.C" +reveal_type(f1(D())) # N: Revealed type is "__main__.D" + +@overload +def f2(g: A) -> A: ... +if True: + @overload + def f2(g: B) -> B: ... + if True: + @overload + def f2(g: C) -> C: ... + if True: + @overload + def f2(g: D) -> D: ... +def f2(g): ... +reveal_type(f2(A())) # N: Revealed type is "__main__.A" +reveal_type(f2(B())) # N: Revealed type is "__main__.B" +reveal_type(f2(C())) # N: Revealed type is "__main__.C" +reveal_type(f2(D())) # N: Revealed type is "__main__.D" + +@overload +def f3(g: A) -> A: ... +if True: + if True: + @overload + def f3(g: B) -> B: ... + if True: + @overload + def f3(g: C) -> C: ... +def f3(g): ... +reveal_type(f3(A())) # N: Revealed type is "__main__.A" +reveal_type(f3(B())) # N: Revealed type is "__main__.B" +reveal_type(f3(C())) # N: Revealed type is "__main__.C" + +@overload +def f4(g: A) -> A: ... +if True: + if False: + @overload + def f4(g: B) -> B: ... + else: + @overload + def f4(g: C) -> C: ... +def f4(g): ... +reveal_type(f4(A())) # N: Revealed type is "__main__.A" +reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \ + # N: Possible overload variants: \ + # N: def f4(g: A) -> A \ + # N: def f4(g: C) -> C \ + # N: Revealed type is "Any" +reveal_type(f4(C())) # N: Revealed type is "__main__.C" + +@overload +def f5(g: A) -> A: ... +if True: + if False: + @overload + def f5(g: B) -> B: ... + elif True: + @overload + def f5(g: C) -> C: ... +def f5(g): ... +reveal_type(f5(A())) # N: Revealed type is "__main__.A" +reveal_type(f5(B())) # E: No overload variant of "f5" matches argument type "B" \ + # N: Possible overload variants: \ + # N: def f5(g: A) -> A \ + # N: def f5(g: C) -> C \ + # N: Revealed type is "Any" +reveal_type(f5(C())) # N: Revealed type is "__main__.C" + +[case testOverloadIfNestedFailure] +# flags: --always-true True --always-false False +from typing import overload + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload # E: Single overload definition, multiple required +def f1(g: A) -> A: ... +if True: + @overload # E: Single overload definition, multiple required + def f1(g: B) -> B: ... + if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \ + # E: Name "maybe_true" is not defined + @overload + def f1(g: C) -> C: ... + @overload + def f1(g: D) -> D: ... +def f1(g): ... # E: Name "f1" already defined on line 9 + +@overload # E: Single overload definition, multiple required +def f2(g: A) -> A: ... +if True: + if False: + @overload + def f2(g: B) -> B: ... + elif maybe_true: # E: Name "maybe_true" is not defined + @overload # E: Single overload definition, multiple required + def f2(g: C) -> C: ... +def f2(g): ... # E: Name "f2" already defined on line 21 + +@overload # E: Single overload definition, multiple required +def f3(g: A) -> A: ... +if True: + @overload # E: Single overload definition, multiple required + def f3(g: B) -> B: ... + if True: + pass # Some other node + @overload # E: Name "f3" already defined on line 32 \ + # E: An overloaded function outside a stub file must have an implementation + def f3(g: C) -> C: ... + @overload + def f3(g: D) -> D: ... +def f3(g): ... # E: Name "f3" already defined on line 32