Skip to content

Commit 8a6787b

Browse files
AlexWaygoodsharkdp
andauthored
[red-knot] Fix control flow for assert statements (#17702)
## Summary @sharkdp and I realised in our 1:1 this morning that our control flow for `assert` statements isn't quite accurate at the moment. Namely, for something like this: ```py def _(x: int | None): assert x is None, reveal_type(x) ``` we currently reveal `None` for `x` here, but this is incorrect. In actual fact, the `msg` expression of an `assert` statement (the expression after the comma) will only be evaluated if the test (`x is None`) evaluates to `False`. As such, we should be adding a constraint of `~None` to `x` in the `msg` expression, which should simplify the inferred type of `x` to `int` in that context (`(int | None) & ~None` -> `int`). ## Test Plan Mdtests added. --------- Co-authored-by: David Peter <mail@david-peter.de>
1 parent 4a621c2 commit 8a6787b

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,64 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
5151
assert y not in (1, 2)
5252
reveal_type(y) # revealed: Literal[3]
5353
```
54+
55+
## Assertions with messages
56+
57+
```py
58+
def _(x: int | None, y: int | None):
59+
reveal_type(x) # revealed: int | None
60+
assert x is None, reveal_type(x) # revealed: int
61+
reveal_type(x) # revealed: None
62+
63+
reveal_type(y) # revealed: int | None
64+
assert isinstance(y, int), reveal_type(y) # revealed: None
65+
reveal_type(y) # revealed: int
66+
```
67+
68+
## Assertions with definitions inside the message
69+
70+
```py
71+
def one(x: int | None):
72+
assert x is None, (y := x * 42) * reveal_type(y) # revealed: int
73+
74+
# error: [unresolved-reference]
75+
reveal_type(y) # revealed: Unknown
76+
77+
def two(x: int | None, y: int | None):
78+
assert x is None, (y := 42) * reveal_type(y) # revealed: Literal[42]
79+
reveal_type(y) # revealed: int | None
80+
```
81+
82+
## Assertions with `test` predicates that are statically known to always be `True`
83+
84+
```py
85+
assert True, (x := 1)
86+
87+
# error: [unresolved-reference]
88+
reveal_type(x) # revealed: Unknown
89+
90+
assert False, (y := 1)
91+
92+
# The `assert` statement is terminal if `test` resolves to `False`,
93+
# so even though we know the `msg` branch will have been taken here
94+
# (we know what the truthiness of `False is!), we also know that the
95+
# `y` definition is not visible from this point in control flow
96+
# (because this point in control flow is unreachable).
97+
# We make sure that this does not emit an `[unresolved-reference]`
98+
# diagnostic by adding a reachability constraint,
99+
# but the inferred type is `Unknown`.
100+
#
101+
reveal_type(y) # revealed: Unknown
102+
```
103+
104+
## Assertions with messages that reference definitions from the `test`
105+
106+
```py
107+
def one(x: int | None):
108+
assert (y := x), reveal_type(y) # revealed: (int & ~AlwaysTruthy) | None
109+
reveal_type(y) # revealed: int & ~AlwaysFalsy
110+
111+
def two(x: int | None):
112+
assert isinstance((y := x), int), reveal_type(y) # revealed: None
113+
reveal_type(y) # revealed: int
114+
```

crates/red_knot_python_semantic/resources/mdtest/unreachable.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,15 @@ def f():
362362
ExceptionGroup
363363
```
364364

365+
Similarly, assertions with statically-known falsy conditions can lead to unreachable code:
366+
367+
```py
368+
def f():
369+
assert sys.version_info > (3, 11)
370+
371+
ExceptionGroup
372+
```
373+
365374
Finally, not that anyone would ever use it, but it also works for `while` loops:
366375

367376
```py

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,8 @@ impl<'db> SemanticIndexBuilder<'db> {
532532

533533
/// Negates a predicate and adds it to the list of all predicates, does not record it.
534534
fn add_negated_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
535-
let negated = Predicate {
536-
node: predicate.node,
537-
is_positive: false,
538-
};
539-
self.current_use_def_map_mut().add_predicate(negated)
535+
self.current_use_def_map_mut()
536+
.add_predicate(predicate.negated())
540537
}
541538

542539
/// Records a previously added narrowing constraint by adding it to all live bindings.
@@ -1383,14 +1380,46 @@ where
13831380
}
13841381
}
13851382

1386-
ast::Stmt::Assert(node) => {
1387-
self.visit_expr(&node.test);
1388-
let predicate = self.record_expression_narrowing_constraint(&node.test);
1389-
self.record_visibility_constraint(predicate);
1383+
ast::Stmt::Assert(ast::StmtAssert {
1384+
test,
1385+
msg,
1386+
range: _,
1387+
}) => {
1388+
// We model an `assert test, msg` statement here. Conceptually, we can think of
1389+
// this as being equivalent to the following:
1390+
//
1391+
// ```py
1392+
// if not test:
1393+
// msg
1394+
// <halt>
1395+
//
1396+
// <whatever code comes after>
1397+
// ```
1398+
//
1399+
// Importantly, the `msg` expression is only evaluated if the `test` expression is
1400+
// falsy. This is why we apply the negated `test` predicate as a narrowing and
1401+
// reachability constraint on the `msg` expression.
1402+
//
1403+
// The other important part is the `<halt>`. This lets us skip the usual merging of
1404+
// flow states and simplification of visibility constraints, since there is no way
1405+
// of getting out of that `msg` branch. We simply restore to the post-test state.
1406+
1407+
self.visit_expr(test);
1408+
let predicate = self.build_predicate(test);
13901409

1391-
if let Some(msg) = &node.msg {
1410+
if let Some(msg) = msg {
1411+
let post_test = self.flow_snapshot();
1412+
let negated_predicate = predicate.negated();
1413+
self.record_narrowing_constraint(negated_predicate);
1414+
self.record_reachability_constraint(negated_predicate);
13921415
self.visit_expr(msg);
1416+
self.record_visibility_constraint(negated_predicate);
1417+
self.flow_restore(post_test);
13931418
}
1419+
1420+
self.record_narrowing_constraint(predicate);
1421+
self.record_visibility_constraint(predicate);
1422+
self.record_reachability_constraint(predicate);
13941423
}
13951424

13961425
ast::Stmt::Assign(node) => {

crates/red_knot_python_semantic/src/semantic_index/predicate.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ pub(crate) struct Predicate<'db> {
4949
pub(crate) is_positive: bool,
5050
}
5151

52+
impl Predicate<'_> {
53+
pub(crate) fn negated(self) -> Self {
54+
Self {
55+
node: self.node,
56+
is_positive: !self.is_positive,
57+
}
58+
}
59+
}
60+
5261
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update)]
5362
pub(crate) enum PredicateNode<'db> {
5463
Expression(Expression<'db>),

0 commit comments

Comments
 (0)