From 5806bc77a5711e25bfbcac9673ca31064fec01ad Mon Sep 17 00:00:00 2001 From: Edward Paget Date: Tue, 14 Nov 2023 14:15:57 -0800 Subject: [PATCH 1/5] Fix narrowing on match with function subject Fixes #12998 mypy can't narrow match statements with functions subjects because the callexpr node is not a literal node. This adds a 'dummy' literal node that the match statement visitor can use to do the type narrowing. The python grammar describes the the match subject as a named expression so this uses that nameexpr node as it's literal. --- mypy/checker.py | 11 ++++++++--- test-data/unit/check-python310.test | 12 ++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e4eb58d40715d..00341157e60c8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5043,8 +5043,13 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return None def visit_match_stmt(self, s: MatchStmt) -> None: + # Create a dummy subject expression to handle cases where a match + # statement's subject is not a literal value which prevent us from correctly + # narrowing types and checking exhaustivity + named_subject = NameExpr("match") if isinstance(s.subject, CallExpr) else s.subject with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) + self.store_type(named_subject, subject_type) if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) @@ -5061,7 +5066,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # The second pass narrows down the types and type checks bodies. for p, g, b in zip(s.patterns, s.guards, s.bodies): current_subject_type = self.expr_checker.narrow_type_from_binder( - s.subject, subject_type + named_subject, subject_type ) pattern_type = self.pattern_checker.accept(p, current_subject_type) with self.binder.frame_context(can_skip=True, fall_through=2): @@ -5072,7 +5077,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: else_map: TypeMap = {} else: pattern_map, else_map = conditional_types_to_typemaps( - s.subject, pattern_type.type, pattern_type.rest_type + named_subject, pattern_type.type, pattern_type.rest_type ) self.remove_capture_conflicts(pattern_type.captures, inferred_types) self.push_type_map(pattern_map) @@ -5100,7 +5105,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: and expr.fullname == case_target.fullname ): continue - type_map[s.subject] = type_map[expr] + type_map[named_subject] = type_map[expr] self.push_type_map(guard_map) self.accept(b) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d3cdf3af849d4..3c31911f98f7c 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1139,6 +1139,18 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" +[case testMatchCapturePatternFromFunctionReturningUnion] +def func(arg: bool) -> str | int: + if arg: + return 1 + return "a" + +match func(True): + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + case a: + reveal_type(a) # N: Revealed type is "builtins.int" + -- Guards -- [case testMatchSimplePatternGuard] From 5c659a0a48194a2a870ea53b6b34b43ce7c2331b Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Mon, 25 Dec 2023 01:44:27 -0600 Subject: [PATCH 2/5] . --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 00341157e60c8..b31401ea8e153 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5046,7 +5046,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # Create a dummy subject expression to handle cases where a match # statement's subject is not a literal value which prevent us from correctly # narrowing types and checking exhaustivity - named_subject = NameExpr("match") if isinstance(s.subject, CallExpr) else s.subject + named_subject = NameExpr("dummy-match") if isinstance(s.subject, CallExpr) else s.subject with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) self.store_type(named_subject, subject_type) From 1081f316103bb50e0d63aa05eb21cd29056f9121 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 11 Feb 2024 02:13:22 -0800 Subject: [PATCH 3/5] set a node to prevent incorrect narrowing, add test --- mypy/checker.py | 16 +++++++++++----- test-data/unit/check-python310.test | 23 +++++++++++++---------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b31401ea8e153..0cb2103687962 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5043,13 +5043,19 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return None def visit_match_stmt(self, s: MatchStmt) -> None: - # Create a dummy subject expression to handle cases where a match - # statement's subject is not a literal value which prevent us from correctly - # narrowing types and checking exhaustivity - named_subject = NameExpr("dummy-match") if isinstance(s.subject, CallExpr) else s.subject + # Create a dummy subject expression to handle cases where a match statement's subject is + # not a literal value. This lets us correctly narrow types and check exhaustivity + if isinstance(s.subject, CallExpr): + id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" + name = "dummy-match-" + id + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + else: + named_subject = s.subject + with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) - self.store_type(named_subject, subject_type) if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 3c31911f98f7c..7b4cdf349198e 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1140,16 +1140,19 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" [case testMatchCapturePatternFromFunctionReturningUnion] -def func(arg: bool) -> str | int: - if arg: - return 1 - return "a" - -match func(True): - case str(a): - reveal_type(a) # N: Revealed type is "builtins.str" - case a: - reveal_type(a) # N: Revealed type is "builtins.int" +def func1(arg: bool) -> str | int: ... +def func2(arg: bool) -> bytes | int: ... + +def main() -> None: + match func1(True): + case str(a): + match func2(True): + case c: + reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(c) # N: Revealed type is "Union[builtins.bytes, builtins.int]" + reveal_type(a) # N: Revealed type is "builtins.str" + case a: + reveal_type(a) # N: Revealed type is "builtins.int" -- Guards -- From e392514d01846b53685e37aab476cde9bd8cdfcf Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 11 Feb 2024 02:14:12 -0800 Subject: [PATCH 4/5] . --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 0cb2103687962..9550c7a982ccb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5047,7 +5047,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # not a literal value. This lets us correctly narrow types and check exhaustivity if isinstance(s.subject, CallExpr): id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id + name = "dummy-match-" + id # this is a hack v = Var(name) named_subject = NameExpr(name) named_subject.node = v From b44c9aa910f3410c219202a2cc18a66bd02b0cd6 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 11 Feb 2024 02:18:39 -0800 Subject: [PATCH 5/5] . --- mypy/checker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9550c7a982ccb..70831c93c3053 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5043,11 +5043,13 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return None def visit_match_stmt(self, s: MatchStmt) -> None: - # Create a dummy subject expression to handle cases where a match statement's subject is - # not a literal value. This lets us correctly narrow types and check exhaustivity + named_subject: Expression if isinstance(s.subject, CallExpr): + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id # this is a hack + name = "dummy-match-" + id v = Var(name) named_subject = NameExpr(name) named_subject.node = v