From 86a5367506f9d6384d6b0556a0b707d4f9294512 Mon Sep 17 00:00:00 2001 From: Carl Meyer <carl@astral.sh> Date: Sun, 5 Jan 2025 11:51:40 -0600 Subject: [PATCH] [red-knot] fix control flow for assignment expressions in elif tests --- .../mdtest/conditional/if_statement.md | 32 +++++++++++++++++++ .../src/semantic_index/builder.rs | 22 ++++++------- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md index b436a739a1141b..539fdf2a4869de 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md +++ b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md @@ -115,3 +115,35 @@ def _(flag: bool, flag2: bool): reveal_type(y) # revealed: Literal[2, 3, 4] ``` + +## if-elif with assignment expressions in tests + +```py +def check(x: int) -> bool: + return bool(x) + +if check(x := 1): + x = 2 +elif check(x := 3): + x = 4 + +reveal_type(x) # revealed: Literal[2, 3, 4] +``` + +## constraints apply to later test expressions + +```py +def check(x) -> bool: + return bool(x) + +def _(flag: bool): + x = 1 if flag else None + y = 0 + + if x is None: + pass + elif check(y := x): + pass + + reveal_type(y) # revealed: Literal[0, 1] +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index c8ca5dd2e66188..5d8c7208a65cd4 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -878,12 +878,11 @@ where } ast::Stmt::If(node) => { self.visit_expr(&node.test); - let pre_if = self.flow_snapshot(); - let constraint = self.record_expression_constraint(&node.test); - let mut constraints = vec![constraint]; + let mut no_branch_taken = self.flow_snapshot(); + let mut last_constraint = self.record_expression_constraint(&node.test); self.visit_body(&node.body); - let visibility_constraint_id = self.record_visibility_constraint(constraint); + let visibility_constraint_id = self.record_visibility_constraint(last_constraint); let mut vis_constraints = vec![visibility_constraint_id]; let mut post_clauses: Vec<FlowSnapshot> = vec![]; @@ -907,26 +906,27 @@ where // the state that we merge the other snapshots into post_clauses.push(self.flow_snapshot()); // we can only take an elif/else branch if none of the previous ones were - // taken, so the block entry state is always `pre_if` - self.flow_restore(pre_if.clone()); - for constraint in &constraints { - self.record_negated_constraint(*constraint); - } + // taken + self.flow_restore(no_branch_taken.clone()); + self.record_negated_constraint(last_constraint); let elif_constraint = if let Some(elif_test) = clause_test { self.visit_expr(elif_test); + // test expression is evaluated whether we take the branch or not + no_branch_taken = self.flow_snapshot(); let constraint = self.record_expression_constraint(elif_test); - constraints.push(constraint); Some(constraint) } else { None }; + self.visit_body(clause_body); for id in &vis_constraints { self.record_negated_visibility_constraint(*id); } if let Some(elif_constraint) = elif_constraint { + last_constraint = elif_constraint; let id = self.record_visibility_constraint(elif_constraint); vis_constraints.push(id); } @@ -936,7 +936,7 @@ where self.flow_merge(post_clause_state); } - self.simplify_visibility_constraints(pre_if); + self.simplify_visibility_constraints(no_branch_taken); } ast::Stmt::While(ast::StmtWhile { test,