Skip to content

Commit ade4c79

Browse files
committed
Recognize named type guards
1 parent 89a76c7 commit ade4c79

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3583,6 +3583,7 @@ pub enum KnownFunction {
35833583
Repr,
35843584
/// `typing(_extensions).final`
35853585
Final,
3586+
// TODO: Move this to `KnownClass`
35863587
/// `builtins.staticmethod`
35873588
StaticMethod,
35883589

@@ -3642,7 +3643,9 @@ impl KnownFunction {
36423643
/// Return `true` if `self` is defined in `module` at runtime.
36433644
const fn check_module(self, module: KnownModule) -> bool {
36443645
match self {
3645-
Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr | Self::StaticMethod => module.is_builtins(),
3646+
Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr | Self::StaticMethod => {
3647+
module.is_builtins()
3648+
}
36463649
Self::AssertType
36473650
| Self::Cast
36483651
| Self::Overload
@@ -4562,7 +4565,8 @@ pub(crate) mod tests {
45624565
KnownFunction::Len
45634566
| KnownFunction::Repr
45644567
| KnownFunction::IsInstance
4565-
| KnownFunction::IsSubclass => KnownModule::Builtins,
4568+
| KnownFunction::IsSubclass
4569+
| KnownFunction::StaticMethod => KnownModule::Builtins,
45664570

45674571
KnownFunction::GetattrStatic => KnownModule::Inspect,
45684572

crates/red_knot_python_semantic/src/types/narrow.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
210210
is_positive: bool,
211211
) -> Option<NarrowingConstraints<'db>> {
212212
match expression_node {
213-
ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, is_positive)),
213+
ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, expression, is_positive)),
214214
ast::Expr::Compare(expr_compare) => {
215215
self.evaluate_expr_compare(expr_compare, expression, is_positive)
216216
}
@@ -257,16 +257,50 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
257257
fn evaluate_expr_name(
258258
&mut self,
259259
expr_name: &ast::ExprName,
260+
expression: Expression<'db>,
260261
is_positive: bool,
261262
) -> NarrowingConstraints<'db> {
262263
let ast::ExprName { id, .. } = expr_name;
264+
let inference = infer_expression_types(self.db, expression);
263265

264266
let symbol = self
265267
.symbols()
266268
.symbol_id_by_name(id)
267269
.expect("Should always have a symbol for every Name node");
270+
let ty = inference.expression_type(expr_name.scoped_expression_id(self.db, self.scope()));
271+
268272
let mut constraints = NarrowingConstraints::default();
269273

274+
// TODO: Handle unions and intersections
275+
let mut narrow_by_typeguards = || match ty {
276+
Type::TypeGuard(type_guard) => {
277+
let (_, guarded_symbol, _) = type_guard.symbol_info(self.db)?;
278+
279+
if !is_positive {
280+
return None;
281+
}
282+
283+
constraints.insert(
284+
guarded_symbol,
285+
type_guard.ty(self.db).negate_if(self.db, !is_positive),
286+
);
287+
288+
Some(())
289+
}
290+
Type::TypeIs(type_is) => {
291+
let (_, guarded_symbol, _) = type_is.symbol_info(self.db)?;
292+
293+
constraints.insert(
294+
guarded_symbol,
295+
type_is.ty(self.db).negate_if(self.db, !is_positive),
296+
);
297+
298+
Some(())
299+
}
300+
_ => None,
301+
};
302+
narrow_by_typeguards();
303+
270304
constraints.insert(
271305
symbol,
272306
if is_positive {

0 commit comments

Comments
 (0)