@@ -703,24 +703,17 @@ def run(self, mod: ast.Module) -> None:
703703 return
704704 pos = 0
705705 for item in mod .body :
706- if (
707- expect_docstring
708- and isinstance (item , ast .Expr )
709- and isinstance (item .value , ast .Constant )
710- and isinstance (item .value .value , str )
711- ):
712- doc = item .value .value
713- if self .is_rewrite_disabled (doc ):
714- return
715- expect_docstring = False
716- elif (
717- isinstance (item , ast .ImportFrom )
718- and item .level == 0
719- and item .module == "__future__"
720- ):
721- pass
722- else :
723- break
706+ match item :
707+ case ast .Expr (value = ast .Constant (value = str () as doc )) if (
708+ expect_docstring
709+ ):
710+ if self .is_rewrite_disabled (doc ):
711+ return
712+ expect_docstring = False
713+ case ast .ImportFrom (level = 0 , module = "__future__" ):
714+ pass
715+ case _:
716+ break
724717 pos += 1
725718 # Special case: for a decorated function, set the lineno to that of the
726719 # first decorator, not the `def`. Issue #4984.
@@ -1017,20 +1010,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
10171010 # cond is set in a prior loop iteration below
10181011 self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa: F821
10191012 self .expl_stmts = fail_inner
1020- # Check if the left operand is a ast.NamedExpr and the value has already been visited
1021- if (
1022- isinstance (v , ast .Compare )
1023- and isinstance (v .left , ast .NamedExpr )
1024- and v .left .target .id
1025- in [
1026- ast_expr .id
1027- for ast_expr in boolop .values [:i ]
1028- if hasattr (ast_expr , "id" )
1029- ]
1030- ):
1031- pytest_temp = self .variable ()
1032- self .variables_overwrite [self .scope ][v .left .target .id ] = v .left # type:ignore[assignment]
1033- v .left .target .id = pytest_temp
1013+ match v :
1014+ # Check if the left operand is an ast.NamedExpr and the value has already been visited
1015+ case ast .Compare (
1016+ left = ast .NamedExpr (target = ast .Name (id = target_id ))
1017+ ) if target_id in [
1018+ e .id for e in boolop .values [:i ] if hasattr (e , "id" )
1019+ ]:
1020+ pytest_temp = self .variable ()
1021+ self .variables_overwrite [self .scope ][target_id ] = v .left # type:ignore[assignment]
1022+ # mypy's false positive, we're checking that the 'target' attribute exists.
1023+ v .left .target .id = pytest_temp # type:ignore[attr-defined]
10341024 self .push_format_context ()
10351025 res , expl = self .visit (v )
10361026 body .append (ast .Assign ([ast .Name (res_var , ast .Store ())], res ))
@@ -1080,10 +1070,11 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
10801070 arg_expls .append (expl )
10811071 new_args .append (res )
10821072 for keyword in call .keywords :
1083- if isinstance (
1084- keyword .value , ast .Name
1085- ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1086- keyword .value = self .variables_overwrite [self .scope ][keyword .value .id ] # type:ignore[assignment]
1073+ match keyword .value :
1074+ case ast .Name (id = id ) if id in self .variables_overwrite .get (
1075+ self .scope , {}
1076+ ):
1077+ keyword .value = self .variables_overwrite [self .scope ][id ] # type:ignore[assignment]
10871078 res , expl = self .visit (keyword .value )
10881079 new_kwargs .append (ast .keyword (keyword .arg , res ))
10891080 if keyword .arg :
@@ -1119,12 +1110,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11191110 def visit_Compare (self , comp : ast .Compare ) -> tuple [ast .expr , str ]:
11201111 self .push_format_context ()
11211112 # We first check if we have overwritten a variable in the previous assert
1122- if isinstance (
1123- comp .left , ast .Name
1124- ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1125- comp .left = self .variables_overwrite [self .scope ][comp .left .id ] # type:ignore[assignment]
1126- if isinstance (comp .left , ast .NamedExpr ):
1127- self .variables_overwrite [self .scope ][comp .left .target .id ] = comp .left # type:ignore[assignment]
1113+ match comp .left :
1114+ case ast .Name (id = name_id ) if name_id in self .variables_overwrite .get (
1115+ self .scope , {}
1116+ ):
1117+ comp .left = self .variables_overwrite [self .scope ][name_id ] # type: ignore[assignment]
1118+ case ast .NamedExpr (target = ast .Name (id = target_id )):
1119+ self .variables_overwrite [self .scope ][target_id ] = comp .left # type: ignore[assignment]
11281120 left_res , left_expl = self .visit (comp .left )
11291121 if isinstance (comp .left , ast .Compare | ast .BoolOp ):
11301122 left_expl = f"({ left_expl } )"
@@ -1136,13 +1128,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11361128 syms : list [ast .expr ] = []
11371129 results = [left_res ]
11381130 for i , op , next_operand in it :
1139- if (
1140- isinstance (next_operand , ast .NamedExpr )
1141- and isinstance (left_res , ast .Name )
1142- and next_operand .target .id == left_res .id
1143- ):
1144- next_operand .target .id = self .variable ()
1145- self .variables_overwrite [self .scope ][left_res .id ] = next_operand # type:ignore[assignment]
1131+ match (next_operand , left_res ):
1132+ case (
1133+ ast .NamedExpr (target = ast .Name (id = target_id )),
1134+ ast .Name (id = name_id ),
1135+ ) if target_id == name_id :
1136+ next_operand .target .id = self .variable ()
1137+ self .variables_overwrite [self .scope ][name_id ] = next_operand # type: ignore[assignment]
1138+
11461139 next_res , next_expl = self .visit (next_operand )
11471140 if isinstance (next_operand , ast .Compare | ast .BoolOp ):
11481141 next_expl = f"({ next_expl } )"
0 commit comments