1010from mypy .checkmember import analyze_member_access
1111from mypy .expandtype import expand_type_by_instance
1212from mypy .join import join_types
13- from mypy .literals import literal_hash
13+ from mypy .literals import Key , literal_hash
1414from mypy .maptype import map_instance_to_supertype
1515from mypy .meet import narrow_declared_type
1616from mypy .messages import MessageBuilder
17- from mypy .nodes import ARG_POS , Context , Expression , NameExpr , TypeAlias , TypeInfo , Var
17+ from mypy .nodes import (
18+ ARG_POS ,
19+ Context ,
20+ Expression ,
21+ IndexExpr ,
22+ IntExpr ,
23+ MemberExpr ,
24+ NameExpr ,
25+ TypeAlias ,
26+ TypeInfo ,
27+ UnaryExpr ,
28+ Var ,
29+ )
1830from mypy .options import Options
1931from mypy .patterns import (
2032 AsPattern ,
@@ -96,10 +108,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96108 msg : MessageBuilder
97109 # Currently unused
98110 plugin : Plugin
99- # The expression being matched against the pattern
100- subject : Expression
101-
102- subject_type : Type
111+ # The expressions being matched against the (sub)pattern
112+ subject_context : list [list [Expression ]]
103113 # Type of the subject to check the (sub)pattern against
104114 type_context : list [Type ]
105115 # Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +128,28 @@ def __init__(
118128 self .msg = msg
119129 self .plugin = plugin
120130
131+ self .subject_context = []
121132 self .type_context = []
122133 self .self_match_types = self .generate_types_from_names (self_match_type_names )
123134 self .non_sequence_match_types = self .generate_types_from_names (
124135 non_sequence_match_type_names
125136 )
126137 self .options = options
127138
128- def accept (self , o : Pattern , type_context : Type ) -> PatternType :
139+ def accept (self , o : Pattern , type_context : Type , subject : list [Expression ]) -> PatternType :
140+ self .subject_context .append (subject )
129141 self .type_context .append (type_context )
130142 result = o .accept (self )
143+ self .subject_context .pop ()
131144 self .type_context .pop ()
132145
133146 return result
134147
135148 def visit_as_pattern (self , o : AsPattern ) -> PatternType :
149+ current_subject = self .subject_context [- 1 ]
136150 current_type = self .type_context [- 1 ]
137151 if o .pattern is not None :
138- pattern_type = self .accept (o .pattern , current_type )
152+ pattern_type = self .accept (o .pattern , current_type , current_subject )
139153 typ , rest_type , type_map = pattern_type
140154 else :
141155 typ , rest_type , type_map = current_type , UninhabitedType (), {}
@@ -150,14 +164,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150164 return PatternType (typ , rest_type , type_map )
151165
152166 def visit_or_pattern (self , o : OrPattern ) -> PatternType :
167+ current_subject = self .subject_context [- 1 ]
153168 current_type = self .type_context [- 1 ]
154169
155170 #
156171 # Check all the subpatterns
157172 #
158- pattern_types = []
173+ pattern_types : list [ PatternType ] = []
159174 for pattern in o .patterns :
160- pattern_type = self .accept (pattern , current_type )
175+ pattern_type = self .accept (pattern , current_type , current_subject )
161176 pattern_types .append (pattern_type )
162177 if not is_uninhabited (pattern_type .type ):
163178 current_type = pattern_type .rest_type
@@ -173,28 +188,40 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173188 #
174189 # Check the capture types
175190 #
176- capture_types : dict [Var , list [tuple [Expression , Type ]]] = defaultdict (list )
191+ capture_types : dict [Var , dict [Key | None , list [tuple [Expression , Type ]]]] = defaultdict (
192+ lambda : defaultdict (list )
193+ )
194+ capture_expr_keys : set [Key | None ] = set ()
177195 # Collect captures from the first subpattern
178196 for expr , typ in pattern_types [0 ].captures .items ():
179197 node = get_var (expr )
180- capture_types [node ].append ((expr , typ ))
198+ key = literal_hash (expr )
199+ capture_types [node ][key ].append ((expr , typ ))
200+ if isinstance (expr , NameExpr ):
201+ capture_expr_keys .add (key )
181202
182203 # Check if other subpatterns capture the same names
183204 for i , pattern_type in enumerate (pattern_types [1 :]):
184- vars = {get_var (expr ) for expr , _ in pattern_type .captures .items ()}
185- if capture_types .keys () != vars :
205+ vars = {
206+ literal_hash (expr ) for expr in pattern_type .captures if isinstance (expr , NameExpr )
207+ }
208+ if capture_expr_keys != vars :
209+ # Only fail for directly captured names (with NameExpr)
186210 self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
187211 for expr , typ in pattern_type .captures .items ():
188212 node = get_var (expr )
189- capture_types [node ].append ((expr , typ ))
213+ key = literal_hash (expr )
214+ capture_types [node ][key ].append ((expr , typ ))
190215
191216 captures : dict [Expression , Type ] = {}
192- for capture_list in capture_types .values ():
193- typ = UninhabitedType ()
194- for _ , other in capture_list :
195- typ = make_simplified_union ([typ , other ])
217+ for expressions in capture_types .values ():
218+ for key , capture_list in expressions .items ():
219+ if other_types := [entry [1 ] for entry in capture_list ]:
220+ typ = make_simplified_union (other_types )
221+ else :
222+ typ = UninhabitedType ()
196223
197- captures [capture_list [0 ][0 ]] = typ
224+ captures [capture_list [0 ][0 ]] = typ
198225
199226 union_type = make_simplified_union (types )
200227 return PatternType (union_type , current_type , captures )
@@ -284,12 +311,24 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284311 contracted_inner_types = self .contract_starred_pattern_types (
285312 inner_types , star_position , required_patterns
286313 )
287- for p , t in zip (o .patterns , contracted_inner_types ):
288- pattern_type = self .accept (p , t )
314+ current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
315+ for s in self .subject_context [- 1 ]:
316+ # Support x[0], x[1], ... lookup until wildcard
317+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
318+ for i in range (end_pos ):
319+ current_subjects [i ].append (IndexExpr (s , IntExpr (i )))
320+ # For everything after wildcard use x[-2], x[-1]
321+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
322+ offset = len (contracted_inner_types ) - i
323+ current_subjects [i ].append (IndexExpr (s , UnaryExpr ("-" , IntExpr (offset ))))
324+ for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
325+ pattern_type = self .accept (p , t , s )
289326 typ , rest , type_map = pattern_type
290327 contracted_new_inner_types .append (typ )
291328 contracted_rest_inner_types .append (rest )
292329 self .update_type_map (captures , type_map )
330+ if s :
331+ self .update_type_map (captures , {subject : typ for subject in s })
293332
294333 new_inner_types = self .expand_starred_pattern_types (
295334 contracted_new_inner_types , star_position , len (inner_types ), unpack_index is not None
@@ -473,11 +512,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473512 if inner_type is None :
474513 can_match = False
475514 inner_type = self .chk .named_type ("builtins.object" )
476- pattern_type = self .accept (value , inner_type )
515+ current_subjects : list [Expression ] = [
516+ IndexExpr (s , key ) for s in self .subject_context [- 1 ]
517+ ]
518+ pattern_type = self .accept (value , inner_type , current_subjects )
477519 if is_uninhabited (pattern_type .type ):
478520 can_match = False
479521 else :
480522 self .update_type_map (captures , pattern_type .captures )
523+ if current_subjects :
524+ self .update_type_map (
525+ captures , {subject : pattern_type .type for subject in current_subjects }
526+ )
481527
482528 if o .rest is not None :
483529 mapping = self .chk .named_type ("typing.Mapping" )
@@ -581,7 +627,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581627 if self .should_self_match (typ ):
582628 if len (o .positionals ) > 1 :
583629 self .msg .fail (message_registry .CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS , o )
584- pattern_type = self .accept (o .positionals [0 ], narrowed_type )
630+ pattern_type = self .accept (o .positionals [0 ], narrowed_type , [] )
585631 if not is_uninhabited (pattern_type .type ):
586632 return PatternType (
587633 pattern_type .type ,
@@ -681,11 +727,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681727 elif keyword is not None :
682728 new_type = self .chk .add_any_attribute_to_type (new_type , keyword )
683729
684- inner_type , inner_rest_type , inner_captures = self .accept (pattern , key_type )
730+ current_subjects : list [Expression ] = []
731+ if keyword is not None :
732+ current_subjects = [MemberExpr (s , keyword ) for s in self .subject_context [- 1 ]]
733+ inner_type , inner_rest_type , inner_captures = self .accept (
734+ pattern , key_type , current_subjects
735+ )
685736 if is_uninhabited (inner_type ):
686737 can_match = False
687738 else :
688739 self .update_type_map (captures , inner_captures )
740+ if current_subjects :
741+ self .update_type_map (
742+ captures , {subject : inner_type for subject in current_subjects }
743+ )
689744 if not is_uninhabited (inner_rest_type ):
690745 rest_type = current_type
691746
@@ -799,6 +854,10 @@ def get_var(expr: Expression) -> Var:
799854 Warning: this in only true for expressions captured by a match statement.
800855 Don't call it from anywhere else
801856 """
857+ if isinstance (expr , MemberExpr ):
858+ return get_var (expr .expr )
859+ if isinstance (expr , IndexExpr ):
860+ return get_var (expr .base )
802861 assert isinstance (expr , NameExpr ), expr
803862 node = expr .node
804863 assert isinstance (node , Var ), node
0 commit comments