Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ def _(target: Literal[True, False]):
case None:
y = 4

# TODO: with exhaustiveness checking, this should be Literal[2, 3]
reveal_type(y) # revealed: Literal[1, 2, 3]
reveal_type(y) # revealed: Literal[2, 3]

def _(target: bool):
y = 1
Expand All @@ -215,8 +214,7 @@ def _(target: bool):
case None:
y = 4

# TODO: with exhaustiveness checking, this should be Literal[2, 3]
reveal_type(y) # revealed: Literal[1, 2, 3]
reveal_type(y) # revealed: Literal[2, 3]

def _(target: None):
y = 1
Expand All @@ -242,8 +240,7 @@ def _(target: None | Literal[True]):
case None:
y = 4

# TODO: with exhaustiveness checking, this should be Literal[2, 4]
reveal_type(y) # revealed: Literal[1, 2, 4]
reveal_type(y) # revealed: Literal[2, 4]

# bool is an int subclass
def _(target: int):
Expand Down Expand Up @@ -292,7 +289,7 @@ def _(answer: Answer):
reveal_type(answer) # revealed: Literal[Answer.NO]
y = 2

reveal_type(y) # revealed: Literal[0, 1, 2]
reveal_type(y) # revealed: Literal[1, 2]
```

## Or match
Expand All @@ -311,8 +308,7 @@ def _(target: Literal["foo", "baz"]):
case "baz":
y = 3

# TODO: with exhaustiveness, this should be Literal[2, 3]
reveal_type(y) # revealed: Literal[1, 2, 3]
reveal_type(y) # revealed: Literal[2, 3]

def _(target: None):
y = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def match_singletons_success(obj: Literal[1, "a"] | None):
case None:
pass
case _ as obj:
# TODO: Ideally, we would not emit an error here
# error: [type-assertion-failure] "Argument does not have asserted type `Never`"
assert_never(obj)

def match_singletons_error(obj: Literal[1, "a"] | None):
Expand Down
2 changes: 0 additions & 2 deletions crates/ty_python_semantic/resources/mdtest/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,6 @@ def color_name(color: Color) -> str:
case _:
assert_never(color)

# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`"
def color_name_without_assertion(color: Color) -> str:
match color:
case Color.RED:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,10 @@ def match_exhaustive(x: Literal[0, 1, "a"]):
case "a":
pass
case _:
# TODO: this should not be an error
no_diagnostic_here # error: [unresolved-reference]
no_diagnostic_here

assert_never(x)

# TODO: there should be no error here
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
def match_exhaustive_no_assertion(x: Literal[0, 1, "a"]) -> int:
match x:
case 0:
Expand Down Expand Up @@ -130,13 +127,21 @@ def match_exhaustive(x: Color):
case Color.BLUE:
pass
case _:
# TODO: this should not be an error
no_diagnostic_here # error: [unresolved-reference]
no_diagnostic_here

assert_never(x)

def match_exhaustive_2(x: Color):
match x:
case Color.RED:
pass
case Color.GREEN | Color.BLUE:
pass
case _:
no_diagnostic_here

assert_never(x)

# TODO: there should be no error here
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
def match_exhaustive_no_assertion(x: Color) -> int:
match x:
case Color.RED:
Expand Down Expand Up @@ -208,13 +213,10 @@ def match_exhaustive(x: A | B | C):
case C():
pass
case _:
# TODO: this should not be an error
no_diagnostic_here # error: [unresolved-reference]
no_diagnostic_here

assert_never(x)

# TODO: there should be no error here
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
def match_exhaustive_no_assertion(x: A | B | C) -> int:
match x:
case A():
Expand Down
21 changes: 13 additions & 8 deletions crates/ty_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
subject: Expression<'db>,
pattern: &ast::Pattern,
guard: Option<&ast::Expr>,
) -> PredicateOrLiteral<'db> {
previous_pattern: Option<PatternPredicate<'db>>,
) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) {
// This is called for the top-level pattern of each match arm. We need to create a
// standalone expression for each arm of a match statement, since they can introduce
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
Expand All @@ -756,13 +757,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
subject,
kind,
guard,
previous_pattern.map(Box::new),
);
let predicate = PredicateOrLiteral::Predicate(Predicate {
node: PredicateNode::Pattern(pattern_predicate),
is_positive: true,
});
self.record_narrowing_constraint(predicate);
predicate
(predicate, pattern_predicate)
}

/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
Expand Down Expand Up @@ -1747,7 +1749,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());

let mut post_case_snapshots = vec![];
let mut match_predicate;
let mut previous_pattern: Option<PatternPredicate<'_>> = None;

for (i, case) in cases.iter().enumerate() {
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
Expand All @@ -1757,11 +1759,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
// here because the effects of visiting a pattern is binding
// symbols, and this doesn't occur unless the pattern
// actually matches
match_predicate = self.add_pattern_narrowing_constraint(
subject_expr,
&case.pattern,
case.guard.as_deref(),
);
let (match_predicate, match_pattern_predicate) = self
.add_pattern_narrowing_constraint(
subject_expr,
&case.pattern,
case.guard.as_deref(),
previous_pattern,
);
previous_pattern = Some(match_pattern_predicate);
let reachability_constraint =
self.record_reachability_constraint(match_predicate);

Expand Down
3 changes: 3 additions & 0 deletions crates/ty_python_semantic/src/semantic_index/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ pub(crate) struct PatternPredicate<'db> {
pub(crate) kind: PatternPredicateKind<'db>,

pub(crate) guard: Option<Expression<'db>>,

/// A reference to the pattern of the previous match case
pub(crate) previous_predicate: Option<Box<PatternPredicate<'db>>>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there are probably more efficient ways to implement this, but this linked-list approach was so unintrusive..

}

// The Salsa heap is tracked separately.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,14 @@ use crate::Db;
use crate::dunder_all::dunder_all_names;
use crate::place::{RequiresExplicitReExport, imported_symbol};
use crate::rank::RankBitBox;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
Predicates, ScopedPredicateId,
};
use crate::types::{Truthiness, Type, infer_expression_type};
use crate::types::{
IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type,
};

/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
/// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the
Expand Down Expand Up @@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId
const ALWAYS_FALSE: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId::ALWAYS_FALSE;
const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE;

fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type<'_> {
let ty = match singleton {
ruff_python_ast::Singleton::None => Type::none(db),
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
};
debug_assert!(ty.is_singleton(db));
ty
}

/// Turn a `match` pattern kind into a type that represents the set of all values that would definitely
/// match that pattern.
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
match kind {
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
PatternPredicateKind::Value(value) => infer_expression_type(db, *value),
PatternPredicateKind::Class(class_expr, kind) => {
if kind.is_irrefutable() {
infer_expression_type(db, *class_expr)
.to_instance(db)
.unwrap_or(Type::Never)
} else {
Type::Never
}
}
PatternPredicateKind::Or(predicates) => {
UnionType::from_elements(db, predicates.iter().map(|p| pattern_kind_to_type(db, p)))
}
PatternPredicateKind::Unsupported => Type::Never,
}
}

/// Go through the list of previous match cases, and accumulate a union of all types that were already
/// matched by these patterns.
fn type_excluded_by_previous_patterns<'db>(
db: &'db dyn Db,
mut predicate: PatternPredicate<'db>,
) -> Type<'db> {
let mut builder = UnionBuilder::new(db);
while let Some(previous) = predicate.previous_predicate(db) {
predicate = *previous;

if predicate.guard(db).is_none() {
builder = builder.add(pattern_kind_to_type(db, predicate.kind(db)));
}
}
builder.build()
}

/// A collection of reachability constraints for a given scope.
#[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) struct ReachabilityConstraints {
Expand Down Expand Up @@ -637,11 +687,10 @@ impl ReachabilityConstraints {
fn analyze_single_pattern_predicate_kind<'db>(
db: &'db dyn Db,
predicate_kind: &PatternPredicateKind<'db>,
subject: Expression<'db>,
subject_ty: Type<'db>,
) -> Truthiness {
match predicate_kind {
PatternPredicateKind::Value(value) => {
let subject_ty = infer_expression_type(db, subject);
let value_ty = infer_expression_type(db, *value);

if subject_ty.is_single_valued(db) {
Expand All @@ -651,15 +700,7 @@ impl ReachabilityConstraints {
}
}
PatternPredicateKind::Singleton(singleton) => {
let subject_ty = infer_expression_type(db, subject);

let singleton_ty = match singleton {
ruff_python_ast::Singleton::None => Type::none(db),
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
};

debug_assert!(singleton_ty.is_singleton(db));
let singleton_ty = singleton_to_type(db, *singleton);

if subject_ty.is_equivalent_to(db, singleton_ty) {
Truthiness::AlwaysTrue
Expand All @@ -671,10 +712,21 @@ impl ReachabilityConstraints {
}
PatternPredicateKind::Or(predicates) => {
use std::ops::ControlFlow;

let mut excluded_types = vec![];
let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) =
predicates
.iter()
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject))
.map(|p| {
let narrowed_subject_ty = IntersectionBuilder::new(db)
.add_positive(subject_ty)
.add_negative(UnionType::from_elements(db, excluded_types.iter()))
.build();

excluded_types.push(pattern_kind_to_type(db, p));

Self::analyze_single_pattern_predicate_kind(db, p, narrowed_subject_ty)
})
// this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
.try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) {
(Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => {
Expand All @@ -690,7 +742,6 @@ impl ReachabilityConstraints {
truthiness
}
PatternPredicateKind::Class(class_expr, kind) => {
let subject_ty = infer_expression_type(db, subject);
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);

class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
Expand All @@ -715,10 +766,17 @@ impl ReachabilityConstraints {
}

fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let subject_ty = infer_expression_type(db, predicate.subject(db));

let narrowed_subject_ty = IntersectionBuilder::new(db)
.add_positive(subject_ty)
.add_negative(type_excluded_by_previous_patterns(db, predicate))
.build();

let truthiness = Self::analyze_single_pattern_predicate_kind(
db,
predicate.kind(db),
predicate.subject(db),
narrowed_subject_ty,
);

if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {
Expand Down
Loading