From 2b0c5d546082420f1ad11834e08db0d981770fca Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 1 Oct 2024 13:36:40 -0400 Subject: [PATCH] Add tests; support unions --- crates/red_knot_python_semantic/src/types.rs | 10 + .../src/types/infer.rs | 198 ++++++++++++------ 2 files changed, 147 insertions(+), 61 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 50bab3d55563b4..5252fe63c7c8f0 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -397,6 +397,16 @@ impl<'db> Type<'db> { } } + /// Return true if the type is a class or a union of classes. + pub fn is_class(&self, db: &'db dyn Db) -> bool { + match self { + Type::Union(union) => union.elements(db).iter().all(|ty| ty.is_class(db)), + Type::Class(_) => true, + // / TODO include type[X], once we add that type + _ => false, + } + } + /// Return true if this type is a [subtype of] type `target`. /// /// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index eed268fc04ba02..00dacdc9aea56d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -50,8 +50,8 @@ use crate::stdlib::builtins_module_scope; use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty, - typing_extensions_symbol_ty, BytesLiteralType, CallOutcome, ClassType, FunctionKind, - FunctionType, StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType, + typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType, + StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType, }; use crate::Db; @@ -1338,23 +1338,6 @@ impl<'db> TypeInferenceBuilder<'db> { ); } - /// Emit a diagnostic declaring that a dunder method is not callable. - pub(super) fn dunder_not_callable_diagnostic( - &mut self, - node: AnyNodeRef, - not_callable_ty: Type<'db>, - dunder: &str, - ) { - self.add_diagnostic( - node, - "not-callable", - format_args!( - "Method `{dunder}` is not callable on object of type '{}'.", - not_callable_ty.display(self.db) - ), - ); - } - fn infer_for_statement_definition( &mut self, target: &ast::ExprName, @@ -2630,35 +2613,19 @@ impl<'db> TypeInferenceBuilder<'db> { // See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem let dunder_getitem_method = value_meta_ty.member(self.db, "__getitem__"); if !dunder_getitem_method.is_unbound() { - let CallOutcome::Callable { return_ty } = - dunder_getitem_method.call(self.db, &[slice_ty]) - else { - self.dunder_not_callable_diagnostic( - (&**value).into(), - value_ty, - "__getitem__", - ); - return Type::Unknown; - }; - return return_ty; + return dunder_getitem_method + .call(self.db, &[slice_ty]) + .unwrap_with_diagnostic(self.db, value.as_ref().into(), self); } // Otherwise, if the value is itself a class and defines `__class_getitem__`, // return its return type. - if matches!(value_ty, Type::Class(_)) { + if value_ty.is_class(self.db) { let dunder_class_getitem_method = value_ty.member(self.db, "__class_getitem__"); if !dunder_class_getitem_method.is_unbound() { - let CallOutcome::Callable { return_ty } = - dunder_class_getitem_method.call(self.db, &[slice_ty]) - else { - self.dunder_not_callable_diagnostic( - (&**value).into(), - value_ty, - "__class_getitem__", - ); - return Type::Unknown; - }; - return return_ty; + return dunder_class_getitem_method + .call(self.db, &[slice_ty]) + .unwrap_with_diagnostic(self.db, value.as_ref().into(), self); } } @@ -6801,14 +6768,14 @@ mod tests { } #[test] - fn subscript_not_callable_getitem() -> anyhow::Result<()> { + fn subscript_getitem_unbound() -> anyhow::Result<()> { let mut db = setup_db(); db.write_dedented( "/src/a.py", " class NotSubscriptable: - __getitem__ = None + pass a = NotSubscriptable()[0] ", @@ -6818,39 +6785,33 @@ mod tests { assert_file_diagnostics( &db, "/src/a.py", - &["Method `__getitem__` is not callable on object of type 'NotSubscriptable'."], + &["Cannot subscript object of type 'NotSubscriptable' with no `__getitem__` method."], ); Ok(()) } #[test] - fn dunder_call() -> anyhow::Result<()> { + fn subscript_not_callable_getitem() -> anyhow::Result<()> { let mut db = setup_db(); db.write_dedented( "/src/a.py", " - class Multiplier: - def __init__(self, factor: float): - self.factor = factor - - def __call__(self, number: float) -> float: - return number * self.factor - - a = Multiplier(2.0)(3.0) - - class Unit: - ... + class NotSubscriptable: + __getitem__ = None - b = Unit()(3.0) + a = NotSubscriptable()[0] ", )?; - assert_public_ty(&db, "/src/a.py", "a", "float"); - assert_public_ty(&db, "/src/a.py", "b", "Unknown"); + assert_public_ty(&db, "/src/a.py", "a", "Unknown"); + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'None' is not callable."], + ); - assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]); Ok(()) } @@ -6913,6 +6874,121 @@ mod tests { Ok(()) } + #[test] + fn subscript_getitem_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity: + if flag: + def __getitem__(self, index: int) -> int: + return index + else: + def __getitem__(self, index: int) -> str: + return str(index) + + a = Identity()[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "int | str"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity: + if flag: + def __class_getitem__(cls, item: int) -> str: + return item + else: + def __class_getitem__(cls, item: int) -> int: + return item + + a = Identity[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "str | int"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_class_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity1: + def __class_getitem__(cls, item: int) -> str: + return item + + class Identity2: + def __class_getitem__(cls, item: int) -> int: + return item + + if flag: + a = Identity1 + else: + a = Identity2 + + b = a[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Literal[Identity1, Identity2]"); + assert_public_ty(&db, "/src/a.py", "b", "str | int"); + + Ok(()) + } + + #[test] + fn dunder_call() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class Multiplier: + def __init__(self, factor: float): + self.factor = factor + + def __call__(self, number: float) -> float: + return number * self.factor + + a = Multiplier(2.0)(3.0) + + class Unit: + ... + + b = Unit()(3.0) + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "float"); + assert_public_ty(&db, "/src/a.py", "b", "Unknown"); + + assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]); + + Ok(()) + } + #[test] fn boolean_or_expression() -> anyhow::Result<()> { let mut db = setup_db();