Skip to content

Commit 7e70581

Browse files
committed
[ty] Exhaustiveness checking & reachability for match statements
1 parent 3d17897 commit 7e70581

File tree

7 files changed

+109
-49
lines changed

7 files changed

+109
-49
lines changed

crates/ty_python_semantic/resources/mdtest/conditional/match.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ def _(target: Literal[True, False]):
201201
case None:
202202
y = 4
203203

204-
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
205-
reveal_type(y) # revealed: Literal[1, 2, 3]
204+
reveal_type(y) # revealed: Literal[2, 3]
206205

207206
def _(target: bool):
208207
y = 1
@@ -215,8 +214,7 @@ def _(target: bool):
215214
case None:
216215
y = 4
217216

218-
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
219-
reveal_type(y) # revealed: Literal[1, 2, 3]
217+
reveal_type(y) # revealed: Literal[2, 3]
220218

221219
def _(target: None):
222220
y = 1
@@ -242,8 +240,7 @@ def _(target: None | Literal[True]):
242240
case None:
243241
y = 4
244242

245-
# TODO: with exhaustiveness checking, this should be Literal[2, 4]
246-
reveal_type(y) # revealed: Literal[1, 2, 4]
243+
reveal_type(y) # revealed: Literal[2, 4]
247244

248245
# bool is an int subclass
249246
def _(target: int):
@@ -292,7 +289,7 @@ def _(answer: Answer):
292289
reveal_type(answer) # revealed: Literal[Answer.NO]
293290
y = 2
294291

295-
reveal_type(y) # revealed: Literal[0, 1, 2]
292+
reveal_type(y) # revealed: Literal[1, 2]
296293
```
297294

298295
## Or match
@@ -311,8 +308,7 @@ def _(target: Literal["foo", "baz"]):
311308
case "baz":
312309
y = 3
313310

314-
# TODO: with exhaustiveness, this should be Literal[2, 3]
315-
reveal_type(y) # revealed: Literal[1, 2, 3]
311+
reveal_type(y) # revealed: Literal[2, 3]
316312

317313
def _(target: None):
318314
y = 1

crates/ty_python_semantic/resources/mdtest/directives/assert_never.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def match_singletons_success(obj: Literal[1, "a"] | None):
119119
case None:
120120
pass
121121
case _ as obj:
122-
# TODO: Ideally, we would not emit an error here
123-
# error: [type-assertion-failure] "Argument does not have asserted type `Never`"
124122
assert_never(obj)
125123

126124
def match_singletons_error(obj: Literal[1, "a"] | None):

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,6 @@ def color_name(color: Color) -> str:
720720
case _:
721721
assert_never(color)
722722

723-
# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488
724-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`"
725723
def color_name_without_assertion(color: Color) -> str:
726724
match color:
727725
case Color.RED:

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,10 @@ def match_exhaustive(x: Literal[0, 1, "a"]):
5050
case "a":
5151
pass
5252
case _:
53-
# TODO: this should not be an error
54-
no_diagnostic_here # error: [unresolved-reference]
53+
no_diagnostic_here
5554

5655
assert_never(x)
5756

58-
# TODO: there should be no error here
59-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
6057
def match_exhaustive_no_assertion(x: Literal[0, 1, "a"]) -> int:
6158
match x:
6259
case 0:
@@ -130,13 +127,21 @@ def match_exhaustive(x: Color):
130127
case Color.BLUE:
131128
pass
132129
case _:
133-
# TODO: this should not be an error
134-
no_diagnostic_here # error: [unresolved-reference]
130+
no_diagnostic_here
131+
132+
assert_never(x)
133+
134+
def match_exhaustive_2(x: Color):
135+
match x:
136+
case Color.RED:
137+
pass
138+
case Color.GREEN | Color.BLUE:
139+
pass
140+
case _:
141+
no_diagnostic_here
135142

136143
assert_never(x)
137144

138-
# TODO: there should be no error here
139-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
140145
def match_exhaustive_no_assertion(x: Color) -> int:
141146
match x:
142147
case Color.RED:
@@ -208,13 +213,10 @@ def match_exhaustive(x: A | B | C):
208213
case C():
209214
pass
210215
case _:
211-
# TODO: this should not be an error
212-
no_diagnostic_here # error: [unresolved-reference]
216+
no_diagnostic_here
213217

214218
assert_never(x)
215219

216-
# TODO: there should be no error here
217-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
218220
def match_exhaustive_no_assertion(x: A | B | C) -> int:
219221
match x:
220222
case A():

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
734734
subject: Expression<'db>,
735735
pattern: &ast::Pattern,
736736
guard: Option<&ast::Expr>,
737-
) -> PredicateOrLiteral<'db> {
737+
previous_pattern: Option<PatternPredicate<'db>>,
738+
) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) {
738739
// This is called for the top-level pattern of each match arm. We need to create a
739740
// standalone expression for each arm of a match statement, since they can introduce
740741
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
@@ -756,13 +757,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
756757
subject,
757758
kind,
758759
guard,
760+
previous_pattern.map(Box::new),
759761
);
760762
let predicate = PredicateOrLiteral::Predicate(Predicate {
761763
node: PredicateNode::Pattern(pattern_predicate),
762764
is_positive: true,
763765
});
764766
self.record_narrowing_constraint(predicate);
765-
predicate
767+
(predicate, pattern_predicate)
766768
}
767769

768770
/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
@@ -1747,7 +1749,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
17471749
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());
17481750

17491751
let mut post_case_snapshots = vec![];
1750-
let mut match_predicate;
1752+
let mut previous_pattern: Option<PatternPredicate<'_>> = None;
17511753

17521754
for (i, case) in cases.iter().enumerate() {
17531755
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
@@ -1757,11 +1759,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
17571759
// here because the effects of visiting a pattern is binding
17581760
// symbols, and this doesn't occur unless the pattern
17591761
// actually matches
1760-
match_predicate = self.add_pattern_narrowing_constraint(
1761-
subject_expr,
1762-
&case.pattern,
1763-
case.guard.as_deref(),
1764-
);
1762+
let (match_predicate, match_pattern_predicate) = self
1763+
.add_pattern_narrowing_constraint(
1764+
subject_expr,
1765+
&case.pattern,
1766+
case.guard.as_deref(),
1767+
previous_pattern,
1768+
);
1769+
previous_pattern = Some(match_pattern_predicate);
17651770
let reachability_constraint =
17661771
self.record_reachability_constraint(match_predicate);
17671772

crates/ty_python_semantic/src/semantic_index/predicate.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ pub(crate) struct PatternPredicate<'db> {
150150
pub(crate) kind: PatternPredicateKind<'db>,
151151

152152
pub(crate) guard: Option<Expression<'db>>,
153+
154+
/// A reference to the pattern of the previous match case
155+
pub(crate) previous_predicate: Option<Box<PatternPredicate<'db>>>,
153156
}
154157

155158
// The Salsa heap is tracked separately.

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,14 @@ use crate::Db;
202202
use crate::dunder_all::dunder_all_names;
203203
use crate::place::{RequiresExplicitReExport, imported_symbol};
204204
use crate::rank::RankBitBox;
205-
use crate::semantic_index::expression::Expression;
206205
use crate::semantic_index::place_table;
207206
use crate::semantic_index::predicate::{
208207
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
209208
Predicates, ScopedPredicateId,
210209
};
211-
use crate::types::{Truthiness, Type, infer_expression_type};
210+
use crate::types::{
211+
IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type,
212+
};
212213

213214
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
214215
/// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the
@@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId
311312
const ALWAYS_FALSE: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId::ALWAYS_FALSE;
312313
const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE;
313314

315+
fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type<'_> {
316+
let ty = match singleton {
317+
ruff_python_ast::Singleton::None => Type::none(db),
318+
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
319+
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
320+
};
321+
debug_assert!(ty.is_singleton(db));
322+
ty
323+
}
324+
325+
/// Turn a `match` pattern kind into a type that represents the set of all values that would definitely
326+
/// match that pattern.
327+
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
328+
match kind {
329+
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
330+
PatternPredicateKind::Value(value) => infer_expression_type(db, *value),
331+
PatternPredicateKind::Class(class_expr, kind) => {
332+
if kind.is_irrefutable() {
333+
infer_expression_type(db, *class_expr)
334+
.to_instance(db)
335+
.unwrap_or(Type::Never)
336+
} else {
337+
Type::Never
338+
}
339+
}
340+
PatternPredicateKind::Or(predicates) => {
341+
UnionType::from_elements(db, predicates.iter().map(|p| pattern_kind_to_type(db, p)))
342+
}
343+
PatternPredicateKind::Unsupported => Type::Never,
344+
}
345+
}
346+
347+
/// Go through the list of previous match cases, and accumulate a union of all types that were already
348+
/// matched by these patterns.
349+
fn type_excluded_by_previous_patterns<'db>(
350+
db: &'db dyn Db,
351+
mut predicate: PatternPredicate<'db>,
352+
) -> Type<'db> {
353+
let mut builder = UnionBuilder::new(db);
354+
while let Some(previous) = predicate.previous_predicate(db) {
355+
predicate = *previous;
356+
357+
if predicate.guard(db).is_none() {
358+
builder = builder.add(pattern_kind_to_type(db, predicate.kind(db)));
359+
}
360+
}
361+
builder.build()
362+
}
363+
314364
/// A collection of reachability constraints for a given scope.
315365
#[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
316366
pub(crate) struct ReachabilityConstraints {
@@ -637,11 +687,10 @@ impl ReachabilityConstraints {
637687
fn analyze_single_pattern_predicate_kind<'db>(
638688
db: &'db dyn Db,
639689
predicate_kind: &PatternPredicateKind<'db>,
640-
subject: Expression<'db>,
690+
subject_ty: Type<'db>,
641691
) -> Truthiness {
642692
match predicate_kind {
643693
PatternPredicateKind::Value(value) => {
644-
let subject_ty = infer_expression_type(db, subject);
645694
let value_ty = infer_expression_type(db, *value);
646695

647696
if subject_ty.is_single_valued(db) {
@@ -651,15 +700,7 @@ impl ReachabilityConstraints {
651700
}
652701
}
653702
PatternPredicateKind::Singleton(singleton) => {
654-
let subject_ty = infer_expression_type(db, subject);
655-
656-
let singleton_ty = match singleton {
657-
ruff_python_ast::Singleton::None => Type::none(db),
658-
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
659-
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
660-
};
661-
662-
debug_assert!(singleton_ty.is_singleton(db));
703+
let singleton_ty = singleton_to_type(db, *singleton);
663704

664705
if subject_ty.is_equivalent_to(db, singleton_ty) {
665706
Truthiness::AlwaysTrue
@@ -671,10 +712,21 @@ impl ReachabilityConstraints {
671712
}
672713
PatternPredicateKind::Or(predicates) => {
673714
use std::ops::ControlFlow;
715+
716+
let mut excluded_types = vec![];
674717
let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) =
675718
predicates
676719
.iter()
677-
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject))
720+
.map(|p| {
721+
let narrowed_subject_ty = IntersectionBuilder::new(db)
722+
.add_positive(subject_ty)
723+
.add_negative(UnionType::from_elements(db, excluded_types.iter()))
724+
.build();
725+
726+
excluded_types.push(pattern_kind_to_type(db, p));
727+
728+
Self::analyze_single_pattern_predicate_kind(db, p, narrowed_subject_ty)
729+
})
678730
// this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
679731
.try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) {
680732
(Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => {
@@ -690,7 +742,6 @@ impl ReachabilityConstraints {
690742
truthiness
691743
}
692744
PatternPredicateKind::Class(class_expr, kind) => {
693-
let subject_ty = infer_expression_type(db, subject);
694745
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
695746

696747
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
@@ -715,10 +766,17 @@ impl ReachabilityConstraints {
715766
}
716767

717768
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
769+
let subject_ty = infer_expression_type(db, predicate.subject(db));
770+
771+
let narrowed_subject_ty = IntersectionBuilder::new(db)
772+
.add_positive(subject_ty)
773+
.add_negative(type_excluded_by_previous_patterns(db, predicate))
774+
.build();
775+
718776
let truthiness = Self::analyze_single_pattern_predicate_kind(
719777
db,
720778
predicate.kind(db),
721-
predicate.subject(db),
779+
narrowed_subject_ty,
722780
);
723781

724782
if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {

0 commit comments

Comments
 (0)