@@ -202,13 +202,14 @@ use crate::Db;
202202use crate :: dunder_all:: dunder_all_names;
203203use crate :: place:: { RequiresExplicitReExport , imported_symbol} ;
204204use crate :: rank:: RankBitBox ;
205- use crate :: semantic_index:: expression:: Expression ;
206205use crate :: semantic_index:: place_table;
207206use crate :: semantic_index:: predicate:: {
208207 CallableAndCallExpr , PatternPredicate , PatternPredicateKind , Predicate , PredicateNode ,
209208 Predicates , ScopedPredicateId ,
210209} ;
211- use crate :: types:: { Truthiness , Type , infer_expression_type} ;
210+ use crate :: types:: {
211+ IntersectionBuilder , Truthiness , Type , UnionBuilder , UnionType , infer_expression_type,
212+ } ;
212213
213214/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
214215/// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the
@@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId
311312const ALWAYS_FALSE : ScopedReachabilityConstraintId = ScopedReachabilityConstraintId :: ALWAYS_FALSE ;
312313const SMALLEST_TERMINAL : ScopedReachabilityConstraintId = ALWAYS_FALSE ;
313314
315+ fn singleton_to_type ( db : & dyn Db , singleton : ruff_python_ast:: Singleton ) -> Type < ' _ > {
316+ let ty = match singleton {
317+ ruff_python_ast:: Singleton :: None => Type :: none ( db) ,
318+ ruff_python_ast:: Singleton :: True => Type :: BooleanLiteral ( true ) ,
319+ ruff_python_ast:: Singleton :: False => Type :: BooleanLiteral ( false ) ,
320+ } ;
321+ debug_assert ! ( ty. is_singleton( db) ) ;
322+ ty
323+ }
324+
325+ /// Turn a `match` pattern kind into a type that represents the set of all values that would definitely
326+ /// match that pattern.
327+ fn pattern_kind_to_type < ' db > ( db : & ' db dyn Db , kind : & PatternPredicateKind < ' db > ) -> Type < ' db > {
328+ match kind {
329+ PatternPredicateKind :: Singleton ( singleton) => singleton_to_type ( db, * singleton) ,
330+ PatternPredicateKind :: Value ( value) => infer_expression_type ( db, * value) ,
331+ PatternPredicateKind :: Class ( class_expr, kind) => {
332+ if kind. is_irrefutable ( ) {
333+ infer_expression_type ( db, * class_expr)
334+ . to_instance ( db)
335+ . unwrap_or ( Type :: Never )
336+ } else {
337+ Type :: Never
338+ }
339+ }
340+ PatternPredicateKind :: Or ( predicates) => {
341+ UnionType :: from_elements ( db, predicates. iter ( ) . map ( |p| pattern_kind_to_type ( db, p) ) )
342+ }
343+ PatternPredicateKind :: Unsupported => Type :: Never ,
344+ }
345+ }
346+
347+ /// Go through the list of previous match cases, and accumulate a union of all types that were already
348+ /// matched by these patterns.
349+ fn type_excluded_by_previous_patterns < ' db > (
350+ db : & ' db dyn Db ,
351+ mut predicate : PatternPredicate < ' db > ,
352+ ) -> Type < ' db > {
353+ let mut builder = UnionBuilder :: new ( db) ;
354+ while let Some ( previous) = predicate. previous_predicate ( db) {
355+ predicate = * previous;
356+
357+ if predicate. guard ( db) . is_none ( ) {
358+ builder = builder. add ( pattern_kind_to_type ( db, predicate. kind ( db) ) ) ;
359+ }
360+ }
361+ builder. build ( )
362+ }
363+
314364/// A collection of reachability constraints for a given scope.
315365#[ derive( Debug , PartialEq , Eq , salsa:: Update , get_size2:: GetSize ) ]
316366pub ( crate ) struct ReachabilityConstraints {
@@ -637,11 +687,10 @@ impl ReachabilityConstraints {
637687 fn analyze_single_pattern_predicate_kind < ' db > (
638688 db : & ' db dyn Db ,
639689 predicate_kind : & PatternPredicateKind < ' db > ,
640- subject : Expression < ' db > ,
690+ subject_ty : Type < ' db > ,
641691 ) -> Truthiness {
642692 match predicate_kind {
643693 PatternPredicateKind :: Value ( value) => {
644- let subject_ty = infer_expression_type ( db, subject) ;
645694 let value_ty = infer_expression_type ( db, * value) ;
646695
647696 if subject_ty. is_single_valued ( db) {
@@ -651,15 +700,7 @@ impl ReachabilityConstraints {
651700 }
652701 }
653702 PatternPredicateKind :: Singleton ( singleton) => {
654- let subject_ty = infer_expression_type ( db, subject) ;
655-
656- let singleton_ty = match singleton {
657- ruff_python_ast:: Singleton :: None => Type :: none ( db) ,
658- ruff_python_ast:: Singleton :: True => Type :: BooleanLiteral ( true ) ,
659- ruff_python_ast:: Singleton :: False => Type :: BooleanLiteral ( false ) ,
660- } ;
661-
662- debug_assert ! ( singleton_ty. is_singleton( db) ) ;
703+ let singleton_ty = singleton_to_type ( db, * singleton) ;
663704
664705 if subject_ty. is_equivalent_to ( db, singleton_ty) {
665706 Truthiness :: AlwaysTrue
@@ -671,10 +712,21 @@ impl ReachabilityConstraints {
671712 }
672713 PatternPredicateKind :: Or ( predicates) => {
673714 use std:: ops:: ControlFlow ;
715+
716+ let mut excluded_types = vec ! [ ] ;
674717 let ( ControlFlow :: Break ( truthiness) | ControlFlow :: Continue ( truthiness) ) =
675718 predicates
676719 . iter ( )
677- . map ( |p| Self :: analyze_single_pattern_predicate_kind ( db, p, subject) )
720+ . map ( |p| {
721+ let narrowed_subject_ty = IntersectionBuilder :: new ( db)
722+ . add_positive ( subject_ty)
723+ . add_negative ( UnionType :: from_elements ( db, excluded_types. iter ( ) ) )
724+ . build ( ) ;
725+
726+ excluded_types. push ( pattern_kind_to_type ( db, p) ) ;
727+
728+ Self :: analyze_single_pattern_predicate_kind ( db, p, narrowed_subject_ty)
729+ } )
678730 // this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
679731 . try_fold ( Truthiness :: AlwaysFalse , |acc, next| match ( acc, next) {
680732 ( Truthiness :: AlwaysTrue , _) | ( _, Truthiness :: AlwaysTrue ) => {
@@ -690,7 +742,6 @@ impl ReachabilityConstraints {
690742 truthiness
691743 }
692744 PatternPredicateKind :: Class ( class_expr, kind) => {
693- let subject_ty = infer_expression_type ( db, subject) ;
694745 let class_ty = infer_expression_type ( db, * class_expr) . to_instance ( db) ;
695746
696747 class_ty. map_or ( Truthiness :: Ambiguous , |class_ty| {
@@ -715,10 +766,17 @@ impl ReachabilityConstraints {
715766 }
716767
717768 fn analyze_single_pattern_predicate ( db : & dyn Db , predicate : PatternPredicate ) -> Truthiness {
769+ let subject_ty = infer_expression_type ( db, predicate. subject ( db) ) ;
770+
771+ let narrowed_subject_ty = IntersectionBuilder :: new ( db)
772+ . add_positive ( subject_ty)
773+ . add_negative ( type_excluded_by_previous_patterns ( db, predicate) )
774+ . build ( ) ;
775+
718776 let truthiness = Self :: analyze_single_pattern_predicate_kind (
719777 db,
720778 predicate. kind ( db) ,
721- predicate . subject ( db ) ,
779+ narrowed_subject_ty ,
722780 ) ;
723781
724782 if truthiness == Truthiness :: AlwaysTrue && predicate. guard ( db) . is_some ( ) {
0 commit comments